dataloaders.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import contextlib
  6. import glob
  7. import hashlib
  8. import json
  9. import math
  10. import os
  11. import random
  12. import shutil
  13. import time
  14. from itertools import repeat
  15. from multiprocessing.pool import Pool, ThreadPool
  16. from pathlib import Path
  17. from threading import Thread
  18. from urllib.parse import urlparse
  19. import numpy as np
  20. import psutil
  21. import torch
  22. import torch.nn.functional as F
  23. import torchvision
  24. import yaml
  25. from PIL import ExifTags, Image, ImageOps
  26. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  27. from tqdm import tqdm
  28. from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
  29. letterbox, mixup, random_perspective)
  30. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
  31. check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
  32. xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  33. from utils.torch_utils import torch_distributed_zero_first
  34. # Parameters
  35. HELP_URL = 'See https://docs.ultralytics.com/yolov5/tutorials/train_custom_data'
  36. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
  37. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  38. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  39. RANK = int(os.getenv('RANK', -1))
  40. PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
  41. # Get orientation exif tag
  42. for orientation in ExifTags.TAGS.keys():
  43. if ExifTags.TAGS[orientation] == 'Orientation':
  44. break
  45. def get_hash(paths):
  46. # Returns a single hash value of a list of paths (files or dirs)
  47. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  48. h = hashlib.sha256(str(size).encode()) # hash sizes
  49. h.update(''.join(paths).encode()) # hash paths
  50. return h.hexdigest() # return hash
  51. def exif_size(img):
  52. # Returns exif-corrected PIL size
  53. s = img.size # (width, height)
  54. with contextlib.suppress(Exception):
  55. rotation = dict(img._getexif().items())[orientation]
  56. if rotation in [6, 8]: # rotation 270 or 90
  57. s = (s[1], s[0])
  58. return s
  59. def exif_transpose(image):
  60. """
  61. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  62. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  63. :param image: The image to transpose.
  64. :return: An image.
  65. """
  66. exif = image.getexif()
  67. orientation = exif.get(0x0112, 1) # default 1
  68. if orientation > 1:
  69. method = {
  70. 2: Image.FLIP_LEFT_RIGHT,
  71. 3: Image.ROTATE_180,
  72. 4: Image.FLIP_TOP_BOTTOM,
  73. 5: Image.TRANSPOSE,
  74. 6: Image.ROTATE_270,
  75. 7: Image.TRANSVERSE,
  76. 8: Image.ROTATE_90}.get(orientation)
  77. if method is not None:
  78. image = image.transpose(method)
  79. del exif[0x0112]
  80. image.info['exif'] = exif.tobytes()
  81. return image
  82. def seed_worker(worker_id):
  83. # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
  84. worker_seed = torch.initial_seed() % 2 ** 32
  85. np.random.seed(worker_seed)
  86. random.seed(worker_seed)
  87. def create_dataloader(path,
  88. imgsz,
  89. batch_size,
  90. stride,
  91. single_cls=False,
  92. hyp=None,
  93. augment=False,
  94. cache=False,
  95. pad=0.0,
  96. rect=False,
  97. rank=-1,
  98. workers=8,
  99. image_weights=False,
  100. quad=False,
  101. prefix='',
  102. shuffle=False,
  103. seed=0):
  104. if rect and shuffle:
  105. LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  106. shuffle = False
  107. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  108. dataset = LoadImagesAndLabels(
  109. path,
  110. imgsz,
  111. batch_size,
  112. augment=augment, # augmentation
  113. hyp=hyp, # hyperparameters
  114. rect=rect, # rectangular batches
  115. cache_images=cache,
  116. single_cls=single_cls,
  117. stride=int(stride),
  118. pad=pad,
  119. image_weights=image_weights,
  120. prefix=prefix)
  121. batch_size = min(batch_size, len(dataset))
  122. nd = torch.cuda.device_count() # number of CUDA devices
  123. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  124. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  125. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  126. generator = torch.Generator()
  127. generator.manual_seed(6148914691236517205 + seed + RANK)
  128. return loader(dataset,
  129. batch_size=batch_size,
  130. shuffle=shuffle and sampler is None,
  131. num_workers=nw,
  132. sampler=sampler,
  133. pin_memory=PIN_MEMORY,
  134. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
  135. worker_init_fn=seed_worker,
  136. generator=generator), dataset
  137. class InfiniteDataLoader(dataloader.DataLoader):
  138. """ Dataloader that reuses workers
  139. Uses same syntax as vanilla DataLoader
  140. """
  141. def __init__(self, *args, **kwargs):
  142. super().__init__(*args, **kwargs)
  143. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  144. self.iterator = super().__iter__()
  145. def __len__(self):
  146. return len(self.batch_sampler.sampler)
  147. def __iter__(self):
  148. for _ in range(len(self)):
  149. yield next(self.iterator)
  150. class _RepeatSampler:
  151. """ Sampler that repeats forever
  152. Args:
  153. sampler (Sampler)
  154. """
  155. def __init__(self, sampler):
  156. self.sampler = sampler
  157. def __iter__(self):
  158. while True:
  159. yield from iter(self.sampler)
  160. class LoadScreenshots:
  161. # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
  162. def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
  163. # source = [screen_number left top width height] (pixels)
  164. check_requirements('mss')
  165. import mss
  166. source, *params = source.split()
  167. self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
  168. if len(params) == 1:
  169. self.screen = int(params[0])
  170. elif len(params) == 4:
  171. left, top, width, height = (int(x) for x in params)
  172. elif len(params) == 5:
  173. self.screen, left, top, width, height = (int(x) for x in params)
  174. self.img_size = img_size
  175. self.stride = stride
  176. self.transforms = transforms
  177. self.auto = auto
  178. self.mode = 'stream'
  179. self.frame = 0
  180. self.sct = mss.mss()
  181. # Parse monitor shape
  182. monitor = self.sct.monitors[self.screen]
  183. self.top = monitor['top'] if top is None else (monitor['top'] + top)
  184. self.left = monitor['left'] if left is None else (monitor['left'] + left)
  185. self.width = width or monitor['width']
  186. self.height = height or monitor['height']
  187. self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
  188. def __iter__(self):
  189. return self
  190. def __next__(self):
  191. # mss screen capture: get raw pixels from the screen as np array
  192. im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
  193. s = f'screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: '
  194. if self.transforms:
  195. im = self.transforms(im0) # transforms
  196. else:
  197. im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
  198. im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  199. im = np.ascontiguousarray(im) # contiguous
  200. self.frame += 1
  201. return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
  202. class LoadImagesWhole:
  203. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  204. def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
  205. if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
  206. path = Path(path).read_text().rsplit()
  207. files = []
  208. for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
  209. p = str(Path(p).resolve())
  210. if '*' in p:
  211. files.extend(sorted(glob.glob(p, recursive=True))) # glob
  212. elif os.path.isdir(p):
  213. files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
  214. elif os.path.isfile(p):
  215. files.append(p) # files
  216. else:
  217. raise FileNotFoundError(f'{p} does not exist')
  218. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  219. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  220. ni, nv = len(images), len(videos)
  221. self.img_size = img_size
  222. self.stride = stride
  223. self.files = images + videos
  224. self.nf = ni + nv # number of files
  225. self.video_flag = [False] * ni + [True] * nv
  226. self.mode = 'image'
  227. self.auto = auto
  228. self.transforms = transforms # optional
  229. self.vid_stride = vid_stride # video frame-rate stride
  230. self.cap = None
  231. assert self.nf > 0, f'No images or videos found in {p}. ' \
  232. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  233. def __iter__(self):
  234. self.count = 0
  235. return self
  236. def __next__(self):
  237. if self.count == self.nf:
  238. raise StopIteration
  239. path = self.files[self.count]
  240. # Read image
  241. self.count += 1
  242. im0 = cv2.imread(path) # BGR
  243. assert im0 is not None, f'Image Not Found {path}'
  244. s = f'image {self.count}/{self.nf} {path}: '
  245. if self.transforms:
  246. im = self.transforms(im0) # transforms
  247. else:
  248. # im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
  249. if im0.shape[0] < self.img_size[0] or im0.shape[1] < self.img_size[1]:
  250. im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
  251. else:
  252. im_newsize0 = math.ceil(im0.shape[0] / self.stride) * self.stride
  253. im_newsize1 = math.ceil(im0.shape[1] / self.stride) * self.stride
  254. im = letterbox(im0, (im_newsize0, im_newsize1), stride=self.stride, auto=self.auto)[0] # padded resize
  255. im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  256. im = np.ascontiguousarray(im) # contiguous
  257. return path, im, im0, self.cap, s
  258. def _new_video(self, path):
  259. # Create a new video capture object
  260. self.frame = 0
  261. self.cap = cv2.VideoCapture(path)
  262. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
  263. self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
  264. # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
  265. def _cv2_rotate(self, im):
  266. # Rotate a cv2 video manually
  267. if self.orientation == 0:
  268. return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
  269. elif self.orientation == 180:
  270. return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
  271. elif self.orientation == 90:
  272. return cv2.rotate(im, cv2.ROTATE_180)
  273. return im
  274. def __len__(self):
  275. return self.nf # number of files
  276. class LoadImages_batch(Dataset):
  277. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  278. def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
  279. if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
  280. path = Path(path).read_text().rsplit()
  281. files = []
  282. for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
  283. p = str(Path(p).resolve())
  284. if '*' in p:
  285. files.extend(sorted(glob.glob(p, recursive=True))) # glob
  286. elif os.path.isdir(p):
  287. files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
  288. elif os.path.isfile(p):
  289. files.append(p) # files
  290. else:
  291. raise FileNotFoundError(f'{p} does not exist')
  292. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  293. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  294. ni, nv = len(images), len(videos)
  295. self.img_size = img_size
  296. self.stride = stride
  297. self.files = images + videos
  298. self.nf = ni + nv # number of files
  299. self.video_flag = [False] * ni + [True] * nv
  300. self.mode = 'image'
  301. self.auto = auto
  302. self.count = 0
  303. self.transforms = transforms # optional
  304. self.vid_stride = vid_stride # video frame-rate stride
  305. if any(videos):
  306. self._new_video(videos[0]) # new video
  307. else:
  308. self.cap = None
  309. assert self.nf > 0, f'No images or videos found in {p}. ' \
  310. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  311. def __getitem__(self, index):
  312. path = self.files[index]
  313. if self.video_flag[self.count]:
  314. # Read video
  315. self.mode = 'video'
  316. for _ in range(self.vid_stride):
  317. self.cap.grab()
  318. ret_val, im0 = self.cap.retrieve()
  319. while not ret_val:
  320. self.count += 1
  321. self.cap.release()
  322. if self.count == self.nf: # last video
  323. raise StopIteration
  324. path = self.files[self.count]
  325. self._new_video(path)
  326. ret_val, im0 = self.cap.read()
  327. self.frame += 1
  328. # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
  329. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  330. else:
  331. # Read image
  332. self.count += 1
  333. im0 = cv2.imread(path) # BGR
  334. assert im0 is not None, f'Image Not Found {path}'
  335. s = f'image{path}: '
  336. if self.transforms:
  337. im = self.transforms(im0) # transforms
  338. else:
  339. im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
  340. im = cv2.resize(im, self.img_size) # todo: 暂时用这个 实际项目不需要resize
  341. im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  342. im = np.ascontiguousarray(im) # contiguous
  343. return path, torch.from_numpy(im), torch.from_numpy(im0), s
  344. def _new_video(self, path):
  345. # Create a new video capture object
  346. self.frame = 0
  347. self.cap = cv2.VideoCapture(path)
  348. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
  349. self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
  350. # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
  351. def _cv2_rotate(self, im):
  352. # Rotate a cv2 video manually
  353. if self.orientation == 0:
  354. return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
  355. elif self.orientation == 180:
  356. return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
  357. elif self.orientation == 90:
  358. return cv2.rotate(im, cv2.ROTATE_180)
  359. return im
  360. def __len__(self):
  361. return self.nf # number of files
  362. @staticmethod
  363. def collate_fn(batch):
  364. img0s = []
  365. path, img, img0, s = zip(*batch) # transposed; (tupe(b*tensor))
  366. for i, l in enumerate(img0):
  367. img0s.append(l) # add target image index for build_targets()
  368. return path, torch.stack(img, 0), img0s, s
  369. class LoadImages:
  370. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  371. def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
  372. if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
  373. path = Path(path).read_text().rsplit()
  374. files = []
  375. for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
  376. p = str(Path(p).resolve())
  377. if '*' in p:
  378. files.extend(sorted(glob.glob(p, recursive=True))) # glob
  379. elif os.path.isdir(p):
  380. files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
  381. elif os.path.isfile(p):
  382. files.append(p) # files
  383. else:
  384. raise FileNotFoundError(f'{p} does not exist')
  385. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  386. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  387. ni, nv = len(images), len(videos)
  388. self.img_size = img_size
  389. self.stride = stride
  390. self.files = images + videos
  391. self.nf = ni + nv # number of files
  392. self.video_flag = [False] * ni + [True] * nv
  393. self.mode = 'image'
  394. self.auto = auto
  395. self.transforms = transforms # optional
  396. self.vid_stride = vid_stride # video frame-rate stride
  397. if any(videos):
  398. self._new_video(videos[0]) # new video
  399. else:
  400. self.cap = None
  401. assert self.nf > 0, f'No images or videos found in {p}. ' \
  402. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  403. def __iter__(self):
  404. self.count = 0
  405. return self
  406. def __next__(self):
  407. if self.count == self.nf:
  408. raise StopIteration
  409. path = self.files[self.count]
  410. if self.video_flag[self.count]:
  411. # Read video
  412. self.mode = 'video'
  413. for _ in range(self.vid_stride):
  414. self.cap.grab()
  415. ret_val, im0 = self.cap.retrieve()
  416. while not ret_val:
  417. self.count += 1
  418. self.cap.release()
  419. if self.count == self.nf: # last video
  420. raise StopIteration
  421. path = self.files[self.count]
  422. self._new_video(path)
  423. ret_val, im0 = self.cap.read()
  424. self.frame += 1
  425. # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
  426. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  427. else:
  428. # Read image
  429. self.count += 1
  430. im0 = cv2.imread(path) # BGR
  431. assert im0 is not None, f'Image Not Found {path}'
  432. s = f'image {self.count}/{self.nf} {path}: '
  433. if self.transforms:
  434. im = self.transforms(im0) # transforms
  435. else:
  436. im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
  437. im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  438. im = np.ascontiguousarray(im) # contiguous
  439. return path, im, im0, self.cap, s
  440. def _new_video(self, path):
  441. # Create a new video capture object
  442. self.frame = 0
  443. self.cap = cv2.VideoCapture(path)
  444. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
  445. self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
  446. # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
  447. def _cv2_rotate(self, im):
  448. # Rotate a cv2 video manually
  449. if self.orientation == 0:
  450. return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
  451. elif self.orientation == 180:
  452. return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
  453. elif self.orientation == 90:
  454. return cv2.rotate(im, cv2.ROTATE_180)
  455. return im
  456. def __len__(self):
  457. return self.nf # number of files
  458. class LoadStreams:
  459. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  460. def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
  461. torch.backends.cudnn.benchmark = True # faster for fixed-size inference
  462. self.mode = 'stream'
  463. self.img_size = img_size
  464. self.stride = stride
  465. self.vid_stride = vid_stride # video frame-rate stride
  466. sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
  467. n = len(sources)
  468. self.sources = [clean_str(x) for x in sources] # clean source names for later
  469. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  470. for i, s in enumerate(sources): # index, source
  471. # Start thread to read frames from video stream
  472. st = f'{i + 1}/{n}: {s}... '
  473. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
  474. # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
  475. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  476. import pafy
  477. s = pafy.new(s).getbest(preftype='mp4').url # YouTube URL
  478. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  479. if s == 0:
  480. assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
  481. assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
  482. cap = cv2.VideoCapture(s)
  483. assert cap.isOpened(), f'{st}Failed to open {s}'
  484. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  485. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  486. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  487. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  488. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  489. _, self.imgs[i] = cap.read() # guarantee first frame
  490. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  491. LOGGER.info(f'{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)')
  492. self.threads[i].start()
  493. LOGGER.info('') # newline
  494. # check for common shapes
  495. s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
  496. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  497. self.auto = auto and self.rect
  498. self.transforms = transforms # optional
  499. if not self.rect:
  500. LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  501. def update(self, i, cap, stream):
  502. # Read stream `i` frames in daemon thread
  503. n, f = 0, self.frames[i] # frame number, frame array
  504. while cap.isOpened() and n < f:
  505. n += 1
  506. cap.grab() # .read() = .grab() followed by .retrieve()
  507. if n % self.vid_stride == 0:
  508. success, im = cap.retrieve()
  509. if success:
  510. self.imgs[i] = im
  511. else:
  512. LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
  513. self.imgs[i] = np.zeros_like(self.imgs[i])
  514. cap.open(stream) # re-open stream if signal was lost
  515. time.sleep(0.0) # wait time
  516. def __iter__(self):
  517. self.count = -1
  518. return self
  519. def __next__(self):
  520. self.count += 1
  521. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  522. cv2.destroyAllWindows()
  523. raise StopIteration
  524. im0 = self.imgs.copy()
  525. if self.transforms:
  526. im = np.stack([self.transforms(x) for x in im0]) # transforms
  527. else:
  528. im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
  529. im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  530. im = np.ascontiguousarray(im) # contiguous
  531. return self.sources, im, im0, None, ''
  532. def __len__(self):
  533. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  534. def img2label_paths(img_paths):
  535. # Define label paths as a function of image paths
  536. sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
  537. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  538. class LoadImagesAndLabels(Dataset):
  539. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  540. cache_version = 0.6 # dataset labels *.cache version
  541. rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
  542. def __init__(self,
  543. path,
  544. img_size=640,
  545. batch_size=16,
  546. augment=False,
  547. hyp=None,
  548. rect=False,
  549. image_weights=False,
  550. cache_images=False,
  551. single_cls=False,
  552. stride=32,
  553. pad=0.0,
  554. min_items=0,
  555. prefix=''):
  556. self.img_size = img_size
  557. self.augment = augment
  558. self.hyp = hyp
  559. self.image_weights = image_weights
  560. self.rect = False if image_weights else rect
  561. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  562. self.mosaic_border = [-img_size // 2, -img_size // 2]
  563. self.stride = stride
  564. self.path = path
  565. self.albumentations = Albumentations(size=img_size) if augment else None
  566. try:
  567. f = [] # image files
  568. for p in path if isinstance(path, list) else [path]:
  569. p = Path(p) # os-agnostic
  570. if p.is_dir(): # dir
  571. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  572. # f = list(p.rglob('*.*')) # pathlib
  573. elif p.is_file(): # file
  574. with open(p) as t:
  575. t = t.read().strip().splitlines()
  576. parent = str(p.parent) + os.sep
  577. f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
  578. # f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
  579. else:
  580. raise FileNotFoundError(f'{prefix}{p} does not exist')
  581. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  582. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  583. assert self.im_files, f'{prefix}No images found'
  584. except Exception as e:
  585. raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
  586. # Check cache
  587. self.label_files = img2label_paths(self.im_files) # labels
  588. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  589. try:
  590. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  591. assert cache['version'] == self.cache_version # matches current version
  592. assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
  593. except Exception:
  594. cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
  595. # Display cache
  596. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  597. if exists and LOCAL_RANK in {-1, 0}:
  598. d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
  599. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
  600. if cache['msgs']:
  601. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  602. assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
  603. # Read cache
  604. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  605. labels, shapes, self.segments = zip(*cache.values())
  606. nl = len(np.concatenate(labels, 0)) # number of labels
  607. assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
  608. self.labels = list(labels)
  609. self.shapes = np.array(shapes)
  610. self.im_files = list(cache.keys()) # update
  611. self.label_files = img2label_paths(cache.keys()) # update
  612. # Filter images
  613. if min_items:
  614. include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
  615. LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
  616. self.im_files = [self.im_files[i] for i in include]
  617. self.label_files = [self.label_files[i] for i in include]
  618. self.labels = [self.labels[i] for i in include]
  619. self.segments = [self.segments[i] for i in include]
  620. self.shapes = self.shapes[include] # wh
  621. # Create indices
  622. n = len(self.shapes) # number of images
  623. bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
  624. nb = bi[-1] + 1 # number of batches
  625. self.batch = bi # batch index of image
  626. self.n = n
  627. self.indices = range(n)
  628. # Update labels
  629. include_class = [] # filter labels to include only these classes (optional)
  630. self.segments = list(self.segments)
  631. include_class_array = np.array(include_class).reshape(1, -1)
  632. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  633. if include_class:
  634. j = (label[:, 0:1] == include_class_array).any(1)
  635. self.labels[i] = label[j]
  636. if segment:
  637. self.segments[i] = [segment[idx] for idx, elem in enumerate(j) if elem]
  638. if single_cls: # single-class training, merge all classes into 0
  639. self.labels[i][:, 0] = 0
  640. # Rectangular Training
  641. if self.rect:
  642. # Sort by aspect ratio
  643. s = self.shapes # wh
  644. ar = s[:, 1] / s[:, 0] # aspect ratio
  645. irect = ar.argsort()
  646. self.im_files = [self.im_files[i] for i in irect]
  647. self.label_files = [self.label_files[i] for i in irect]
  648. self.labels = [self.labels[i] for i in irect]
  649. self.segments = [self.segments[i] for i in irect]
  650. self.shapes = s[irect] # wh
  651. ar = ar[irect]
  652. # Set training image shapes
  653. shapes = [[1, 1]] * nb
  654. for i in range(nb):
  655. ari = ar[bi == i]
  656. mini, maxi = ari.min(), ari.max()
  657. if maxi < 1:
  658. shapes[i] = [maxi, 1]
  659. elif mini > 1:
  660. shapes[i] = [1, 1 / mini]
  661. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
  662. # Cache images into RAM/disk for faster training
  663. if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
  664. cache_images = False
  665. self.ims = [None] * n
  666. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  667. if cache_images:
  668. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  669. self.im_hw0, self.im_hw = [None] * n, [None] * n
  670. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  671. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  672. pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
  673. for i, x in pbar:
  674. if cache_images == 'disk':
  675. b += self.npy_files[i].stat().st_size
  676. else: # 'ram'
  677. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  678. b += self.ims[i].nbytes
  679. pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
  680. pbar.close()
  681. def check_cache_ram(self, safety_margin=0.1, prefix=''):
  682. # Check image caching requirements vs available memory
  683. b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
  684. n = min(self.n, 30) # extrapolate from 30 random images
  685. for _ in range(n):
  686. im = cv2.imread(random.choice(self.im_files)) # sample image
  687. ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
  688. b += im.nbytes * ratio ** 2
  689. mem_required = b * self.n / n # GB required to cache dataset into RAM
  690. mem = psutil.virtual_memory()
  691. cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
  692. if not cache:
  693. LOGGER.info(f'{prefix}{mem_required / gb:.1f}GB RAM required, '
  694. f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
  695. f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
  696. return cache
  697. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  698. # Cache dataset labels, check images and read shapes
  699. x = {} # dict
  700. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  701. desc = f'{prefix}Scanning {path.parent / path.stem}...'
  702. with Pool(NUM_THREADS) as pool:
  703. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  704. desc=desc,
  705. total=len(self.im_files),
  706. bar_format=TQDM_BAR_FORMAT)
  707. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  708. nm += nm_f
  709. nf += nf_f
  710. ne += ne_f
  711. nc += nc_f
  712. if im_file:
  713. x[im_file] = [lb, shape, segments]
  714. if msg:
  715. msgs.append(msg)
  716. pbar.desc = f'{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt'
  717. pbar.close()
  718. if msgs:
  719. LOGGER.info('\n'.join(msgs))
  720. if nf == 0:
  721. LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
  722. x['hash'] = get_hash(self.label_files + self.im_files)
  723. x['results'] = nf, nm, ne, nc, len(self.im_files)
  724. x['msgs'] = msgs # warnings
  725. x['version'] = self.cache_version # cache version
  726. try:
  727. np.save(path, x) # save cache for next time
  728. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  729. LOGGER.info(f'{prefix}New cache created: {path}')
  730. except Exception as e:
  731. LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable
  732. return x
  733. def __len__(self):
  734. return len(self.im_files)
  735. # def __iter__(self):
  736. # self.count = -1
  737. # print('ran dataset iter')
  738. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  739. # return self
  740. def __getitem__(self, index):
  741. index = self.indices[index] # linear, shuffled, or image_weights
  742. hyp = self.hyp
  743. mosaic = self.mosaic and random.random() < hyp['mosaic']
  744. if mosaic:
  745. # Load mosaic
  746. img, labels = self.load_mosaic(index)
  747. shapes = None
  748. # MixUp augmentation
  749. if random.random() < hyp['mixup']:
  750. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  751. else:
  752. # Load image
  753. img, (h0, w0), (h, w) = self.load_image(index)
  754. # Letterbox
  755. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  756. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  757. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  758. labels = self.labels[index].copy()
  759. if labels.size: # normalized xywh to pixel xyxy format
  760. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  761. if self.augment:
  762. img, labels = random_perspective(img,
  763. labels,
  764. degrees=hyp['degrees'],
  765. translate=hyp['translate'],
  766. scale=hyp['scale'],
  767. shear=hyp['shear'],
  768. perspective=hyp['perspective'])
  769. nl = len(labels) # number of labels
  770. if nl:
  771. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  772. if self.augment:
  773. # Albumentations
  774. img, labels = self.albumentations(img, labels)
  775. nl = len(labels) # update after albumentations
  776. # HSV color-space
  777. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  778. # Flip up-down
  779. if random.random() < hyp['flipud']:
  780. img = np.flipud(img)
  781. if nl:
  782. labels[:, 2] = 1 - labels[:, 2]
  783. # Flip left-right
  784. if random.random() < hyp['fliplr']:
  785. img = np.fliplr(img)
  786. if nl:
  787. labels[:, 1] = 1 - labels[:, 1]
  788. # Cutouts
  789. # labels = cutout(img, labels, p=0.5)
  790. # nl = len(labels) # update after cutout
  791. labels_out = torch.zeros((nl, 6))
  792. if nl:
  793. labels_out[:, 1:] = torch.from_numpy(labels)
  794. # Convert
  795. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  796. img = np.ascontiguousarray(img)
  797. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  798. def load_image(self, i):
  799. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  800. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  801. if im is None: # not cached in RAM
  802. if fn.exists(): # load npy
  803. im = np.load(fn)
  804. else: # read image
  805. im = cv2.imread(f) # BGR
  806. assert im is not None, f'Image Not Found {f}'
  807. h0, w0 = im.shape[:2] # orig hw
  808. r = self.img_size / max(h0, w0) # ratio
  809. if r != 1: # if sizes are not equal
  810. interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
  811. im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
  812. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  813. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  814. def cache_images_to_disk(self, i):
  815. # Saves an image as an *.npy file for faster loading
  816. f = self.npy_files[i]
  817. if not f.exists():
  818. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  819. def load_mosaic(self, index):
  820. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  821. labels4, segments4 = [], []
  822. s = self.img_size
  823. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  824. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  825. random.shuffle(indices)
  826. for i, index in enumerate(indices):
  827. # Load image
  828. img, _, (h, w) = self.load_image(index)
  829. # place img in img4
  830. if i == 0: # top left
  831. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  832. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  833. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  834. elif i == 1: # top right
  835. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  836. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  837. elif i == 2: # bottom left
  838. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  839. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  840. elif i == 3: # bottom right
  841. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  842. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  843. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  844. padw = x1a - x1b
  845. padh = y1a - y1b
  846. # Labels
  847. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  848. if labels.size:
  849. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  850. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  851. labels4.append(labels)
  852. segments4.extend(segments)
  853. # Concat/clip labels
  854. labels4 = np.concatenate(labels4, 0)
  855. for x in (labels4[:, 1:], *segments4):
  856. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  857. # img4, labels4 = replicate(img4, labels4) # replicate
  858. # Augment
  859. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  860. img4, labels4 = random_perspective(img4,
  861. labels4,
  862. segments4,
  863. degrees=self.hyp['degrees'],
  864. translate=self.hyp['translate'],
  865. scale=self.hyp['scale'],
  866. shear=self.hyp['shear'],
  867. perspective=self.hyp['perspective'],
  868. border=self.mosaic_border) # border to remove
  869. return img4, labels4
  870. def load_mosaic9(self, index):
  871. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  872. labels9, segments9 = [], []
  873. s = self.img_size
  874. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  875. random.shuffle(indices)
  876. hp, wp = -1, -1 # height, width previous
  877. for i, index in enumerate(indices):
  878. # Load image
  879. img, _, (h, w) = self.load_image(index)
  880. # place img in img9
  881. if i == 0: # center
  882. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  883. h0, w0 = h, w
  884. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  885. elif i == 1: # top
  886. c = s, s - h, s + w, s
  887. elif i == 2: # top right
  888. c = s + wp, s - h, s + wp + w, s
  889. elif i == 3: # right
  890. c = s + w0, s, s + w0 + w, s + h
  891. elif i == 4: # bottom right
  892. c = s + w0, s + hp, s + w0 + w, s + hp + h
  893. elif i == 5: # bottom
  894. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  895. elif i == 6: # bottom left
  896. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  897. elif i == 7: # left
  898. c = s - w, s + h0 - h, s, s + h0
  899. elif i == 8: # top left
  900. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  901. padx, pady = c[:2]
  902. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  903. # Labels
  904. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  905. if labels.size:
  906. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  907. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  908. labels9.append(labels)
  909. segments9.extend(segments)
  910. # Image
  911. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  912. hp, wp = h, w # height, width previous
  913. # Offset
  914. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  915. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  916. # Concat/clip labels
  917. labels9 = np.concatenate(labels9, 0)
  918. labels9[:, [1, 3]] -= xc
  919. labels9[:, [2, 4]] -= yc
  920. c = np.array([xc, yc]) # centers
  921. segments9 = [x - c for x in segments9]
  922. for x in (labels9[:, 1:], *segments9):
  923. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  924. # img9, labels9 = replicate(img9, labels9) # replicate
  925. # Augment
  926. img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
  927. img9, labels9 = random_perspective(img9,
  928. labels9,
  929. segments9,
  930. degrees=self.hyp['degrees'],
  931. translate=self.hyp['translate'],
  932. scale=self.hyp['scale'],
  933. shear=self.hyp['shear'],
  934. perspective=self.hyp['perspective'],
  935. border=self.mosaic_border) # border to remove
  936. return img9, labels9
  937. @staticmethod
  938. def collate_fn(batch):
  939. im, label, path, shapes = zip(*batch) # transposed
  940. for i, lb in enumerate(label):
  941. lb[:, 0] = i # add target image index for build_targets()
  942. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  943. @staticmethod
  944. def collate_fn4(batch):
  945. im, label, path, shapes = zip(*batch) # transposed
  946. n = len(shapes) // 4
  947. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  948. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  949. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  950. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  951. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  952. i *= 4
  953. if random.random() < 0.5:
  954. im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  955. align_corners=False)[0].type(im[i].type())
  956. lb = label[i]
  957. else:
  958. im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
  959. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  960. im4.append(im1)
  961. label4.append(lb)
  962. for i, lb in enumerate(label4):
  963. lb[:, 0] = i # add target image index for build_targets()
  964. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  965. # Ancillary functions --------------------------------------------------------------------------------------------------
  966. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  967. # Flatten a recursive directory by bringing all files to top level
  968. new_path = Path(f'{str(path)}_flat')
  969. if os.path.exists(new_path):
  970. shutil.rmtree(new_path) # delete output folder
  971. os.makedirs(new_path) # make new output folder
  972. for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
  973. shutil.copyfile(file, new_path / Path(file).name)
  974. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
  975. # Convert detection dataset into classification dataset, with one directory per class
  976. path = Path(path) # images dir
  977. shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
  978. files = list(path.rglob('*.*'))
  979. n = len(files) # number of files
  980. for im_file in tqdm(files, total=n):
  981. if im_file.suffix[1:] in IMG_FORMATS:
  982. # image
  983. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  984. h, w = im.shape[:2]
  985. # labels
  986. lb_file = Path(img2label_paths([str(im_file)])[0])
  987. if Path(lb_file).exists():
  988. with open(lb_file) as f:
  989. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  990. for j, x in enumerate(lb):
  991. c = int(x[0]) # class
  992. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  993. if not f.parent.is_dir():
  994. f.parent.mkdir(parents=True)
  995. b = x[1:] * [w, h, w, h] # box
  996. # b[2:] = b[2:].max() # rectangle to square
  997. b[2:] = b[2:] * 1.2 + 3 # pad
  998. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
  999. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  1000. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  1001. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  1002. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  1003. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  1004. Usage: from utils.dataloaders import *; autosplit()
  1005. Arguments
  1006. path: Path to images directory
  1007. weights: Train, val, test weights (list, tuple)
  1008. annotated_only: Only use images with an annotated txt file
  1009. """
  1010. path = Path(path) # images dir
  1011. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  1012. n = len(files) # number of files
  1013. random.seed(0) # for reproducibility
  1014. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  1015. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  1016. for x in txt:
  1017. if (path.parent / x).exists():
  1018. (path.parent / x).unlink() # remove existing
  1019. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  1020. for i, img in tqdm(zip(indices, files), total=n):
  1021. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  1022. with open(path.parent / txt[i], 'a') as f:
  1023. f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
  1024. def verify_image_label(args):
  1025. # Verify one image-label pair
  1026. im_file, lb_file, prefix = args
  1027. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  1028. try:
  1029. # verify images
  1030. im = Image.open(im_file)
  1031. im.verify() # PIL verify
  1032. shape = exif_size(im) # image size
  1033. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  1034. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  1035. if im.format.lower() in ('jpg', 'jpeg'):
  1036. with open(im_file, 'rb') as f:
  1037. f.seek(-2, 2)
  1038. if f.read() != b'\xff\xd9': # corrupt JPEG
  1039. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  1040. msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
  1041. # verify labels
  1042. if os.path.isfile(lb_file):
  1043. nf = 1 # label found
  1044. with open(lb_file) as f:
  1045. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  1046. if any(len(x) > 6 for x in lb): # is segment
  1047. classes = np.array([x[0] for x in lb], dtype=np.float32)
  1048. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  1049. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  1050. lb = np.array(lb, dtype=np.float32)
  1051. nl = len(lb)
  1052. if nl:
  1053. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  1054. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  1055. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  1056. _, i = np.unique(lb, axis=0, return_index=True)
  1057. if len(i) < nl: # duplicate row check
  1058. lb = lb[i] # remove duplicates
  1059. if segments:
  1060. segments = [segments[x] for x in i]
  1061. msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
  1062. else:
  1063. ne = 1 # label empty
  1064. lb = np.zeros((0, 5), dtype=np.float32)
  1065. else:
  1066. nm = 1 # label missing
  1067. lb = np.zeros((0, 5), dtype=np.float32)
  1068. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  1069. except Exception as e:
  1070. nc = 1
  1071. msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
  1072. return [None, None, None, None, nm, nf, ne, nc, msg]
  1073. class HUBDatasetStats():
  1074. """ Class for generating HUB dataset JSON and `-hub` dataset directory
  1075. Arguments
  1076. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  1077. autodownload: Attempt to download dataset if not found locally
  1078. Usage
  1079. from utils.dataloaders import HUBDatasetStats
  1080. stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
  1081. stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
  1082. stats.get_json(save=False)
  1083. stats.process_images()
  1084. """
  1085. def __init__(self, path='coco128.yaml', autodownload=False):
  1086. # Initialize class
  1087. zipped, data_dir, yaml_path = self._unzip(Path(path))
  1088. try:
  1089. with open(check_yaml(yaml_path), errors='ignore') as f:
  1090. data = yaml.safe_load(f) # data dict
  1091. if zipped:
  1092. data['path'] = data_dir
  1093. except Exception as e:
  1094. raise Exception('error/HUB/dataset_stats/yaml_load') from e
  1095. check_dataset(data, autodownload) # download dataset if missing
  1096. self.hub_dir = Path(data['path'] + '-hub')
  1097. self.im_dir = self.hub_dir / 'images'
  1098. self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
  1099. self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
  1100. self.data = data
  1101. @staticmethod
  1102. def _find_yaml(dir):
  1103. # Return data.yaml file
  1104. files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
  1105. assert files, f'No *.yaml file found in {dir}'
  1106. if len(files) > 1:
  1107. files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
  1108. assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
  1109. assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
  1110. return files[0]
  1111. def _unzip(self, path):
  1112. # Unzip data.zip
  1113. if not str(path).endswith('.zip'): # path is data.yaml
  1114. return False, None, path
  1115. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  1116. unzip_file(path, path=path.parent)
  1117. dir = path.with_suffix('') # dataset directory == zip name
  1118. assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
  1119. return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
  1120. def _hub_ops(self, f, max_dim=1920):
  1121. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  1122. f_new = self.im_dir / Path(f).name # dataset-hub image filename
  1123. try: # use PIL
  1124. im = Image.open(f)
  1125. r = max_dim / max(im.height, im.width) # ratio
  1126. if r < 1.0: # image too large
  1127. im = im.resize((int(im.width * r), int(im.height * r)))
  1128. im.save(f_new, 'JPEG', quality=50, optimize=True) # save
  1129. except Exception as e: # use OpenCV
  1130. LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
  1131. im = cv2.imread(f)
  1132. im_height, im_width = im.shape[:2]
  1133. r = max_dim / max(im_height, im_width) # ratio
  1134. if r < 1.0: # image too large
  1135. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  1136. cv2.imwrite(str(f_new), im)
  1137. def get_json(self, save=False, verbose=False):
  1138. # Return dataset JSON for Ultralytics HUB
  1139. def _round(labels):
  1140. # Update labels to integer class and 6 decimal place floats
  1141. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  1142. for split in 'train', 'val', 'test':
  1143. if self.data.get(split) is None:
  1144. self.stats[split] = None # i.e. no test set
  1145. continue
  1146. dataset = LoadImagesAndLabels(self.data[split]) # load dataset
  1147. x = np.array([
  1148. np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
  1149. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
  1150. self.stats[split] = {
  1151. 'instance_stats': {
  1152. 'total': int(x.sum()),
  1153. 'per_class': x.sum(0).tolist()},
  1154. 'image_stats': {
  1155. 'total': dataset.n,
  1156. 'unlabelled': int(np.all(x == 0, 1).sum()),
  1157. 'per_class': (x > 0).sum(0).tolist()},
  1158. 'labels': [{
  1159. str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  1160. # Save, print and return
  1161. if save:
  1162. stats_path = self.hub_dir / 'stats.json'
  1163. print(f'Saving {stats_path.resolve()}...')
  1164. with open(stats_path, 'w') as f:
  1165. json.dump(self.stats, f) # save stats.json
  1166. if verbose:
  1167. print(json.dumps(self.stats, indent=2, sort_keys=False))
  1168. return self.stats
  1169. def process_images(self):
  1170. # Compress images for Ultralytics HUB
  1171. for split in 'train', 'val', 'test':
  1172. if self.data.get(split) is None:
  1173. continue
  1174. dataset = LoadImagesAndLabels(self.data[split]) # load dataset
  1175. desc = f'{split} images'
  1176. for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
  1177. pass
  1178. print(f'Done. All images saved to {self.im_dir}')
  1179. return self.im_dir
  1180. # Classification dataloaders -------------------------------------------------------------------------------------------
  1181. class ClassificationDataset(torchvision.datasets.ImageFolder):
  1182. """
  1183. YOLOv5 Classification Dataset.
  1184. Arguments
  1185. root: Dataset path
  1186. transform: torchvision transforms, used by default
  1187. album_transform: Albumentations transforms, used if installed
  1188. """
  1189. def __init__(self, root, augment, imgsz, cache=False):
  1190. super().__init__(root=root)
  1191. self.torch_transforms = classify_transforms(imgsz)
  1192. self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
  1193. self.cache_ram = cache is True or cache == 'ram'
  1194. self.cache_disk = cache == 'disk'
  1195. self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
  1196. def __getitem__(self, i):
  1197. f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
  1198. if self.cache_ram and im is None:
  1199. im = self.samples[i][3] = cv2.imread(f)
  1200. elif self.cache_disk:
  1201. if not fn.exists(): # load npy
  1202. np.save(fn.as_posix(), cv2.imread(f))
  1203. im = np.load(fn)
  1204. else: # read image
  1205. im = cv2.imread(f) # BGR
  1206. if self.album_transforms:
  1207. sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
  1208. else:
  1209. sample = self.torch_transforms(im)
  1210. return sample, j
  1211. def create_classification_dataloader(path,
  1212. imgsz=224,
  1213. batch_size=16,
  1214. augment=True,
  1215. cache=False,
  1216. rank=-1,
  1217. workers=8,
  1218. shuffle=True):
  1219. # Returns Dataloader object to be used with YOLOv5 Classifier
  1220. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  1221. dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
  1222. batch_size = min(batch_size, len(dataset))
  1223. nd = torch.cuda.device_count()
  1224. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
  1225. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  1226. generator = torch.Generator()
  1227. generator.manual_seed(6148914691236517205 + RANK)
  1228. return InfiniteDataLoader(dataset,
  1229. batch_size=batch_size,
  1230. shuffle=shuffle and sampler is None,
  1231. num_workers=nw,
  1232. sampler=sampler,
  1233. pin_memory=PIN_MEMORY,
  1234. worker_init_fn=seed_worker,
  1235. generator=generator) # or DataLoader(persistent_workers=True)