73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import os
|
|
import cv2
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import models, transforms
|
|
from sklearn.svm import SVC
|
|
import joblib
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
# 加载 ResNet 模型并去掉最后一层
|
|
model = models.resnet50(pretrained=False)
|
|
modules = list(model.children())[:-1] # 去掉最后的全连接层
|
|
model = nn.Sequential(*modules)
|
|
|
|
# 加载训练好的 ResNet 模型权重
|
|
model.load_state_dict(torch.load('resnet_feature_extractor.pth'))
|
|
model.eval()
|
|
|
|
# 使用 GPU
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
# 加载 SVM 分类器
|
|
svm = joblib.load('svm_classifier.joblib')
|
|
|
|
# 数据预处理
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# 提取 ORB 特征
|
|
def extract_orb_features(image_path):
|
|
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
|
orb = cv2.ORB_create()
|
|
keypoints, descriptors = orb.detectAndCompute(image, None)
|
|
if descriptors is None:
|
|
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
|
|
return descriptors
|
|
|
|
# 提取 ResNet 特征
|
|
def extract_resnet_features(image):
|
|
image = transform(image).unsqueeze(0) # 添加批次维度
|
|
image = image.to(device)
|
|
with torch.no_grad():
|
|
features = model(image)
|
|
features = features.view(features.size(0), -1) # 展平特征向量
|
|
return features.cpu().numpy()
|
|
|
|
# 推理函数
|
|
def predict(image_path):
|
|
# 提取 ORB 特征
|
|
orb_features = extract_orb_features(image_path)
|
|
orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征
|
|
|
|
# 提取 ResNet 特征
|
|
image = Image.open(image_path).convert('RGB')
|
|
resnet_features = extract_resnet_features(image)
|
|
|
|
# 特征融合
|
|
features = np.hstack((resnet_features, orb_features))
|
|
|
|
# 进行分类
|
|
prediction = svm.predict(features)
|
|
return 'real' if prediction == 0 else 'fake'
|
|
|
|
# 示例推理调用
|
|
image_path = 'demo.png'
|
|
result = predict(image_path)
|
|
print(f'The QR code is {result}.')
|