pythonproject/qrtrain_eval.py

57 lines
1.8 KiB
Python
Raw Permalink Normal View History

2024-08-07 09:46:47 +08:00
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn as nn
# 加载保存的模型
def load_model(save_path, device):
model = models.resnet18() # 重新定义模型
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2) # 假设是二分类问题
checkpoint = torch.load(save_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval() # 设置模型为评估模式
return model
# 预处理图片
def preprocess_image(image_path):
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
image = data_transform(image)
image = image.unsqueeze(0) # 增加批次维度
return image
# 进行推理并输出结果
def predict(model, image_tensor, device):
image_tensor = image_tensor.to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, preds = torch.max(outputs, 1)
return preds.item()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = '_epoch_10.pth' # 替换为你的模型路径
image_path = '/Users/jasonwong/Downloads/二维码测试1/NormalJpg1/data/val/real/cropped_qr_code_09310.jpg' # 替换为你要测试的图片路径
# 加载模型
model = load_model(model_path, device)
# 预处理图片
image_tensor = preprocess_image(image_path)
# 进行预测
prediction = predict(model, image_tensor, device)
class_names = ['Normal', 'Anomaly']
print(f'Predicted class: {class_names[prediction]}')