import cv2 import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms from torch.utils.data import DataLoader, Dataset from sklearn.svm import SVC from sklearn.metrics import accuracy_score from PIL import Image import numpy as np import os from tqdm import tqdm import joblib # 自定义数据集类 class QRCodeDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] self.labels = [] for label in ['real', 'fake']: label_dir = os.path.join(root_dir, label) for img_name in os.listdir(label_dir): if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件 self.image_paths.append(os.path.join(label_dir, img_name)) self.labels.append(0 if label == 'real' else 1) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] image = Image.open(img_path) label = self.labels[idx] if self.transform: image = self.transform(image) return image, label # 数据预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 加载数据集 train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # 加载预训练的 ResNet 模型并去掉最后一层 model = models.resnet50(pretrained=True) modules = list(model.children())[:-1] # 去掉最后的全连接层 model = nn.Sequential(*modules) # 使用 GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # 提取 ORB 特征 def extract_orb_features(image_path): image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) orb = cv2.ORB_create() keypoints, descriptors = orb.detectAndCompute(image, None) if descriptors is None: return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组 return descriptors # 提取 ResNet 特征 def extract_resnet_features(loader): model.eval() features = [] labels = [] with torch.no_grad(): for inputs, label in tqdm(loader, desc="Extracting ResNet features"): inputs = inputs.to(device) output = model(inputs) output = output.view(output.size(0), -1) # 展平特征向量 features.append(output.cpu()) labels.append(label) return torch.cat(features), torch.cat(labels) # 从训练和验证数据集中提取 ResNet 特征 train_resnet_features, train_labels = extract_resnet_features(train_loader) val_resnet_features, val_labels = extract_resnet_features(val_loader) # 从训练和验证数据集中提取 ORB 特征 def extract_orb_features_from_dataset(dataset): features = [] for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"): orb_features = extract_orb_features(img_path) features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征 return np.array(features) train_orb_features = extract_orb_features_from_dataset(train_dataset) val_orb_features = extract_orb_features_from_dataset(val_dataset) # 特征融合 train_features = np.hstack((train_resnet_features.numpy(), train_orb_features)) val_features = np.hstack((val_resnet_features.numpy(), val_orb_features)) # 使用 SVM 分类器进行训练和评估 svm = SVC(kernel='linear') svm.fit(train_features, train_labels) val_preds = svm.predict(val_features) val_acc = accuracy_score(val_labels, val_preds) print(f'Validation Accuracy: {val_acc:.4f}') # 保存 ResNet 模型特征提取部分 torch.save(model.state_dict(), 'resnet_feature_extractor.pth') # 保存 SVM 分类器 joblib.dump(svm, 'svm_classifier.joblib')