57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
![]() |
import torch
|
||
|
from torchvision import models, transforms
|
||
|
from PIL import Image
|
||
|
import torch.nn as nn
|
||
|
|
||
|
# 加载保存的模型
|
||
|
def load_model(save_path, device):
|
||
|
model = models.resnet18() # 重新定义模型
|
||
|
num_features = model.fc.in_features
|
||
|
model.fc = nn.Linear(num_features, 2) # 假设是二分类问题
|
||
|
|
||
|
checkpoint = torch.load(save_path, map_location=device)
|
||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
|
||
|
model = model.to(device)
|
||
|
model.eval() # 设置模型为评估模式
|
||
|
return model
|
||
|
|
||
|
# 预处理图片
|
||
|
def preprocess_image(image_path):
|
||
|
data_transform = transforms.Compose([
|
||
|
transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
|
])
|
||
|
|
||
|
image = Image.open(image_path).convert("RGB")
|
||
|
image = data_transform(image)
|
||
|
image = image.unsqueeze(0) # 增加批次维度
|
||
|
return image
|
||
|
|
||
|
# 进行推理并输出结果
|
||
|
def predict(model, image_tensor, device):
|
||
|
image_tensor = image_tensor.to(device)
|
||
|
with torch.no_grad():
|
||
|
outputs = model(image_tensor)
|
||
|
_, preds = torch.max(outputs, 1)
|
||
|
return preds.item()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
model_path = '_epoch_10.pth' # 替换为你的模型路径
|
||
|
image_path = '/Users/jasonwong/Downloads/二维码测试1/NormalJpg1/data/val/real/cropped_qr_code_09310.jpg' # 替换为你要测试的图片路径
|
||
|
|
||
|
# 加载模型
|
||
|
model = load_model(model_path, device)
|
||
|
|
||
|
# 预处理图片
|
||
|
image_tensor = preprocess_image(image_path)
|
||
|
|
||
|
# 进行预测
|
||
|
prediction = predict(model, image_tensor, device)
|
||
|
|
||
|
class_names = ['Normal', 'Anomaly']
|
||
|
print(f'Predicted class: {class_names[prediction]}')
|