mydetection.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import subprocess
  2. import os
  3. import numpy as np
  4. import cv2
  5. import matplotlib.pyplot as plt
  6. from matplotlib.patches import Rectangle
  7. plt.rcParams['font.sans-serif'] = ['SimHei']
  8. plt.rcParams['axes.unicode_minus'] = False
  9. from my_config import averagePooling, maxPooling, xyOverlap
  10. import sys
  11. plt.ioff()
  12. def ship_detection_fun(pic_path, label_path):
  13. sar_img = cv2.imread(pic_path)
  14. # 读取标签矩阵
  15. label_matrix = np.loadtxt(label_path)
  16. label_matrix[:, 2] -= label_matrix[:, 0]
  17. label_matrix[:, 3] -= label_matrix[:, 1]
  18. # 转换为灰度图像
  19. image_init = cv2.cvtColor(sar_img, cv2.COLOR_BGR2GRAY)
  20. # 进行卷积和池化操作
  21. image = cv2.filter2D(image_init, -1, np.ones((13, 13)) / 169)
  22. image = averagePooling(image, [5, 5])
  23. # 进行 GMM 聚类
  24. from sklearn.mixture import GaussianMixture
  25. numComponents = 10
  26. selected_components = 2
  27. gmm = GaussianMixture(n_components=numComponents)
  28. image_data = np.float32(image.reshape(-1, 1))
  29. gmm.fit(image_data)
  30. cluster_idx = gmm.predict(image_data)
  31. clusterCenters = gmm.means_
  32. max_idx_vec = np.argsort(clusterCenters.flatten())[-selected_components:]
  33. idx_keep = np.isin(cluster_idx, max_idx_vec)
  34. cluster_idx[~idx_keep] = 0
  35. out_put_image = cluster_idx.reshape(image.shape)
  36. out_put_image = cv2.filter2D(out_put_image, -1, np.ones((12, 12)) / 144)
  37. out_put_image = maxPooling(out_put_image, [5, 5])
  38. contours, _ = cv2.findContours(np.uint8(out_put_image), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  39. boundingBoxes = np.zeros((len(contours), 5))
  40. for i, contour in enumerate(contours):
  41. x, y, w, h = cv2.boundingRect(contour)
  42. boundingBoxes[i, :] = [0, x, y, w, h]
  43. # 存储bounding box
  44. folder_path = './results_label/'
  45. if not os.path.exists(folder_path):
  46. os.makedirs(folder_path)
  47. filePath = f'{folder_path}{os.path.splitext(os.path.basename(pic_path))[0]}.txt'
  48. np.savetxt(filePath, boundingBoxes, fmt='%d', delimiter=' ')
  49. Ngt = len(label_matrix)
  50. Ntt = 0
  51. Nfa = 0
  52. for i in range(len(contours)):
  53. count = 0
  54. for j in range(len(label_matrix)):
  55. areaOverlap = xyOverlap(label_matrix[j], boundingBoxes[i, 1:])
  56. if areaOverlap == 1:
  57. count += 1
  58. break
  59. if count == 0:
  60. Nfa += 1
  61. else:
  62. Ntt += 1
  63. FoM = Ntt / (Ngt + Nfa)
  64. Precision = Ntt / (Ntt + Nfa)
  65. Far = Nfa / (Ntt + Nfa)
  66. # 显示图像结果并存储
  67. fig1, axes1 = plt.subplots(2, 1)
  68. axes1[0].imshow(image_init, cmap='gray')
  69. axes1[0].set_title('原始图像')
  70. axes1[1].imshow(image_init, cmap='gray')
  71. for i in range(len(contours)):
  72. axes1[1].add_patch(Rectangle((boundingBoxes[i, 1], boundingBoxes[i, 2]), boundingBoxes[i, 3], boundingBoxes[i, 4], edgecolor='r', linewidth=2, fill=False))
  73. axes1[1].set_title('检测结果')
  74. axes1[0].axis('off')
  75. axes1[1].axis('off')
  76. plt.axis('off')
  77. plt.tight_layout()
  78. folder_path = './results_pic/'
  79. if not os.path.exists(folder_path):
  80. os.makedirs(folder_path)
  81. plt.savefig(f'{folder_path}{os.path.splitext(os.path.basename(pic_path))[0]}.jpg', bbox_inches='tight', pad_inches= 0.0)
  82. plt.close(fig1)
  83. fig2 = plt.figure()
  84. plt.imshow(image_init, cmap='gray')
  85. for i in range(len(contours)):
  86. plt.gca().add_patch(Rectangle((boundingBoxes[i, 1], boundingBoxes[i, 2]), boundingBoxes[i, 3], boundingBoxes[i, 4], edgecolor='r', linewidth=2, fill=False))
  87. plt.axis('off')
  88. folder_path = './results_pic_original/'
  89. if not os.path.exists(folder_path):
  90. os.makedirs(folder_path)
  91. plt.savefig(f'{folder_path}{os.path.splitext(os.path.basename(pic_path))[0]}.jpg', bbox_inches='tight', pad_inches = 0.0)
  92. absolute_path = os.path.abspath(f'{folder_path}{os.path.splitext(os.path.basename(pic_path))[0]}.jpg')
  93. print(absolute_path)
  94. if __name__ == "__main__":
  95. if len(sys.argv) != 3:
  96. sys.exit(1)
  97. pic_path = sys.argv[1]
  98. label_path = sys.argv[2]
  99. ship_detection_fun(pic_path, label_path)