pythonproject/qrtrain_v1.py

121 lines
4.2 KiB
Python
Raw Permalink Normal View History

2024-08-07 09:46:47 +08:00
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import joblib
# 自定义数据集类
class QRCodeDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label in ['real', 'fake']:
label_dir = os.path.join(root_dir, label)
for img_name in os.listdir(label_dir):
if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件
self.image_paths.append(os.path.join(label_dir, img_name))
self.labels.append(0 if label == 'real' else 1)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载数据集
train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载预训练的 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=True)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(loader):
model.eval()
features = []
labels = []
with torch.no_grad():
for inputs, label in tqdm(loader, desc="Extracting ResNet features"):
inputs = inputs.to(device)
output = model(inputs)
output = output.view(output.size(0), -1) # 展平特征向量
features.append(output.cpu())
labels.append(label)
return torch.cat(features), torch.cat(labels)
# 从训练和验证数据集中提取 ResNet 特征
train_resnet_features, train_labels = extract_resnet_features(train_loader)
val_resnet_features, val_labels = extract_resnet_features(val_loader)
# 从训练和验证数据集中提取 ORB 特征
def extract_orb_features_from_dataset(dataset):
features = []
for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"):
orb_features = extract_orb_features(img_path)
features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征
return np.array(features)
train_orb_features = extract_orb_features_from_dataset(train_dataset)
val_orb_features = extract_orb_features_from_dataset(val_dataset)
# 特征融合
train_features = np.hstack((train_resnet_features.numpy(), train_orb_features))
val_features = np.hstack((val_resnet_features.numpy(), val_orb_features))
# 使用 SVM 分类器进行训练和评估
svm = SVC(kernel='linear')
svm.fit(train_features, train_labels)
val_preds = svm.predict(val_features)
val_acc = accuracy_score(val_labels, val_preds)
print(f'Validation Accuracy: {val_acc:.4f}')
# 保存 ResNet 模型特征提取部分
torch.save(model.state_dict(), 'resnet_feature_extractor.pth')
# 保存 SVM 分类器
joblib.dump(svm, 'svm_classifier.joblib')