1         ##完全使用本地权重,识别时根据识别准确率来确定是否绘制
 2         import matplotlib.pyplot as plt
 3         import torch
 4         import torchvision.transforms as T
 5         import torchvision
 6         import cv2
 7         from torchvision.io.image import read_image
 8         from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
 9 
10         import warnings
11         warnings.filterwarnings("ignore",category=ResourceWarning)
12         warnings.filterwarnings("ignore",category=DeprecationWarning)
13 
116         img_path = "./jupyterlab/doc/ccc.jpg"        ##骑着自行车的美女,任选
17         img = read_image(img_path)##用pytorch提供的io函数
18 
19         weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
20         ##读本地权重文件,权重文件到pytorch网站下载
21         model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None)
22         myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')
23         model.load_state_dict(myweights)
24         model.eval()##识别工作模式
25         
26         preprocess = weights_info.transforms()
27         batch = [preprocess(img)]
28         prediction = model(batch)[0]
29         labels = [weights_info.meta["categories"][i] for i in prediction["labels"]]
30         boxes = [i for i in prediction["boxes"]]
31         scores = [i for i in prediction["scores"]]
32 
33         myimg = cv2.imread(img_path)
35         myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB)
36         for i,score in enumerate(scores):
37             if score.item() < 0.9 : continue##舍弃准确率90%以下的
38             myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1)
39             ##注意:cv2这里只接受整型坐标值
40             start_point = (int(boxes[i][0]), int(boxes[i][1]))
41             end_point = (int(boxes[i][2]), int(boxes[i][3]))
42             cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3)
43             cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3)
44         plt.figure(figsize=(7, 5))
45         plt.imshow(myimg)
46         plt.xticks([])
47         plt.yticks([])
48         plt.show()

 

原文地址:http://www.cnblogs.com/ace007/p/16828883.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性