plots.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Plotting utils
  4. """
  5. import contextlib
  6. import math
  7. import os
  8. from copy import copy
  9. from pathlib import Path
  10. from urllib.error import URLError
  11. import cv2
  12. import matplotlib
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. import pandas as pd
  16. import seaborn as sn
  17. import torch
  18. from PIL import Image, ImageDraw, ImageFont
  19. from utils import TryExcept, threaded
  20. from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
  21. is_ascii, xywh2xyxy, xyxy2xywh)
  22. from utils.metrics import fitness
  23. from utils.segment.general import scale_image
  24. # Settings
  25. RANK = int(os.getenv('RANK', -1))
  26. matplotlib.rc('font', **{'size': 11})
  27. matplotlib.use('Agg') # for writing to files only
  28. class Colors:
  29. # Ultralytics color palette https://ultralytics.com/
  30. def __init__(self):
  31. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  32. hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', '0018EC', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  33. '2C99A8', '00C2FF', '344593', '6473FF', 'CFD231', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  34. self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
  35. self.n = len(self.palette)
  36. def __call__(self, i, bgr=False):
  37. c = self.palette[int(i) % self.n]
  38. return (c[2], c[1], c[0]) if bgr else c
  39. @staticmethod
  40. def hex2rgb(h): # rgb order (PIL)
  41. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  42. colors = Colors() # create instance for 'from utils.plots import colors'
  43. def check_pil_font(font=FONT, size=10):
  44. # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
  45. font = Path(font)
  46. font = font if font.exists() else (CONFIG_DIR / font.name)
  47. try:
  48. return ImageFont.truetype(str(font) if font.exists() else font.name, size)
  49. except Exception: # download if missing
  50. try:
  51. check_font(font)
  52. return ImageFont.truetype(str(font), size)
  53. except TypeError:
  54. check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
  55. except URLError: # not online
  56. return ImageFont.load_default()
  57. class Annotator:
  58. # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
  59. def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
  60. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
  61. non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
  62. self.pil = pil or non_ascii
  63. if self.pil: # use PIL
  64. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  65. self.draw = ImageDraw.Draw(self.im)
  66. self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
  67. size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
  68. else: # use cv2
  69. self.im = im
  70. self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
  71. def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255),flag=True):
  72. # Add one xyxy box to image with label
  73. if self.pil or not is_ascii(label):
  74. if flag:
  75. self.draw.rectangle(box, width=self.lw, outline=color) # box
  76. if label:
  77. w, h = self.font.getsize(label) # text width, height
  78. outside = box[1] - h >= 0 # label fits outside box
  79. if flag:
  80. self.draw.rectangle(
  81. (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
  82. box[1] + 1 if outside else box[1] + h + 1),
  83. fill=color,
  84. )
  85. # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
  86. self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
  87. else:
  88. self.draw.text((box[0], box[1] - h if outside else box[1]), label.split(' ')[0], fill=color, font=self.font)
  89. else: # cv2
  90. p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  91. if flag:
  92. cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
  93. if label:
  94. tf = max(self.lw - 1, 1) # font thickness
  95. w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
  96. outside = p1[1] - h >= 3
  97. p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
  98. if flag:
  99. cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
  100. cv2.putText(self.im,
  101. label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
  102. 0,
  103. self.lw / 3,
  104. txt_color,
  105. thickness=tf,
  106. lineType=cv2.LINE_AA)
  107. else:
  108. cv2.putText(self.im,
  109. label.split(' ')[0], (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
  110. 0,
  111. self.lw / 3,
  112. color,
  113. thickness=tf,
  114. lineType=cv2.LINE_AA)
  115. def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  116. """Plot masks at once.
  117. Args:
  118. masks (tensor): predicted masks on cuda, shape: [n, h, w]
  119. colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
  120. im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
  121. alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
  122. """
  123. if self.pil:
  124. # convert to numpy first
  125. self.im = np.asarray(self.im).copy()
  126. if len(masks) == 0:
  127. self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
  128. colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
  129. colors = colors[:, None, None] # shape(n,1,1,3)
  130. masks = masks.unsqueeze(3) # shape(n,h,w,1)
  131. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  132. inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  133. mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
  134. im_gpu = im_gpu.flip(dims=[0]) # flip channel
  135. im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
  136. im_gpu = im_gpu * inv_alph_masks[-1] + mcs
  137. im_mask = (im_gpu * 255).byte().cpu().numpy()
  138. self.im[:] = im_mask if retina_masks else scale_image(im_gpu.shape, im_mask, self.im.shape)
  139. if self.pil:
  140. # convert im back to PIL and update draw
  141. self.fromarray(self.im)
  142. def masks_cpu(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  143. """Plot masks at once.
  144. Args:
  145. masks (numpy.array): predicted masks on cuda, shape: [n, h, w]
  146. colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
  147. im_gpu (numpy.array): img is in cuda, shape: [3, h, w], range: [0, 1]
  148. alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
  149. """
  150. import time
  151. m_1_t = time.time()
  152. if self.pil:
  153. # convert to numpy first
  154. self.im = np.asarray(self.im).copy()
  155. m_2_t = time.time()
  156. print("m_2_t: ", m_2_t - m_1_t)
  157. if len(masks) == 0:
  158. self.im[:] = im_gpu.transpose(1, 2, 0) * 255
  159. m_3_t = time.time()
  160. print("m_3_t: ", m_3_t - m_2_t)
  161. colors = np.array(colors) / 255.0 #torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
  162. colors = colors[:, None, None] # shape(n,1,1,3)
  163. masks = masks[...,None] # shape(n,h,w,1)
  164. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  165. m_4_t = time.time()
  166. print("m_4_t: ", m_4_t - m_3_t)
  167. inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  168. m_5_t = time.time()
  169. print("m_5_t: ", m_5_t - m_4_t)
  170. mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
  171. m_6_t = time.time()
  172. print("m_6_t: ", m_6_t - m_5_t)
  173. #im_gpu = im_gpu.flip(dims=[0]) # flip channel
  174. im_gpu = im_gpu.transpose(1, 2, 0) # shape(h,w,3)
  175. im_gpu = im_gpu * inv_alph_masks[-1] + mcs
  176. im_mask = (im_gpu * 255).astype(np.uint8)
  177. self.im[:] = im_mask if retina_masks else scale_image(im_gpu.shape, im_mask, self.im.shape)
  178. if self.pil:
  179. # convert im back to PIL and update draw
  180. self.fromarray(self.im)
  181. def rectangle(self, xy, fill=None, outline=None, width=1):
  182. # Add rectangle to image (PIL-only)
  183. self.draw.rectangle(xy, fill, outline, width)
  184. def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
  185. # Add text to image (PIL-only)
  186. if anchor == 'bottom': # start y from font bottom
  187. w, h = self.font.getsize(text) # text width, height
  188. xy[1] += 1 - h
  189. self.draw.text(xy, text, fill=txt_color, font=self.font)
  190. def fromarray(self, im):
  191. # Update self.im from a numpy array
  192. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  193. self.draw = ImageDraw.Draw(self.im)
  194. def result(self):
  195. # Return annotated image as array
  196. return np.asarray(self.im)
  197. def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
  198. """
  199. x: Features to be visualized
  200. module_type: Module type
  201. stage: Module stage within model
  202. n: Maximum number of feature maps to plot
  203. save_dir: Directory to save results
  204. """
  205. if 'Detect' not in module_type:
  206. batch, channels, height, width = x.shape # batch, channels, height, width
  207. if height > 1 and width > 1:
  208. f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  209. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  210. n = min(n, channels) # number of plots
  211. fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  212. ax = ax.ravel()
  213. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  214. for i in range(n):
  215. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  216. ax[i].axis('off')
  217. LOGGER.info(f'Saving {f}... ({n}/{channels})')
  218. plt.savefig(f, dpi=300, bbox_inches='tight')
  219. plt.close()
  220. np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
  221. def hist2d(x, y, n=100):
  222. # 2d histogram used in labels.png and evolve.png
  223. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  224. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  225. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  226. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  227. return np.log(hist[xidx, yidx])
  228. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  229. from scipy.signal import butter, filtfilt
  230. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  231. def butter_lowpass(cutoff, fs, order):
  232. nyq = 0.5 * fs
  233. normal_cutoff = cutoff / nyq
  234. return butter(order, normal_cutoff, btype='low', analog=False)
  235. b, a = butter_lowpass(cutoff, fs, order=order)
  236. return filtfilt(b, a, data) # forward-backward filter
  237. def output_to_target(output, max_det=300):
  238. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
  239. targets = []
  240. for i, o in enumerate(output):
  241. box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
  242. j = torch.full((conf.shape[0], 1), i)
  243. targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
  244. return torch.cat(targets, 0).numpy()
  245. @threaded
  246. def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
  247. # Plot image grid with labels
  248. if isinstance(images, torch.Tensor):
  249. images = images.cpu().float().numpy()
  250. if isinstance(targets, torch.Tensor):
  251. targets = targets.cpu().numpy()
  252. max_size = 1920 # max image size
  253. max_subplots = 16 # max image subplots, i.e. 4x4
  254. bs, _, h, w = images.shape # batch size, _, height, width
  255. bs = min(bs, max_subplots) # limit plot images
  256. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  257. if np.max(images[0]) <= 1:
  258. images *= 255 # de-normalise (optional)
  259. # Build Image
  260. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  261. for i, im in enumerate(images):
  262. if i == max_subplots: # if last batch has fewer images than we expect
  263. break
  264. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  265. im = im.transpose(1, 2, 0)
  266. mosaic[y:y + h, x:x + w, :] = im
  267. # Resize (optional)
  268. scale = max_size / ns / max(h, w)
  269. if scale < 1:
  270. h = math.ceil(scale * h)
  271. w = math.ceil(scale * w)
  272. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  273. # Annotate
  274. fs = int((h + w) * ns * 0.01) # font size
  275. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
  276. for i in range(i + 1):
  277. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  278. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  279. if paths:
  280. annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  281. if len(targets) > 0:
  282. ti = targets[targets[:, 0] == i] # image targets
  283. boxes = xywh2xyxy(ti[:, 2:6]).T
  284. classes = ti[:, 1].astype('int')
  285. labels = ti.shape[1] == 6 # labels if no conf column
  286. conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
  287. if boxes.shape[1]:
  288. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  289. boxes[[0, 2]] *= w # scale to pixels
  290. boxes[[1, 3]] *= h
  291. elif scale < 1: # absolute coords need scale if image scales
  292. boxes *= scale
  293. boxes[[0, 2]] += x
  294. boxes[[1, 3]] += y
  295. for j, box in enumerate(boxes.T.tolist()):
  296. cls = classes[j]
  297. color = colors(cls)
  298. cls = names[cls] if names else cls
  299. if labels or conf[j] > 0.25: # 0.25 conf thresh
  300. label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
  301. annotator.box_label(box, label, color=color)
  302. annotator.im.save(fname) # save
  303. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  304. # Plot LR simulating training for full epochs
  305. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  306. y = []
  307. for _ in range(epochs):
  308. scheduler.step()
  309. y.append(optimizer.param_groups[0]['lr'])
  310. plt.plot(y, '.-', label='LR')
  311. plt.xlabel('epoch')
  312. plt.ylabel('LR')
  313. plt.grid()
  314. plt.xlim(0, epochs)
  315. plt.ylim(0)
  316. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  317. plt.close()
  318. def plot_val_txt(): # from utils.plots import *; plot_val()
  319. # Plot val.txt histograms
  320. x = np.loadtxt('val.txt', dtype=np.float32)
  321. box = xyxy2xywh(x[:, :4])
  322. cx, cy = box[:, 0], box[:, 1]
  323. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  324. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  325. ax.set_aspect('equal')
  326. plt.savefig('hist2d.png', dpi=300)
  327. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  328. ax[0].hist(cx, bins=600)
  329. ax[1].hist(cy, bins=600)
  330. plt.savefig('hist1d.png', dpi=200)
  331. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  332. # Plot targets.txt histograms
  333. x = np.loadtxt('targets.txt', dtype=np.float32).T
  334. s = ['x targets', 'y targets', 'width targets', 'height targets']
  335. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  336. ax = ax.ravel()
  337. for i in range(4):
  338. ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
  339. ax[i].legend()
  340. ax[i].set_title(s[i])
  341. plt.savefig('targets.jpg', dpi=200)
  342. def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
  343. # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
  344. save_dir = Path(file).parent if file else Path(dir)
  345. plot2 = False # plot additional results
  346. if plot2:
  347. ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
  348. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  349. # for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  350. for f in sorted(save_dir.glob('study*.txt')):
  351. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  352. x = np.arange(y.shape[1]) if x is None else np.array(x)
  353. if plot2:
  354. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
  355. for i in range(7):
  356. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  357. ax[i].set_title(s[i])
  358. j = y[3].argmax() + 1
  359. ax2.plot(y[5, 1:j],
  360. y[3, 1:j] * 1E2,
  361. '.-',
  362. linewidth=2,
  363. markersize=8,
  364. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  365. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  366. 'k.-',
  367. linewidth=2,
  368. markersize=8,
  369. alpha=.25,
  370. label='EfficientDet')
  371. ax2.grid(alpha=0.2)
  372. ax2.set_yticks(np.arange(20, 60, 5))
  373. ax2.set_xlim(0, 57)
  374. ax2.set_ylim(25, 55)
  375. ax2.set_xlabel('GPU Speed (ms/img)')
  376. ax2.set_ylabel('COCO AP val')
  377. ax2.legend(loc='lower right')
  378. f = save_dir / 'study.png'
  379. print(f'Saving {f}...')
  380. plt.savefig(f, dpi=300)
  381. @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
  382. def plot_labels(labels, names=(), save_dir=Path('')):
  383. # plot dataset labels
  384. LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
  385. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  386. nc = int(c.max() + 1) # number of classes
  387. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  388. # seaborn correlogram
  389. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  390. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  391. plt.close()
  392. # matplotlib labels
  393. matplotlib.use('svg') # faster
  394. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  395. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  396. with contextlib.suppress(Exception): # color histogram bars by class
  397. [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
  398. ax[0].set_ylabel('instances')
  399. if 0 < len(names) < 30:
  400. ax[0].set_xticks(range(len(names)))
  401. ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
  402. else:
  403. ax[0].set_xlabel('classes')
  404. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  405. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  406. # rectangles
  407. labels[:, 1:3] = 0.5 # center
  408. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  409. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  410. for cls, *box in labels[:1000]:
  411. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  412. ax[1].imshow(img)
  413. ax[1].axis('off')
  414. for a in [0, 1, 2, 3]:
  415. for s in ['top', 'right', 'left', 'bottom']:
  416. ax[a].spines[s].set_visible(False)
  417. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  418. matplotlib.use('Agg')
  419. plt.close()
  420. def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
  421. # Show classification image grid with labels (optional) and predictions (optional)
  422. from utils.augmentations import denormalize
  423. names = names or [f'class{i}' for i in range(1000)]
  424. blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
  425. dim=0) # select batch index 0, block by channels
  426. n = min(len(blocks), nmax) # number of plots
  427. m = min(8, round(n ** 0.5)) # 8 x 8 default
  428. fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
  429. ax = ax.ravel() if m > 1 else [ax]
  430. # plt.subplots_adjust(wspace=0.05, hspace=0.05)
  431. for i in range(n):
  432. ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
  433. ax[i].axis('off')
  434. if labels is not None:
  435. s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
  436. ax[i].set_title(s, fontsize=8, verticalalignment='top')
  437. plt.savefig(f, dpi=300, bbox_inches='tight')
  438. plt.close()
  439. if verbose:
  440. LOGGER.info(f"Saving {f}")
  441. if labels is not None:
  442. LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
  443. if pred is not None:
  444. LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
  445. return f
  446. def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
  447. # Plot evolve.csv hyp evolution results
  448. evolve_csv = Path(evolve_csv)
  449. data = pd.read_csv(evolve_csv)
  450. keys = [x.strip() for x in data.columns]
  451. x = data.values
  452. f = fitness(x)
  453. j = np.argmax(f) # max fitness index
  454. plt.figure(figsize=(10, 12), tight_layout=True)
  455. matplotlib.rc('font', **{'size': 8})
  456. print(f'Best results from row {j} of {evolve_csv}:')
  457. for i, k in enumerate(keys[7:]):
  458. v = x[:, 7 + i]
  459. mu = v[j] # best single result
  460. plt.subplot(6, 5, i + 1)
  461. plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  462. plt.plot(mu, f.max(), 'k+', markersize=15)
  463. plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
  464. if i % 5 != 0:
  465. plt.yticks([])
  466. print(f'{k:>15}: {mu:.3g}')
  467. f = evolve_csv.with_suffix('.png') # filename
  468. plt.savefig(f, dpi=200)
  469. plt.close()
  470. print(f'Saved {f}')
  471. def plot_results(file='path/to/results.csv', dir=''):
  472. # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
  473. save_dir = Path(file).parent if file else Path(dir)
  474. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  475. ax = ax.ravel()
  476. files = list(save_dir.glob('results*.csv'))
  477. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  478. for f in files:
  479. try:
  480. data = pd.read_csv(f)
  481. s = [x.strip() for x in data.columns]
  482. x = data.values[:, 0]
  483. for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
  484. y = data.values[:, j].astype('float')
  485. # y[y == 0] = np.nan # don't show zero values
  486. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
  487. ax[i].set_title(s[j], fontsize=12)
  488. # if j in [8, 9, 10]: # share train and val loss y axes
  489. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  490. except Exception as e:
  491. LOGGER.info(f'Warning: Plotting error for {f}: {e}')
  492. ax[1].legend()
  493. fig.savefig(save_dir / 'results.png', dpi=200)
  494. plt.close()
  495. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  496. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  497. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  498. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  499. files = list(Path(save_dir).glob('frames*.txt'))
  500. for fi, f in enumerate(files):
  501. try:
  502. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  503. n = results.shape[1] # number of rows
  504. x = np.arange(start, min(stop, n) if stop else n)
  505. results = results[:, x]
  506. t = (results[0] - results[0].min()) # set t0=0s
  507. results[0] = x
  508. for i, a in enumerate(ax):
  509. if i < len(results):
  510. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  511. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  512. a.set_title(s[i])
  513. a.set_xlabel('time (s)')
  514. # if fi == len(files) - 1:
  515. # a.set_ylim(bottom=0)
  516. for side in ['top', 'right']:
  517. a.spines[side].set_visible(False)
  518. else:
  519. a.remove()
  520. except Exception as e:
  521. print(f'Warning: Plotting error for {f}; {e}')
  522. ax[1].legend()
  523. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  524. def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
  525. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  526. xyxy = torch.tensor(xyxy).view(-1, 4)
  527. b = xyxy2xywh(xyxy) # boxes
  528. if square:
  529. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  530. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  531. xyxy = xywh2xyxy(b).long()
  532. clip_boxes(xyxy, im.shape)
  533. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  534. if save:
  535. file.parent.mkdir(parents=True, exist_ok=True) # make directory
  536. f = str(increment_path(file).with_suffix('.jpg'))
  537. # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
  538. Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
  539. return crop