diff --git a/0.png b/0.png deleted file mode 100644 index 1956009..0000000 Binary files a/0.png and /dev/null differ diff --git a/20240507153436.png b/20240507153436.png deleted file mode 100644 index 55f9097..0000000 Binary files a/20240507153436.png and /dev/null differ diff --git a/20240507153436.png.zip b/20240507153436.png.zip deleted file mode 100644 index cddf719..0000000 Binary files a/20240507153436.png.zip and /dev/null differ diff --git a/20240507173217.png b/20240507173217.png deleted file mode 100644 index a565ab8..0000000 Binary files a/20240507173217.png and /dev/null differ diff --git a/20240507173217.png.zip b/20240507173217.png.zip deleted file mode 100644 index 162fdc2..0000000 Binary files a/20240507173217.png.zip and /dev/null differ diff --git a/20240517105758.png b/20240517105758.png deleted file mode 100644 index 8187845..0000000 Binary files a/20240517105758.png and /dev/null differ diff --git a/20240517105800.png b/20240517105800.png deleted file mode 100644 index 4fc8367..0000000 Binary files a/20240517105800.png and /dev/null differ diff --git a/20240517105802.png b/20240517105802.png deleted file mode 100644 index 0f1d8b0..0000000 Binary files a/20240517105802.png and /dev/null differ diff --git a/20240517105803.png b/20240517105803.png deleted file mode 100644 index b8b250f..0000000 Binary files a/20240517105803.png and /dev/null differ diff --git a/20240517105805.png b/20240517105805.png deleted file mode 100644 index 777d5b8..0000000 Binary files a/20240517105805.png and /dev/null differ diff --git a/20240517105806.png b/20240517105806.png deleted file mode 100644 index 750151c..0000000 Binary files a/20240517105806.png and /dev/null differ diff --git a/20240517105808.png b/20240517105808.png deleted file mode 100644 index 6e0ece2..0000000 Binary files a/20240517105808.png and /dev/null differ diff --git a/20240517105810.png b/20240517105810.png deleted file mode 100644 index c82d626..0000000 Binary files a/20240517105810.png and /dev/null differ diff --git a/20240517105812.png b/20240517105812.png deleted file mode 100644 index e97254b..0000000 Binary files a/20240517105812.png and /dev/null differ diff --git a/20240517105813.png b/20240517105813.png deleted file mode 100644 index cb08518..0000000 Binary files a/20240517105813.png and /dev/null differ diff --git a/20240517105815.png b/20240517105815.png deleted file mode 100644 index 78c8005..0000000 Binary files a/20240517105815.png and /dev/null differ diff --git a/20240517105816.png b/20240517105816.png deleted file mode 100644 index 2680452..0000000 Binary files a/20240517105816.png and /dev/null differ diff --git a/20240517105818.png b/20240517105818.png deleted file mode 100644 index b404a5b..0000000 Binary files a/20240517105818.png and /dev/null differ diff --git a/20240517105820.png b/20240517105820.png deleted file mode 100644 index 6755d2d..0000000 Binary files a/20240517105820.png and /dev/null differ diff --git a/20240517105821.png b/20240517105821.png deleted file mode 100644 index 40e5c3e..0000000 Binary files a/20240517105821.png and /dev/null differ diff --git a/20240517105823.png b/20240517105823.png deleted file mode 100644 index 7416139..0000000 Binary files a/20240517105823.png and /dev/null differ diff --git a/20240517105824.png b/20240517105824.png deleted file mode 100644 index 482dcee..0000000 Binary files a/20240517105824.png and /dev/null differ diff --git a/20240517105826.png b/20240517105826.png deleted file mode 100644 index 577486b..0000000 Binary files a/20240517105826.png and /dev/null differ diff --git a/20240517105828.png b/20240517105828.png deleted file mode 100644 index bf8eb2a..0000000 Binary files a/20240517105828.png and /dev/null differ diff --git a/20240517105829.png b/20240517105829.png deleted file mode 100644 index 224ac17..0000000 Binary files a/20240517105829.png and /dev/null differ diff --git a/20240517115712.png b/20240517115712.png deleted file mode 100644 index 99232ea..0000000 Binary files a/20240517115712.png and /dev/null differ diff --git a/20240517115712.tiff b/20240517115712.tiff deleted file mode 100644 index 3ce6146..0000000 Binary files a/20240517115712.tiff and /dev/null differ diff --git a/20240517135556.png b/20240517135556.png deleted file mode 100644 index f7713fc..0000000 Binary files a/20240517135556.png and /dev/null differ diff --git a/20240517135556.tiff b/20240517135556.tiff deleted file mode 100644 index 6f7cda8..0000000 Binary files a/20240517135556.tiff and /dev/null differ diff --git a/20240517140142.png b/20240517140142.png deleted file mode 100644 index 7f6fb75..0000000 Binary files a/20240517140142.png and /dev/null differ diff --git a/20240517140142.tiff b/20240517140142.tiff deleted file mode 100644 index 8813fb7..0000000 Binary files a/20240517140142.tiff and /dev/null differ diff --git a/20240517140325.png b/20240517140325.png deleted file mode 100644 index ba025e8..0000000 Binary files a/20240517140325.png and /dev/null differ diff --git a/20240517140325.tiff b/20240517140325.tiff deleted file mode 100644 index 254cec4..0000000 Binary files a/20240517140325.tiff and /dev/null differ diff --git a/20240517140418.png b/20240517140418.png deleted file mode 100644 index 2bf4d88..0000000 Binary files a/20240517140418.png and /dev/null differ diff --git a/20240517140418.tiff b/20240517140418.tiff deleted file mode 100644 index 6ee8295..0000000 Binary files a/20240517140418.tiff and /dev/null differ diff --git a/Bag.png b/Bag.png deleted file mode 100644 index 5b930ab..0000000 Binary files a/Bag.png and /dev/null differ diff --git a/FocalLoss.py b/FocalLoss.py new file mode 100644 index 0000000..69d2933 --- /dev/null +++ b/FocalLoss.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FocalLoss(nn.Module): + def __init__(self, gamma=2, alpha=None, reduction='mean'): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + if isinstance(alpha,(float,int)): + self.alpha = torch.Tensor([alpha,1-alpha]) + if isinstance(alpha,list): + self.alpha = torch.Tensor(alpha) + self.reduction = reduction + + def forward(self, input, target): + if input.dim()>2: + input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W + input = input.transpose(1,2) # N,C,H*W => N,H*W,C + input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C + target = target.view(-1,1) + + logpt = F.log_softmax(input,dim=1) + logpt = logpt.gather(1,target) + logpt = logpt.view(-1) + pt = logpt.exp() + + if self.alpha is not None: + if self.alpha.type()!=input.data.type(): + self.alpha = self.alpha.type_as(input.data) + at = self.alpha.gather(0,target.data.view(-1)) + logpt = logpt * at + + loss = -1 * (1-pt)**self.gamma * logpt + + if self.reduction=='sum': + loss = loss.sum() + elif self.reduction=='mean': + loss = loss.mean() + + return loss \ No newline at end of file diff --git a/processQr.py b/processQr.py new file mode 100644 index 0000000..fc27cb5 --- /dev/null +++ b/processQr.py @@ -0,0 +1,149 @@ +import random + +import cv2 +import numpy as np +import os +from tqdm import tqdm +import shutil + + +def correct_qr_code_images(input_folder, output_folder, margin=5): + # 确保输出文件夹存在 + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # 读取输入文件夹中的所有图片 + images = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.endswith(('.png', '.jpg', '.jpeg'))] + total_images = len(images) + + # 确保WeChatQRCode类和模型文件正确加载 + weChatQr = cv2.wechat_qrcode.WeChatQRCode( + "wechat_qrcode/detect.prototxt", + "wechat_qrcode/detect.caffemodel", + "wechat_qrcode/sr.prototxt", + "wechat_qrcode/sr.caffemodel" + ) + + # 使用tqdm显示进度条 + for i, img_path in enumerate(tqdm(images, desc="Processing images", total=total_images)): + img = cv2.imread(img_path) + # 灰度转换 + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + kernel = np.ones((1, 3), np.uint8) + dilated_img = cv2.dilate(img, kernel, iterations=1) + + + + # 使用OTSU阈值算法将灰度图转换为二值图 + _, binary_image = cv2.threshold(gray_img, 0, 255, cv2.THRESH_OTSU) + # 黑白颜色反转 + # binary_img_inv = cv2.bitwise_not(binary_image) + + # 显示反转后的二值图像 + # cv2.imshow('Inverted Binary Image', dilated_img) + + # 等待按键后关闭窗口 + # cv2.waitKey(0) + # cv2.destroyAllWindows() + + res, points = weChatQr.detectAndDecode(dilated_img) + + if len(res) > 0: + point = points[0] + width = max(np.linalg.norm(point[0] - point[1]), np.linalg.norm(point[2] - point[3])) + height = max(np.linalg.norm(point[1] - point[2]), np.linalg.norm(point[3] - point[0])) + + # 计算带有边距的新的points_dst + points_dst = np.array([[margin, margin], + [width - margin, margin], + [width - margin, height - margin], + [margin, height - margin]], dtype=np.float32) + + # 生成变换矩阵 + matrix = cv2.getPerspectiveTransform(point, points_dst) + + # 应用透视变换,纠正图像 + corrected_image = cv2.warpPerspective(img, matrix, (int(width), int(height))) + + # 保存裁剪后的二维码图像 + output_path = os.path.join(output_folder, f"cropped_qr_code_{i + 1:05d}.jpg") + cv2.imwrite(output_path, corrected_image) + else: + print(f"No QR code found in {img_path}.") + + +def migrate_images(source_folder, target_folder): + # 确保目标文件夹存在 + if not os.path.exists(target_folder): + os.makedirs(target_folder) + + # 初始化图片计数器 + image_counter = 0 + # 存储所有图片路径的列表 + image_paths = [] + + # 遍历源文件夹及其所有子文件夹 + for root, dirs, files in os.walk(source_folder): + for dir in dirs: + # 检查是否是名为'output'的文件夹 + if dir == 'output': + # 拼接完整的output文件夹路径 + output_path = os.path.join(root, dir) + # 遍历output文件夹中的所有文件 + for file in os.listdir(output_path): + # 检查是否是图片文件 + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')): + # 将图片路径添加到列表中 + image_paths.append(os.path.join(output_path, file)) + + # 使用tqdm显示进度条 + for i, image_path in enumerate(tqdm(image_paths, desc="Migrating images", unit="image")): + # 从路径中提取文件名和扩展名 + file_name = os.path.basename(image_path) + # 构建目标文件的完整路径,使用i作为编号 + target_file = os.path.join(target_folder, f"image_{i + 1:04d}{os.path.splitext(file_name)[1]}") + # 复制图片到目标文件夹 + shutil.copy2(image_path, target_file) + image_counter += 1 + + print(f"Total images migrated: {image_counter}") + + +def split_dataset(dataset_dir, output_dir, train_ratio=0.8): + for label in ['real', 'fake']: + label_dir = os.path.join(dataset_dir, label) + images = os.listdir(label_dir) + random.shuffle(images) + + train_size = int(len(images) * train_ratio) + + train_images = images[:train_size] + val_images = images[train_size:] + + train_output_dir = os.path.join(output_dir, 'train', label) + val_output_dir = os.path.join(output_dir, 'val', label) + + os.makedirs(train_output_dir, exist_ok=True) + os.makedirs(val_output_dir, exist_ok=True) + + for image in train_images: + shutil.copy(os.path.join(label_dir, image), os.path.join(train_output_dir, image)) + + for image in val_images: + shutil.copy(os.path.join(label_dir, image), os.path.join(val_output_dir, image)) + + +# 使用示例 +input_folder = '/Users/jasonwong/Downloads/二维码测试1/手机验证码' # 替换为你的输入文件夹路径 +output_folder = '/Users/jasonwong/Downloads/二维码测试1/手机验证码/output' # 替换为你的输出文件夹路径 +correct_qr_code_images(input_folder, output_folder) + +# source_folder = '/Users/jasonwong/Downloads/二维码测试1/仿品' # 替换为你的源文件夹路径 +# target_folder = '/Users/jasonwong/Downloads/二维码测试1/dataset/fake' # 替换为你的目标文件夹路径 +# migrate_images(source_folder, target_folder) + +# # 使用脚本拆分数据集 +# dataset_dir = '/Users/jasonwong/Downloads/二维码测试1/dataset' # 原始数据集路径 +# output_dir = '/Users/jasonwong/Downloads/二维码测试1/data' # 拆分后的数据集路径 +# split_dataset(dataset_dir, output_dir) \ No newline at end of file diff --git a/qrtrain.py b/qrtrain.py new file mode 100644 index 0000000..571159e --- /dev/null +++ b/qrtrain.py @@ -0,0 +1,222 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +from torch.optim.lr_scheduler import StepLR +from torchvision import models, transforms +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from torch.optim import Adam +from tqdm import tqdm +from torchvision.transforms import functional +from FocalLoss import FocalLoss +from torchvision.models import ResNet18_Weights +from torch.cuda.amp import GradScaler, autocast + +# 加载预训练的ResNet-18模型 +model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) + +# 将模型移到GPU上(如果有) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = model.to(device) + +# 替换最后一层以适应二分类任务 +num_features = model.fc.in_features +model.fc = nn.Linear(num_features, 2) # 假设是二分类问题 +model.fc = model.fc.to(device) + +# 冻结所有层 +for param in model.parameters(): + param.requires_grad = False + +# 解冻最后一层,用于微调 +# for param in model.fc.parameters(): +# param.requires_grad = True + +# 解冻最后几层 +for name, param in model.named_parameters(): + if "layer4" in name or "fc" in name: # 只解冻最后几层 + param.requires_grad = True + +def random_brightness(image): + """ + 随机调整图像的亮度。 + + 该函数通过生成一个随机的亮度因子,然后应用到图像上,以增加或减少图像的亮度。这种方法常用于数据增强, + 旨在帮助模型更好地泛化,因为它教模型在不同亮度条件下处理图像。 + + 参数: + image (Tensor): 输入的图像Tensor。 + + 返回: + Tensor: 亮度经过随机调整后的图像Tensor。 + """ + # 生成一个随机的亮度因子 + brightness_factor = torch.rand(1).item() + # 调整图像的亮度 + return functional.adjust_brightness(image, brightness_factor) + + +def random_contrast(image): + """ + 随机调整图像的对比度。 + + 该函数通过生成一个随机的对比度因子,然后应用到图像上,以增加或减少图像的对比度。 + 对比度因子是在0.5到1.5之间随机生成的,其中0.5是最低对比度,1.5是最高对比度。 + + 参数: + image (Tensor): 输入的图像Tensor。 + + 返回: + Tensor: 对比度经过随机调整后的图像Tensor。 + """ + # 生成一个随机的对比度因子,其值在0.5到1.5之间。 + contrast_factor = torch.rand(1).item() + 0.5 + # 使用生成的对比度因子调整图像的对比度。 + return functional.adjust_contrast(image, contrast_factor) + + +def random_hue(image): + """ + 随机调整图像的色相。 + + 色相的变化是由一个随机生成的因子决定的,这个因子是在-0.05到0.05之间均匀分布的。 + 这个函数的目的是为图像数据增强,使得模型在训练过程中看到更多变体,从而提高泛化能力。 + + 参数: + image (Tensor): 输入的图像Tensor。 + + 返回: + Tensor: 色相被随机调整后的图像Tensor。 + """ + # 生成一个随机的色相调整因子,其值在-0.05到0.05之间。 + hue_factor = (torch.rand(1).item() - 0.5) / 10 + # 使用调整的色相因子来改变图像的色相。 + return functional.adjust_hue(image, hue_factor) + + +def add_gaussian_noise(image): + """ + 向图像数据添加高斯噪声。 + + 该函数的目的是为了模拟现实世界中图像经常受到的各种噪声干扰。通过在图像中添加高斯噪声, + 可以使模型在训练过程中更加鲁棒,提高其在真实场景下的泛化能力。 + + 参数: + image: 输入的图像数据,作为一个torch张量。 + + 返回值: + 返回添加了高斯噪声的图像数据,与输入图像数据类型和形状相同。 + + 添加的噪声是通过torch.randn_like函数生成的,确保了噪声的分布与图像数据相同, + 并且乘以了一个小的常数0.1来控制噪声的强度。 + """ + # 生成与图像数据相同形状的高斯噪声张量,并将其与图像数据相加 + return image + torch.randn_like(image) * 0.1 + + +data_transforms = { + 'train': transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(degrees=15), # 随机旋转图像最多15度 + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 随机改变图像的颜色 + transforms.ToTensor(), + # transforms.Lambda(lambda x: functional.adjust_brightness(x, brightness_factor=torch.rand(1).item())), # 随机调整亮度 + # transforms.Lambda(lambda x: functional.adjust_contrast(x, contrast_factor=torch.rand(1).item() + 0.5)), # 随机调整对比度 + # transforms.Lambda(lambda x: functional.adjust_hue(x, hue_factor=(torch.rand(1).item()-0.5)/10)), # 随机调整色相 + # transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.1), # 添加随机高斯噪声 + transforms.Lambda(random_brightness), + transforms.Lambda(random_contrast), + transforms.Lambda(random_hue), + transforms.Lambda(add_gaussian_noise), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), +} + +# 加载数据集 +data_dir = '/Users/jasonwong/Downloads/二维码测试1/data' +image_datasets = {x: ImageFolder(os.path.join(data_dir, x), data_transforms[x]) + for x in ['train', 'val']} +dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) + for x in ['train', 'val']} +dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} + +#设置损失函数和优化器】 +num_normal_samples = 7891 +num_anomaly_samples = 846 +class_sample_counts = [num_normal_samples, num_anomaly_samples] # 替换为实际样本数量 +weights = 1.0 / np.array(class_sample_counts) +weights /= weights.sum() # 确保权重和为1 +class_weights = torch.FloatTensor(weights).to(device) # 将权重转换为PyTorch张量,并移动到设备上 +# criterion = nn.CrossEntropyLoss() +# optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9) +optimizer = Adam(model.parameters(), + lr=1e-4, # 学习率 + betas=(0.9, 0.999), # beta_1 和 beta_2 + eps=1e-08) # epsilon +scheduler = StepLR(optimizer, step_size=7, gamma=0.1) + +# 初始化GradScaler +# scaler = GradScaler() + +#训练模型 +def train_model(model, dataloaders, dataset_sizes, device, num_epochs=10, save_path='path_to_save/full_model'): + # criterion = FocalLoss(gamma=2, alpha=[0.25, 0.75], reduction='mean') + # 加载加权交叉熵损失函数 + criterion = nn.CrossEntropyLoss(weight=class_weights) + + for epoch in range(num_epochs): + print(f'Starting epoch {epoch + 1}/{num_epochs}') + + for phase in ['train', 'val']: + if phase == 'train': + model.train() + else: + model.eval() + + running_loss = 0.0 + running_corrects = 0 + + for inputs, labels in tqdm(dataloaders[phase], desc=f'Epoch {epoch + 1} {phase.capitalize()}'): + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + + with torch.set_grad_enabled(phase == 'train'): + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + if phase == 'train': + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / dataset_sizes[phase] + epoch_acc = running_corrects.double() / dataset_sizes[phase] + + print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') + + scheduler.step() + + torch.save({ + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': epoch_loss, + }, f'{save_path}_epoch_{epoch + 1}.pth') + + +if __name__ == '__main__': + train_model(model, dataloaders, dataset_sizes, device, num_epochs=10, save_path='') \ No newline at end of file diff --git a/qrtrain_eval.py b/qrtrain_eval.py new file mode 100644 index 0000000..7f38223 --- /dev/null +++ b/qrtrain_eval.py @@ -0,0 +1,56 @@ +import torch +from torchvision import models, transforms +from PIL import Image +import torch.nn as nn + +# 加载保存的模型 +def load_model(save_path, device): + model = models.resnet18() # 重新定义模型 + num_features = model.fc.in_features + model.fc = nn.Linear(num_features, 2) # 假设是二分类问题 + + checkpoint = torch.load(save_path, map_location=device) + model.load_state_dict(checkpoint['model_state_dict']) + + model = model.to(device) + model.eval() # 设置模型为评估模式 + return model + +# 预处理图片 +def preprocess_image(image_path): + data_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + image = Image.open(image_path).convert("RGB") + image = data_transform(image) + image = image.unsqueeze(0) # 增加批次维度 + return image + +# 进行推理并输出结果 +def predict(model, image_tensor, device): + image_tensor = image_tensor.to(device) + with torch.no_grad(): + outputs = model(image_tensor) + _, preds = torch.max(outputs, 1) + return preds.item() + +if __name__ == '__main__': + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model_path = '_epoch_10.pth' # 替换为你的模型路径 + image_path = '/Users/jasonwong/Downloads/二维码测试1/NormalJpg1/data/val/real/cropped_qr_code_09310.jpg' # 替换为你要测试的图片路径 + + # 加载模型 + model = load_model(model_path, device) + + # 预处理图片 + image_tensor = preprocess_image(image_path) + + # 进行预测 + prediction = predict(model, image_tensor, device) + + class_names = ['Normal', 'Anomaly'] + print(f'Predicted class: {class_names[prediction]}') diff --git a/qrtrain_predict_v1.py b/qrtrain_predict_v1.py new file mode 100644 index 0000000..d962df0 --- /dev/null +++ b/qrtrain_predict_v1.py @@ -0,0 +1,72 @@ +import os +import cv2 +import torch +import torch.nn as nn +from torchvision import models, transforms +from sklearn.svm import SVC +import joblib +from PIL import Image +import numpy as np + +# 加载 ResNet 模型并去掉最后一层 +model = models.resnet50(pretrained=False) +modules = list(model.children())[:-1] # 去掉最后的全连接层 +model = nn.Sequential(*modules) + +# 加载训练好的 ResNet 模型权重 +model.load_state_dict(torch.load('resnet_feature_extractor.pth')) +model.eval() + +# 使用 GPU +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = model.to(device) + +# 加载 SVM 分类器 +svm = joblib.load('svm_classifier.joblib') + +# 数据预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# 提取 ORB 特征 +def extract_orb_features(image_path): + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + orb = cv2.ORB_create() + keypoints, descriptors = orb.detectAndCompute(image, None) + if descriptors is None: + return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组 + return descriptors + +# 提取 ResNet 特征 +def extract_resnet_features(image): + image = transform(image).unsqueeze(0) # 添加批次维度 + image = image.to(device) + with torch.no_grad(): + features = model(image) + features = features.view(features.size(0), -1) # 展平特征向量 + return features.cpu().numpy() + +# 推理函数 +def predict(image_path): + # 提取 ORB 特征 + orb_features = extract_orb_features(image_path) + orb_features = np.mean(orb_features, axis=0).reshape(1, -1) # 取平均值作为全局特征 + + # 提取 ResNet 特征 + image = Image.open(image_path).convert('RGB') + resnet_features = extract_resnet_features(image) + + # 特征融合 + features = np.hstack((resnet_features, orb_features)) + + # 进行分类 + prediction = svm.predict(features) + return 'real' if prediction == 0 else 'fake' + +# 示例推理调用 +image_path = 'demo.png' +result = predict(image_path) +print(f'The QR code is {result}.') diff --git a/qrtrain_v1.py b/qrtrain_v1.py new file mode 100644 index 0000000..7770559 --- /dev/null +++ b/qrtrain_v1.py @@ -0,0 +1,120 @@ +import cv2 +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models, transforms +from torch.utils.data import DataLoader, Dataset +from sklearn.svm import SVC +from sklearn.metrics import accuracy_score +from PIL import Image +import numpy as np +import os +from tqdm import tqdm +import joblib + +# 自定义数据集类 +class QRCodeDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.transform = transform + self.image_paths = [] + self.labels = [] + + for label in ['real', 'fake']: + label_dir = os.path.join(root_dir, label) + for img_name in os.listdir(label_dir): + if img_name.endswith(('.png', '.jpg', '.jpeg')): # 只添加图像文件 + self.image_paths.append(os.path.join(label_dir, img_name)) + self.labels.append(0 if label == 'real' else 1) + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = Image.open(img_path) + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + +# 数据预处理 +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# 加载数据集 +train_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/train', transform=transform) +train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + +val_dataset = QRCodeDataset(root_dir='/Users/jasonwong/Downloads/二维码测试1/data/val', transform=transform) +val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) + +# 加载预训练的 ResNet 模型并去掉最后一层 +model = models.resnet50(pretrained=True) +modules = list(model.children())[:-1] # 去掉最后的全连接层 +model = nn.Sequential(*modules) + +# 使用 GPU +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = model.to(device) + +# 提取 ORB 特征 +def extract_orb_features(image_path): + image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) + orb = cv2.ORB_create() + keypoints, descriptors = orb.detectAndCompute(image, None) + if descriptors is None: + return np.zeros((1, 32)) # 如果没有检测到关键点,返回一个零数组 + return descriptors + +# 提取 ResNet 特征 +def extract_resnet_features(loader): + model.eval() + features = [] + labels = [] + with torch.no_grad(): + for inputs, label in tqdm(loader, desc="Extracting ResNet features"): + inputs = inputs.to(device) + output = model(inputs) + output = output.view(output.size(0), -1) # 展平特征向量 + features.append(output.cpu()) + labels.append(label) + return torch.cat(features), torch.cat(labels) + +# 从训练和验证数据集中提取 ResNet 特征 +train_resnet_features, train_labels = extract_resnet_features(train_loader) +val_resnet_features, val_labels = extract_resnet_features(val_loader) + +# 从训练和验证数据集中提取 ORB 特征 +def extract_orb_features_from_dataset(dataset): + features = [] + for img_path in tqdm(dataset.image_paths, desc="Extracting ORB features"): + orb_features = extract_orb_features(img_path) + features.append(np.mean(orb_features, axis=0)) # 取平均值作为全局特征 + return np.array(features) + +train_orb_features = extract_orb_features_from_dataset(train_dataset) +val_orb_features = extract_orb_features_from_dataset(val_dataset) + +# 特征融合 +train_features = np.hstack((train_resnet_features.numpy(), train_orb_features)) +val_features = np.hstack((val_resnet_features.numpy(), val_orb_features)) + +# 使用 SVM 分类器进行训练和评估 +svm = SVC(kernel='linear') +svm.fit(train_features, train_labels) + +val_preds = svm.predict(val_features) +val_acc = accuracy_score(val_labels, val_preds) +print(f'Validation Accuracy: {val_acc:.4f}') + +# 保存 ResNet 模型特征提取部分 +torch.save(model.state_dict(), 'resnet_feature_extractor.pth') + +# 保存 SVM 分类器 +joblib.dump(svm, 'svm_classifier.joblib') diff --git a/res/000035.jpg b/res/000035.jpg new file mode 100644 index 0000000..799e3cc Binary files /dev/null and b/res/000035.jpg differ diff --git a/res/007563.jpg b/res/007563.jpg new file mode 100644 index 0000000..2e3aebc Binary files /dev/null and b/res/007563.jpg differ