2024 年 08 月 07 日 09:46:45

This commit is contained in:
Admin 2024-08-07 09:46:47 +08:00
parent 47f231bb21
commit 6bacf2fc56
44 changed files with 660 additions and 0 deletions

BIN
0.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 414 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

Binary file not shown.

BIN
Bag.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 61 KiB

41
FocalLoss.py Normal file
View File

@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int)):
self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list):
self.alpha = torch.Tensor(alpha)
self.reduction = reduction
def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input,dim=1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = logpt.exp()
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * at
loss = -1 * (1-pt)**self.gamma * logpt
if self.reduction=='sum':
loss = loss.sum()
elif self.reduction=='mean':
loss = loss.mean()
return loss

149
processQr.py Normal file
View File

@ -0,0 +1,149 @@
import random
import cv2
import numpy as np
import os
from tqdm import tqdm
import shutil
def correct_qr_code_images(input_folder, output_folder, margin=5):
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 读取输入文件夹中的所有图片
images = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
total_images = len(images)
# 确保WeChatQRCode类和模型文件正确加载
weChatQr = cv2.wechat_qrcode.WeChatQRCode(
"wechat_qrcode/detect.prototxt",
"wechat_qrcode/detect.caffemodel",
"wechat_qrcode/sr.prototxt",
"wechat_qrcode/sr.caffemodel"
)
# 使用tqdm显示进度条
for i, img_path in enumerate(tqdm(images, desc="Processing images", total=total_images)):
img = cv2.imread(img_path)
# 灰度转换
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
kernel = np.ones((1, 3), np.uint8)
dilated_img = cv2.dilate(img, kernel, iterations=1)
# 使用OTSU阈值算法将灰度图转换为二值图
_, binary_image = cv2.threshold(gray_img, 0, 255, cv2.THRESH_OTSU)
# 黑白颜色反转
# binary_img_inv = cv2.bitwise_not(binary_image)
# 显示反转后的二值图像
# cv2.imshow('Inverted Binary Image', dilated_img)
# 等待按键后关闭窗口
# cv2.waitKey(0)
# cv2.destroyAllWindows()
res, points = weChatQr.detectAndDecode(dilated_img)
if len(res) > 0:
point = points[0]
width = max(np.linalg.norm(point[0] - point[1]), np.linalg.norm(point[2] - point[3]))
height = max(np.linalg.norm(point[1] - point[2]), np.linalg.norm(point[3] - point[0]))
# 计算带有边距的新的points_dst
points_dst = np.array([[margin, margin],
[width - margin, margin],
[width - margin, height - margin],
[margin, height - margin]], dtype=np.float32)
# 生成变换矩阵
matrix = cv2.getPerspectiveTransform(point, points_dst)
# 应用透视变换,纠正图像
corrected_image = cv2.warpPerspective(img, matrix, (int(width), int(height)))
# 保存裁剪后的二维码图像
output_path = os.path.join(output_folder, f"cropped_qr_code_{i + 1:05d}.jpg")
cv2.imwrite(output_path, corrected_image)
else:
print(f"No QR code found in {img_path}.")
def migrate_images(source_folder, target_folder):
# 确保目标文件夹存在
if not os.path.exists(target_folder):
os.makedirs(target_folder)
# 初始化图片计数器
image_counter = 0
# 存储所有图片路径的列表
image_paths = []
# 遍历源文件夹及其所有子文件夹
for root, dirs, files in os.walk(source_folder):
for dir in dirs:
# 检查是否是名为'output'的文件夹
if dir == 'output':
# 拼接完整的output文件夹路径
output_path = os.path.join(root, dir)
# 遍历output文件夹中的所有文件
for file in os.listdir(output_path):
# 检查是否是图片文件
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')):
# 将图片路径添加到列表中
image_paths.append(os.path.join(output_path, file))
# 使用tqdm显示进度条
for i, image_path in enumerate(tqdm(image_paths, desc="Migrating images", unit="image")):
# 从路径中提取文件名和扩展名
file_name = os.path.basename(image_path)
# 构建目标文件的完整路径使用i作为编号
target_file = os.path.join(target_folder, f"image_{i + 1:04d}{os.path.splitext(file_name)[1]}")
# 复制图片到目标文件夹
shutil.copy2(image_path, target_file)
image_counter += 1
print(f"Total images migrated: {image_counter}")
def split_dataset(dataset_dir, output_dir, train_ratio=0.8):
for label in ['real', 'fake']:
label_dir = os.path.join(dataset_dir, label)
images = os.listdir(label_dir)
random.shuffle(images)
train_size = int(len(images) * train_ratio)
train_images = images[:train_size]
val_images = images[train_size:]
train_output_dir = os.path.join(output_dir, 'train', label)
val_output_dir = os.path.join(output_dir, 'val', label)
os.makedirs(train_output_dir, exist_ok=True)
os.makedirs(val_output_dir, exist_ok=True)
for image in train_images:
shutil.copy(os.path.join(label_dir, image), os.path.join(train_output_dir, image))
for image in val_images:
shutil.copy(os.path.join(label_dir, image), os.path.join(val_output_dir, image))
# 使用示例
input_folder = '/Users/jasonwong/Downloads/二维码测试1/手机验证码' # 替换为你的输入文件夹路径
output_folder = '/Users/jasonwong/Downloads/二维码测试1/手机验证码/output' # 替换为你的输出文件夹路径
correct_qr_code_images(input_folder, output_folder)
# source_folder = '/Users/jasonwong/Downloads/二维码测试1/仿品' # 替换为你的源文件夹路径
# target_folder = '/Users/jasonwong/Downloads/二维码测试1/dataset/fake' # 替换为你的目标文件夹路径
# migrate_images(source_folder, target_folder)
# # 使用脚本拆分数据集
# dataset_dir = '/Users/jasonwong/Downloads/二维码测试1/dataset' # 原始数据集路径
# output_dir = '/Users/jasonwong/Downloads/二维码测试1/data' # 拆分后的数据集路径
# split_dataset(dataset_dir, output_dir)

222
qrtrain.py Normal file
View File

@ -0,0 +1,222 @@
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.optim import Adam
from tqdm import tqdm
from torchvision.transforms import functional
from FocalLoss import FocalLoss
from torchvision.models import ResNet18_Weights
from torch.cuda.amp import GradScaler, autocast
# 加载预训练的ResNet-18模型
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# 将模型移到GPU上如果有
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 替换最后一层以适应二分类任务
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # 假设是二分类问题
model.fc = model.fc.to(device)
# 冻结所有层
for param in model.parameters():
param.requires_grad = False
# 解冻最后一层,用于微调
# for param in model.fc.parameters():
# param.requires_grad = True
# 解冻最后几层
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name: # 只解冻最后几层
param.requires_grad = True
def random_brightness(image):
"""
随机调整图像的亮度
该函数通过生成一个随机的亮度因子然后应用到图像上以增加或减少图像的亮度这种方法常用于数据增强
旨在帮助模型更好地泛化因为它教模型在不同亮度条件下处理图像
参数:
image (Tensor): 输入的图像Tensor
返回:
Tensor: 亮度经过随机调整后的图像Tensor
"""
# 生成一个随机的亮度因子
brightness_factor = torch.rand(1).item()
# 调整图像的亮度
return functional.adjust_brightness(image, brightness_factor)
def random_contrast(image):
"""
随机调整图像的对比度
该函数通过生成一个随机的对比度因子然后应用到图像上以增加或减少图像的对比度
对比度因子是在0.5到1.5之间随机生成的其中0.5是最低对比度1.5是最高对比度
参数:
image (Tensor): 输入的图像Tensor
返回:
Tensor: 对比度经过随机调整后的图像Tensor
"""
# 生成一个随机的对比度因子其值在0.5到1.5之间。
contrast_factor = torch.rand(1).item() + 0.5
# 使用生成的对比度因子调整图像的对比度。
return functional.adjust_contrast(image, contrast_factor)
def random_hue(image):
"""
随机调整图像的色相
色相的变化是由一个随机生成的因子决定的这个因子是在-0.05到0.05之间均匀分布的
这个函数的目的是为图像数据增强使得模型在训练过程中看到更多变体从而提高泛化能力
参数:
image (Tensor): 输入的图像Tensor
返回:
Tensor: 色相被随机调整后的图像Tensor
"""
# 生成一个随机的色相调整因子,其值在-0.05到0.05之间。
hue_factor = (torch.rand(1).item() - 0.5) / 10
# 使用调整的色相因子来改变图像的色相。
return functional.adjust_hue(image, hue_factor)
def add_gaussian_noise(image):
"""
向图像数据添加高斯噪声
该函数的目的是为了模拟现实世界中图像经常受到的各种噪声干扰通过在图像中添加高斯噪声
可以使模型在训练过程中更加鲁棒提高其在真实场景下的泛化能力
参数:
image: 输入的图像数据作为一个torch张量
返回值:
返回添加了高斯噪声的图像数据与输入图像数据类型和形状相同
添加的噪声是通过torch.randn_like函数生成的确保了噪声的分布与图像数据相同
并且乘以了一个小的常数0.1来控制噪声的强度
"""
# 生成与图像数据相同形状的高斯噪声张量,并将其与图像数据相加
return image + torch.randn_like(image) * 0.1
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(degrees=15), # 随机旋转图像最多15度
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 随机改变图像的颜色
transforms.ToTensor(),
# transforms.Lambda(lambda x: functional.adjust_brightness(x, brightness_factor=torch.rand(1).item())), # 随机调整亮度
# transforms.Lambda(lambda x: functional.adjust_contrast(x, contrast_factor=torch.rand(1).item() + 0.5)), # 随机调整对比度
# transforms.Lambda(lambda x: functional.adjust_hue(x, hue_factor=(torch.rand(1).item()-0.5)/10)), # 随机调整色相
# transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.1), # 添加随机高斯噪声
transforms.Lambda(random_brightness),
transforms.Lambda(random_contrast),
transforms.Lambda(random_hue),
transforms.Lambda(add_gaussian_noise),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# 加载数据集
data_dir = '/Users/jasonwong/Downloads/二维码测试1/data'
image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
#设置损失函数和优化器】
num_normal_samples = 7891
num_anomaly_samples = 846
class_sample_counts = [num_normal_samples, num_anomaly_samples] # 替换为实际样本数量
weights = 1.0 / np.array(class_sample_counts)
weights /= weights.sum() # 确保权重和为1
class_weights = torch.FloatTensor(weights).to(device) # 将权重转换为PyTorch张量并移动到设备上
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
optimizer = Adam(model.parameters(),
lr=1e-4, # 学习率
betas=(0.9, 0.999), # beta_1 和 beta_2
eps=1e-08) # epsilon
scheduler = StepLR(optimizer, step_size=7, gamma=0.1)
# 初始化GradScaler
# scaler = GradScaler()
#训练模型
def train_model(model, dataloaders, dataset_sizes, device, num_epochs=10, save_path='path_to_save/full_model'):
# criterion = FocalLoss(gamma=2, alpha=[0.25, 0.75], reduction='mean')
# 加载加权交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)
for epoch in range(num_epochs):
print(f'Starting epoch {epoch + 1}/{num_epochs}')
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in tqdm(dataloaders[phase], desc=f'Epoch {epoch + 1} {phase.capitalize()}'):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
scheduler.step()
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': epoch_loss,
}, f'{save_path}_epoch_{epoch + 1}.pth')
if __name__ == '__main__':
train_model(model, dataloaders, dataset_sizes, device, num_epochs=10, save_path='')

56
qrtrain_eval.py Normal file
View File

@ -0,0 +1,56 @@
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn as nn
# 加载保存的模型
def load_model(save_path, device):
model = models.resnet18() # 重新定义模型
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # 假设是二分类问题
checkpoint = torch.load(save_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval() # 设置模型为评估模式
return model
# 预处理图片
def preprocess_image(image_path):
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
image = data_transform(image)
image = image.unsqueeze(0) # 增加批次维度
return image
# 进行推理并输出结果
def predict(model, image_tensor, device):
image_tensor = image_tensor.to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, preds = torch.max(outputs, 1)
return preds.item()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = '_epoch_10.pth' # 替换为你的模型路径
image_path = '/Users/jasonwong/Downloads/二维码测试1/NormalJpg1/data/val/real/cropped_qr_code_09310.jpg' # 替换为你要测试的图片路径
# 加载模型
model = load_model(model_path, device)
# 预处理图片
image_tensor = preprocess_image(image_path)
# 进行预测
prediction = predict(model, image_tensor, device)
class_names = ['Normal', 'Anomaly']
print(f'Predicted class: {class_names[prediction]}')

72
qrtrain_predict_v1.py Normal file
View File

@ -0,0 +1,72 @@
import os
import cv2
import torch
import torch.nn as nn
from torchvision import models, transforms
from sklearn.svm import SVC
import joblib
from PIL import Image
import numpy as np
# 加载 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=False)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 加载训练好的 ResNet 模型权重
model.load_state_dict(torch.load('resnet_feature_extractor.pth'))
model.eval()
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 加载 SVM 分类器
svm = joblib.load('svm_classifier.joblib')
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(image):
image = transform(image).unsqueeze(0) # 添加批次维度
image = image.to(device)
with torch.no_grad():
features = model(image)
features = features.view(features.size(0), -1) # 展平特征向量
return features.cpu().numpy()
# 推理函数
def predict(image_path):
# 提取 ORB 特征
orb_features = extract_orb_features(image_path)
orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征
# 提取 ResNet 特征
image = Image.open(image_path).convert('RGB')
resnet_features = extract_resnet_features(image)
# 特征融合
features = np.hstack((resnet_features, orb_features))
# 进行分类
prediction = svm.predict(features)
return 'real' if prediction == 0 else 'fake'
# 示例推理调用
image_path = 'demo.png'
result = predict(image_path)
print(f'The QR code is {result}.')

120
qrtrain_v1.py Normal file
View File

@ -0,0 +1,120 @@
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import joblib
# 自定义数据集类
class QRCodeDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
self.labels = []
for label in ['real', 'fake']:
label_dir = os.path.join(root_dir, label)
for img_name in os.listdir(label_dir):
if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件
self.image_paths.append(os.path.join(label_dir, img_name))
self.labels.append(0 if label == 'real' else 1)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载数据集
train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 加载预训练的 ResNet 模型并去掉最后一层
model = models.resnet50(pretrained=True)
modules = list(model.children())[:-1] # 去掉最后的全连接层
model = nn.Sequential(*modules)
# 使用 GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 提取 ORB 特征
def extract_orb_features(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
orb = cv2.ORB_create()
keypoints, descriptors = orb.detectAndCompute(image, None)
if descriptors is None:
return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组
return descriptors
# 提取 ResNet 特征
def extract_resnet_features(loader):
model.eval()
features = []
labels = []
with torch.no_grad():
for inputs, label in tqdm(loader, desc="Extracting ResNet features"):
inputs = inputs.to(device)
output = model(inputs)
output = output.view(output.size(0), -1) # 展平特征向量
features.append(output.cpu())
labels.append(label)
return torch.cat(features), torch.cat(labels)
# 从训练和验证数据集中提取 ResNet 特征
train_resnet_features, train_labels = extract_resnet_features(train_loader)
val_resnet_features, val_labels = extract_resnet_features(val_loader)
# 从训练和验证数据集中提取 ORB 特征
def extract_orb_features_from_dataset(dataset):
features = []
for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"):
orb_features = extract_orb_features(img_path)
features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征
return np.array(features)
train_orb_features = extract_orb_features_from_dataset(train_dataset)
val_orb_features = extract_orb_features_from_dataset(val_dataset)
# 特征融合
train_features = np.hstack((train_resnet_features.numpy(), train_orb_features))
val_features = np.hstack((val_resnet_features.numpy(), val_orb_features))
# 使用 SVM 分类器进行训练和评估
svm = SVC(kernel='linear')
svm.fit(train_features, train_labels)
val_preds = svm.predict(val_features)
val_acc = accuracy_score(val_labels, val_preds)
print(f'Validation Accuracy: {val_acc:.4f}')
# 保存 ResNet 模型特征提取部分
torch.save(model.state_dict(), 'resnet_feature_extractor.pth')
# 保存 SVM 分类器
joblib.dump(svm, 'svm_classifier.joblib')

BIN
res/000035.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 274 KiB

BIN
res/007563.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB