import os import cv2 import torch import torch.nn as nn from torchvision import models, transforms from sklearn.svm import SVC import joblib from PIL import Image import numpy as np # 加载 ResNet 模型并去掉最后一层 model = models.resnet50(pretrained=False) modules = list(model.children())[:-1] # 去掉最后的全连接层 model = nn.Sequential(*modules) # 加载训练好的 ResNet 模型权重 model.load_state_dict(torch.load('resnet_feature_extractor.pth')) model.eval() # 使用 GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # 加载 SVM 分类器 svm = joblib.load('svm_classifier.joblib') # 数据预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 提取 ORB 特征 def extract_orb_features(image_path): image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) orb = cv2.ORB_create() keypoints, descriptors = orb.detectAndCompute(image, None) if descriptors is None: return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组 return descriptors # 提取 ResNet 特征 def extract_resnet_features(image): image = transform(image).unsqueeze(0) # 添加批次维度 image = image.to(device) with torch.no_grad(): features = model(image) features = features.view(features.size(0), -1) # 展平特征向量 return features.cpu().numpy() # 推理函数 def predict(image_path): # 提取 ORB 特征 orb_features = extract_orb_features(image_path) orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征 # 提取 ResNet 特征 image = Image.open(image_path).convert('RGB') resnet_features = extract_resnet_features(image) # 特征融合 features = np.hstack((resnet_features, orb_features)) # 进行分类 prediction = svm.predict(features) return 'real' if prediction == 0 else 'fake' # 示例推理调用 image_path = 'demo.png' result = predict(image_path) print(f'The QR code is {result}.')