pythonproject/qrtrain_predict_v1.py
2024-08-07 09:46:47 +08:00

73 lines
2.1 KiB
Python

import os
import cv2
import torch
import torch.nn as nn
from torchvision import models, transforms
from sklearn.svm import SVC
import joblib
from PIL import Image
import numpy as np
# 加载 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=False)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 加载训练好的 ResNet 模型权重
model.load_state_dict(torch.load('resnet_feature_extractor.pth'))
model.eval()
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 加载 SVM 分类器
svm = joblib.load('svm_classifier.joblib')
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(image):
image = transform(image).unsqueeze(0) # 添加批次维度
image = image.to(device)
with torch.no_grad():
features = model(image)
features = features.view(features.size(0), -1) # 展平特征向量
return features.cpu().numpy()
# 推理函数
def predict(image_path):
# 提取 ORB 特征
orb_features = extract_orb_features(image_path)
orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征
# 提取 ResNet 特征
image = Image.open(image_path).convert('RGB')
resnet_features = extract_resnet_features(image)
# 特征融合
features = np.hstack((resnet_features, orb_features))
# 进行分类
prediction = svm.predict(features)
return 'real' if prediction == 0 else 'fake'
# 示例推理调用
image_path = 'demo.png'
result = predict(image_path)
print(f'The QR code is {result}.')