pythonproject/qrtrain_predict_v1.py

73 lines
2.1 KiB
Python
Raw Permalink Normal View History

2024-08-07 09:46:47 +08:00
import os
import cv2
import torch
import torch.nn as nn
from torchvision import models, transforms
from sklearn.svm import SVC
import joblib
from PIL import Image
import numpy as np
# 加载 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=False)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 加载训练好的 ResNet 模型权重
model.load_state_dict(torch.load('resnet_feature_extractor.pth'))
model.eval()
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 加载 SVM 分类器
svm = joblib.load('svm_classifier.joblib')
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(image):
image = transform(image).unsqueeze(0) # 添加批次维度
image = image.to(device)
with torch.no_grad():
features = model(image)
features = features.view(features.size(0), -1) # 展平特征向量
return features.cpu().numpy()
# 推理函数
def predict(image_path):
# 提取 ORB 特征
orb_features = extract_orb_features(image_path)
orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征
# 提取 ResNet 特征
image = Image.open(image_path).convert('RGB')
resnet_features = extract_resnet_features(image)
# 特征融合
features = np.hstack((resnet_features, orb_features))
# 进行分类
prediction = svm.predict(features)
return 'real' if prediction == 0 else 'fake'
# 示例推理调用
image_path = 'demo.png'
result = predict(image_path)
print(f'The QR code is {result}.')