pythonproject/qrtrain_v1.py
2024-08-07 09:46:47 +08:00

121 lines
4.2 KiB
Python

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import joblib
# 自定义数据集类
class QRCodeDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label in ['real', 'fake']:
label_dir = os.path.join(root_dir, label)
for img_name in os.listdir(label_dir):
if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件
self.image_paths.append(os.path.join(label_dir, img_name))
self.labels.append(0 if label == 'real' else 1)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载数据集
train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载预训练的 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=True)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(loader):
model.eval()
features = []
labels = []
with torch.no_grad():
for inputs, label in tqdm(loader, desc="Extracting ResNet features"):
inputs = inputs.to(device)
output = model(inputs)
output = output.view(output.size(0), -1) # 展平特征向量
features.append(output.cpu())
labels.append(label)
return torch.cat(features), torch.cat(labels)
# 从训练和验证数据集中提取 ResNet 特征
train_resnet_features, train_labels = extract_resnet_features(train_loader)
val_resnet_features, val_labels = extract_resnet_features(val_loader)
# 从训练和验证数据集中提取 ORB 特征
def extract_orb_features_from_dataset(dataset):
features = []
for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"):
orb_features = extract_orb_features(img_path)
features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征
return np.array(features)
train_orb_features = extract_orb_features_from_dataset(train_dataset)
val_orb_features = extract_orb_features_from_dataset(val_dataset)
# 特征融合
train_features = np.hstack((train_resnet_features.numpy(), train_orb_features))
val_features = np.hstack((val_resnet_features.numpy(), val_orb_features))
# 使用 SVM 分类器进行训练和评估
svm = SVC(kernel='linear')
svm.fit(train_features, train_labels)
val_preds = svm.predict(val_features)
val_acc = accuracy_score(val_labels, val_preds)
print(f'Validation Accuracy: {val_acc:.4f}')
# 保存 ResNet 模型特征提取部分
torch.save(model.state_dict(), 'resnet_feature_extractor.pth')
# 保存 SVM 分类器
joblib.dump(svm, 'svm_classifier.joblib')