121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
![]() |
import cv2
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
from torchvision import models, transforms
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
from sklearn.svm import SVC
|
||
|
from sklearn.metrics import accuracy_score
|
||
|
from PIL import Image
|
||
|
import numpy as np
|
||
|
import os
|
||
|
from tqdm import tqdm
|
||
|
import joblib
|
||
|
|
||
|
# 自定义数据集类
|
||
|
class QRCodeDataset(Dataset):
|
||
|
def __init__(self, root_dir, transform=None):
|
||
|
self.root_dir = root_dir
|
||
|
self.transform = transform
|
||
|
self.image_paths = []
|
||
|
self.labels = []
|
||
|
|
||
|
for label in ['real', 'fake']:
|
||
|
label_dir = os.path.join(root_dir, label)
|
||
|
for img_name in os.listdir(label_dir):
|
||
|
if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件
|
||
|
self.image_paths.append(os.path.join(label_dir, img_name))
|
||
|
self.labels.append(0 if label == 'real' else 1)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.image_paths)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
img_path = self.image_paths[idx]
|
||
|
image = Image.open(img_path)
|
||
|
label = self.labels[idx]
|
||
|
|
||
|
if self.transform:
|
||
|
image = self.transform(image)
|
||
|
|
||
|
return image, label
|
||
|
|
||
|
# 数据预处理
|
||
|
transform = transforms.Compose([
|
||
|
transforms.Resize((224, 224)),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
|
])
|
||
|
|
||
|
# 加载数据集
|
||
|
train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform)
|
||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||
|
|
||
|
val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform)
|
||
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||
|
|
||
|
# 加载预训练的 ResNet 模型并去掉最后一层
|
||
|
model = models.resnet50(pretrained=True)
|
||
|
modules = list(model.children())[:-1] # 去掉最后的全连接层
|
||
|
model = nn.Sequential(*modules)
|
||
|
|
||
|
# 使用 GPU
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
model = model.to(device)
|
||
|
|
||
|
# 提取 ORB 特征
|
||
|
def extract_orb_features(image_path):
|
||
|
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||
|
orb = cv2.ORB_create()
|
||
|
keypoints, descriptors = orb.detectAndCompute(image, None)
|
||
|
if descriptors is None:
|
||
|
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
|
||
|
return descriptors
|
||
|
|
||
|
# 提取 ResNet 特征
|
||
|
def extract_resnet_features(loader):
|
||
|
model.eval()
|
||
|
features = []
|
||
|
labels = []
|
||
|
with torch.no_grad():
|
||
|
for inputs, label in tqdm(loader, desc="Extracting ResNet features"):
|
||
|
inputs = inputs.to(device)
|
||
|
output = model(inputs)
|
||
|
output = output.view(output.size(0), -1) # 展平特征向量
|
||
|
features.append(output.cpu())
|
||
|
labels.append(label)
|
||
|
return torch.cat(features), torch.cat(labels)
|
||
|
|
||
|
# 从训练和验证数据集中提取 ResNet 特征
|
||
|
train_resnet_features, train_labels = extract_resnet_features(train_loader)
|
||
|
val_resnet_features, val_labels = extract_resnet_features(val_loader)
|
||
|
|
||
|
# 从训练和验证数据集中提取 ORB 特征
|
||
|
def extract_orb_features_from_dataset(dataset):
|
||
|
features = []
|
||
|
for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"):
|
||
|
orb_features = extract_orb_features(img_path)
|
||
|
features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征
|
||
|
return np.array(features)
|
||
|
|
||
|
train_orb_features = extract_orb_features_from_dataset(train_dataset)
|
||
|
val_orb_features = extract_orb_features_from_dataset(val_dataset)
|
||
|
|
||
|
# 特征融合
|
||
|
train_features = np.hstack((train_resnet_features.numpy(), train_orb_features))
|
||
|
val_features = np.hstack((val_resnet_features.numpy(), val_orb_features))
|
||
|
|
||
|
# 使用 SVM 分类器进行训练和评估
|
||
|
svm = SVC(kernel='linear')
|
||
|
svm.fit(train_features, train_labels)
|
||
|
|
||
|
val_preds = svm.predict(val_features)
|
||
|
val_acc = accuracy_score(val_labels, val_preds)
|
||
|
print(f'Validation Accuracy: {val_acc:.4f}')
|
||
|
|
||
|
# 保存 ResNet 模型特征提取部分
|
||
|
torch.save(model.state_dict(), 'resnet_feature_extractor.pth')
|
||
|
|
||
|
# 保存 SVM 分类器
|
||
|
joblib.dump(svm, 'svm_classifier.joblib')
|