create_eval.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import glob
  2. import json
  3. import os
  4. import shutil
  5. import random
  6. import cv2
  7. import numpy as np
  8. import base64
  9. import math
  10. import io
  11. import PIL.Image
  12. import os.path as osp
  13. import pandas as pd
  14. import shapely
  15. from shapely.geometry import box, Polygon,MultiPolygon,collection
  16. from shapely.geometry import Point
  17. import shapely
  18. from shapely import ops
  19. root_path = "../runs/ma/2023-06-28/detect/exp"
  20. mask_iou=True
  21. cls_names=["crack","hole","debonding","rarefaction"]
  22. gtss=dict() #所有图片的gt,字典的键是图片名,值是掩码类别和坐标
  23. predss=dict() #所有图片的预测结果,字典是图片名,值是掩码类别和坐标
  24. filenames=dict() #键值对是图片名称和它对应的完整存储路径
  25. with open(os.path.join(root_path,"detect_for_val.txt"),"r", encoding="utf-8") as f:
  26. lines=f.readlines()
  27. for line in lines:
  28. file = line.split(" ",2)[-1].rsplit(":",1)[0]
  29. jfile = file.rsplit(".",1)[0]+".json"
  30. bsname = os.path.basename(file).rsplit(".",1)[0] #图片基本名称
  31. #gts[parent+"+"+bsname]=[]
  32. gts_cls = []
  33. for _ in range(len(cls_names)):
  34. gts_cls.append([])
  35. print("file.strip(): ", file.strip())
  36. if os.path.exists(jfile.strip()):
  37. with open(jfile.strip(), "rb") as f_json: #读取gt文件的json文件
  38. json_dict = json.load(f_json)
  39. for shape in json_dict["shapes"]: #读取标注的每一个掩码
  40. ps = []
  41. # 如果是圆,就将圆转化为多边形
  42. if len(shape["points"]) == 2: #圆有的是采用两点标注法,计算圆心和圆的半径,然后转化为Polygon
  43. # print("circle,dst_img_path:",dst_img_path)
  44. r = math.sqrt(math.pow(shape["points"][0][0] - shape["points"][1][0], 2) + math.pow(
  45. shape["points"][0][1] - shape["points"][1][1], 2))
  46. circle = Point(shape["points"][0]).buffer(r)
  47. poly1 = Polygon(circle)
  48. else:
  49. poly1 = Polygon(shape["points"])
  50. # print("bsname: ", bsname,parent)
  51. # if bsname=="478.521mm" and parent=="BB1B2-20170928":
  52. # print("shape: ", shape)
  53. if "crack" in shape["label"] : #只挑选自己指定的类别,设备类不做指标计算
  54. idx = cls_names.index("crack")
  55. gts_cls[idx].append(poly1) #gts_cls按类别保存每一类的所有标注框
  56. elif "hole" in shape["label"] : #只挑选自己指定的类别,设备类不做指标计算
  57. idx = cls_names.index("hole")
  58. gts_cls[idx].append(poly1) #gts_cls按类别保存每一类的所有标注框
  59. elif "debonding" in shape["label"]: # 只挑选自己指定的类别,设备类不做指标计算
  60. idx = cls_names.index("debonding")
  61. gts_cls[idx].append(poly1) # gts_cls按类别保存每一类的所有标注框
  62. elif "rarefaction" in shape["label"]: # 只挑选自己指定的类别,设备类不做指标计算
  63. idx = cls_names.index("rarefaction")
  64. gts_cls[idx].append(poly1) # gts_cls按类别保存每一类的所有标注框
  65. gtss[bsname]=gts_cls
  66. filenames[bsname]=file
  67. print("gtss: ",len(gtss),gtss)
  68. for file in glob.glob(root_path+"/labels/*"): #遍历每一个检测结果txt文件
  69. bsname=os.path.basename(file).rsplit(".",1)[0]
  70. preds_cls = []
  71. for _ in range(len(cls_names)):
  72. preds_cls.append([])
  73. with open(file,"r", encoding="utf-8") as f:
  74. for line in f.readlines():
  75. line=line.strip().split()
  76. if len(line)==0:
  77. continue
  78. cls_id=int(line[0])
  79. if cls_id>len(cls_names):
  80. print("cls_id 超出类别索引")
  81. else:
  82. ps = line[1:]
  83. polygon_ps = []
  84. for i in range(0, len(ps), 2):
  85. x, y = int(float(ps[i])), int(float(ps[i + 1]))
  86. polygon_ps.append([x, y ])
  87. poly = Polygon(polygon_ps)
  88. preds_cls[cls_id].append(poly)
  89. predss[bsname]=preds_cls
  90. print("predss: ",len(predss),predss)
  91. nums=dict()
  92. tp_gt_preds=[]
  93. prs=[]
  94. for _ in range(len(cls_names)):
  95. tp_gt_pred_cls=[]
  96. pr_cls=[]
  97. for _ in range(3):
  98. tp_gt_pred_cls.append(0) #装每一类的tp、gt、pred
  99. pr_cls.append([])
  100. pr_cls.append([]) #装每一类的p、r
  101. tp_gt_preds.append(tp_gt_pred_cls)
  102. prs.append(pr_cls)
  103. #print("tp_gt_preds: ", tp_gt_preds)
  104. for k,gts_cls in gtss.items(): #遍历每一个gt文件的真值框
  105. #print("k:",k,gts_cls)
  106. tp_gt_pred_cls=[]
  107. #print(type(predss))
  108. if k not in predss.keys():
  109. for idx in range(len(cls_names)): # 按类别遍历每一个类别的真值框
  110. tp_gt_pred = [0, 0, 0, 0, 0]
  111. gts = gts_cls[idx]
  112. gt_num = len(gts)
  113. tp_gt_pred[2] = gt_num
  114. tp_gt_pred_cls.append(tp_gt_pred)
  115. nums[k] = tp_gt_pred_cls
  116. continue
  117. preds_cls=predss[k]
  118. #print(k,gts_cls)
  119. #print(k,preds_cls)
  120. for idx in range(len(cls_names)): #按类别遍历每一个类别的真值框
  121. tp_gt_pred = [0, 0, 0,0,0]
  122. gts,preds = gts_cls[idx],preds_cls[idx]
  123. gt_num,pred_num=len(gts),len(preds)
  124. tp_gt_pred[2]=gt_num
  125. tp_gt_pred[1]=pred_num
  126. tp_num=0
  127. pred_match=[]
  128. for _ in range(len(preds)):
  129. pred_match.append(0)
  130. for gt in gts: #遍历每一个真值框
  131. gt_ps = (list(gt.exterior.coords))
  132. xs=[]
  133. ys=[]
  134. for p in gt_ps:
  135. xs.append(p[0])
  136. ys.append(p[1])
  137. gt_x1,gt_x2,gt_y1,gt_y2=min(xs),max(xs),min(ys),max(ys) #根据真值多边形的点坐标,计算外接矩形的左上角和右下角坐标
  138. gt_a=(gt_x2-gt_x1)*(gt_y2-gt_y1)
  139. # print(type(union)) #<class 'shapely.geometry.multipoint.MultiPoint'>
  140. ps = np.array(ps, dtype=np.float).astype(np.int)
  141. ps = ps.reshape(-1, 2)
  142. for i in range(len(preds)): #遍历对应图片的对应类的所有检测框
  143. pred=preds[i]
  144. if pred_match[i]==1: #如果预测框和真值框匹配上了,就不再遍历该预测框
  145. continue
  146. pred_ps = (list(pred.exterior.coords)) #预测多边形的点集
  147. xs = []
  148. ys = []
  149. for p in pred_ps:
  150. xs.append(p[0])
  151. ys.append(p[1])
  152. pred_x1, pred_x2, pred_y1, pred_y2 = min(xs), max(xs), min(ys), max(ys) #预测多边形的最小外接矩形
  153. pred_a = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
  154. x1=max([gt_x1,pred_x1])
  155. x2 = min([gt_x2, pred_x2])
  156. y1 = max([gt_y1, pred_y1])
  157. y2 = min([gt_y2, pred_y2]) #预测框与真值框的交集矩形
  158. if x2>x1 and y2>y1:
  159. inter_a=(x2-x1)*(y2-y1)
  160. iou=inter_a/(gt_a+pred_a-inter_a) #
  161. matched=False
  162. if "crack" in cls_names[idx]: #不同类别有不同的tp iou阈值
  163. if iou>0.3:
  164. matched=True
  165. elif "hole" in cls_names[idx]:
  166. if iou > 0.2:
  167. matched = True
  168. elif "debonding" in cls_names[idx]:
  169. if iou > 0.4:
  170. matched = True
  171. elif "rarefaction" in cls_names[idx]:
  172. if iou>0.5:
  173. matched=True
  174. if matched: #如果预测框和真值框匹配上了,对应的标记位就置为1
  175. pred_match[i]=1
  176. tp_num +=1
  177. break #只要有一个pred与gt匹配上了,tp就加1,然后退出,不再遍历后面的pred
  178. tp_gt_pred[0] = tp_num #保存当前图片的tp,并计算该图片的p和r
  179. tp_gt_pred[3] = tp_num * 1.0 / (pred_num + 1.0e-10)
  180. tp_gt_pred[4] = tp_num * 1.0 / (gt_num + 1.0e-10)
  181. tp_gt_pred_cls.append(tp_gt_pred)
  182. #print("tp_gt_preds: ", tp_gt_preds,i)
  183. tp_gt_preds[idx][0] += tp_num #将当前类的所有图片的tp、pred和gt累加起来
  184. tp_gt_preds[idx][2] += gt_num
  185. tp_gt_preds[idx][1] += pred_num
  186. nums[k]=tp_gt_pred_cls #k是图片名,值是该图每一类的tp、pred、gt、p和r
  187. #print("k , nums[k]: ", k, nums[k])
  188. dict_e=dict()
  189. dict_e["图片名称"]=[]
  190. for j in range(len(cls_names)):
  191. for i in ["tp","pred","gt","p","r"]:
  192. dict_e[str(j)+"_"+i]=[] #装tp、gt、pred,p,r
  193. #将所有图片的检测结果保存到字典里,准备存入excel表里,表列名就是图片名和每一类的tp、gt、pred,p,r
  194. for name,tp_gt_pred_cls in nums.items():
  195. dict_e["图片名称"].append(name)
  196. for i in range(len(tp_gt_pred_cls)):
  197. inds=["tp","pred","gt","p","r"]
  198. for j in range(5): #tp,gt,pred,p,r
  199. dict_e[str(i)+"_"+inds[j]].append(tp_gt_pred_cls[i][j])
  200. #print("dict_e: ", dict_e)
  201. dict_e["图片名称"].insert(0,"合计")
  202. #将所有图片的tp、gt、pred累加起来,计算总的tp、gt、pred,并在此基础上计算所有图片每一类的tp、gt、pred,p,r。
  203. for i in range(len(cls_names)):
  204. inds = ["tp", "pred", "gt"]
  205. for j in range(3): # tp,gt,pred
  206. t =sum(dict_e[str(i) + "_" + inds[j]])
  207. dict_e[str(i) + "_" + inds[j]].insert(0,t)
  208. tps=dict_e[str(i) + "_" + inds[0]][0]
  209. preds = dict_e[str(i) + "_" + inds[1]][0]
  210. gts = dict_e[str(i) + "_" + inds[2]][0]
  211. p_t=tps/(preds+1.0e-10)
  212. r_t=tps/(gts+1.0e-10)
  213. dict_e[str(i) + "_p" ].insert(0,p_t)
  214. dict_e[str(i) + "_r"].insert(0,r_t)
  215. dict_g=dict()
  216. dict_g["图片名称"]=[]
  217. for j in range(len(cls_names)):
  218. for i in ["tp","pred","gt","p","r"]:
  219. dict_g[str(j)+"_"+i]=[] #装tp、gt、pred,p,r
  220. for i in range(2,len(dict_e["图片名称"]),):
  221. good=True
  222. for j in range(len(cls_names)):
  223. inds = ["tp", "pred", "gt","p","r"]
  224. if dict_e[str(j)+"_"+inds[2]][i]>0:
  225. if dict_e[str(j)+"_"+inds[3]][i]<0.5 or dict_e[str(j)+"_"+inds[4]][i]<0.4:
  226. good=False
  227. break
  228. else:
  229. if dict_e[str(j)+"_"+inds[1]][i]>0:
  230. good=False
  231. break
  232. if good:
  233. dict_g["图片名称"].append(dict_e["图片名称"][i])
  234. for j in range(len(cls_names)):
  235. inds = ["tp", "pred", "gt", "p", "r"]
  236. for k in range(len(inds)):
  237. dict_g[str(j) + "_" + inds[k]].append(dict_e[str(j) + "_" + inds[k]][i])
  238. dict_g["图片名称"].insert(0, "合计")
  239. for i in range(len(cls_names)):
  240. inds = ["tp", "pred", "gt"]
  241. for j in range(3): # tp,gt,pred
  242. t = sum(dict_g[str(i) + "_" + inds[j]])
  243. dict_g[str(i) + "_" + inds[j]].insert(0, t)
  244. tps = dict_g[str(i) + "_" + inds[0]][0]
  245. preds = dict_g[str(i) + "_" + inds[1]][0]
  246. gts = dict_g[str(i) + "_" + inds[2]][0]
  247. p_t = tps / (preds + 1.0e-10)
  248. r_t = tps / (gts + 1.0e-10)
  249. dict_g[str(i) + "_p"].insert(0, p_t)
  250. dict_g[str(i) + "_r"].insert(0, r_t)
  251. dst_file= "../eval/eval.txt"
  252. with open(dst_file, "w",encoding="utf-8") as f:
  253. for i in range(2, len(dict_g["图片名称"])):
  254. f.writelines(filenames[dict_g["图片名称"][i]]+"\n")