import torch from torchvision import models, transforms from PIL import Image import torch.nn as nn # 加载保存的模型 def load_model(save_path, device): model = models.resnet18() # 重新定义模型 num_features = model.fc.in_features model.fc = nn.Linear(num_features, 2) # 假设是二分类问题 checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() # 设置模型为评估模式 return model # 预处理图片 def preprocess_image(image_path): data_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert("RGB") image = data_transform(image) image = image.unsqueeze(0) # 增加批次维度 return image # 进行推理并输出结果 def predict(model, image_tensor, device): image_tensor = image_tensor.to(device) with torch.no_grad(): outputs = model(image_tensor) _, preds = torch.max(outputs, 1) return preds.item() if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_path = '_epoch_10.pth' # 替换为你的模型路径 image_path = '/Users/jasonwong/Downloads/二维码测试1/NormalJpg1/data/val/real/cropped_qr_code_09310.jpg' # 替换为你要测试的图片路径 # 加载模型 model = load_model(model_path, device) # 预处理图片 image_tensor = preprocess_image(image_path) # 进行预测 prediction = predict(model, image_tensor, device) class_names = ['Normal', 'Anomaly'] print(f'Predicted class: {class_names[prediction]}')