general.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import logging.config
  10. import math
  11. import os
  12. import platform
  13. import random
  14. import re
  15. import signal
  16. import sys
  17. import time
  18. import urllib
  19. from copy import deepcopy
  20. from datetime import datetime
  21. from itertools import repeat
  22. from multiprocessing.pool import ThreadPool
  23. from pathlib import Path
  24. from subprocess import check_output
  25. from tarfile import is_tarfile
  26. from typing import Optional
  27. from zipfile import ZipFile, is_zipfile
  28. import cv2
  29. import IPython
  30. import numpy as np
  31. import pandas as pd
  32. import pkg_resources as pkg
  33. import torch
  34. import torchvision
  35. import yaml
  36. from utils import TryExcept, emojis
  37. from utils.downloads import gsutil_getsize
  38. from utils.metrics import box_iou, fitness
  39. FILE = Path(__file__).resolve()
  40. ROOT = FILE.parents[1] # YOLOv5 root directory
  41. RANK = int(os.getenv('RANK', -1))
  42. # Settings
  43. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  44. DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
  45. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  46. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  47. TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
  48. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  49. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  50. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  51. pd.options.display.max_columns = 10
  52. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  53. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  54. os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
  55. def is_ascii(s=''):
  56. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  57. s = str(s) # convert list, tuple, None, etc. to str
  58. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  59. def is_chinese(s='人工智能'):
  60. # Is string composed of any Chinese characters?
  61. return bool(re.search('[\u4e00-\u9fff]', str(s)))
  62. def is_colab():
  63. # Is environment a Google Colab instance?
  64. return 'google.colab' in sys.modules
  65. def is_notebook():
  66. # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
  67. ipython_type = str(type(IPython.get_ipython()))
  68. return 'colab' in ipython_type or 'zmqshell' in ipython_type
  69. def is_kaggle():
  70. # Is environment a Kaggle Notebook?
  71. return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  72. def is_docker() -> bool:
  73. """Check if the process runs inside a docker container."""
  74. if Path("/.dockerenv").exists():
  75. return True
  76. try: # check if docker is in control groups
  77. with open("/proc/self/cgroup") as file:
  78. return any("docker" in line for line in file)
  79. except OSError:
  80. return False
  81. def is_writeable(dir, test=False):
  82. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  83. if not test:
  84. return os.access(dir, os.W_OK) # possible issues on Windows
  85. file = Path(dir) / 'tmp.txt'
  86. try:
  87. with open(file, 'w'): # open file with write permissions
  88. pass
  89. file.unlink() # remove file
  90. return True
  91. except OSError:
  92. return False
  93. LOGGING_NAME = "yolov5"
  94. def set_logging(name=LOGGING_NAME, verbose=True):
  95. # sets up logging for the given name
  96. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  97. level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
  98. logging.config.dictConfig({
  99. "version": 1,
  100. "disable_existing_loggers": False,
  101. "formatters": {
  102. name: {
  103. "format": "%(message)s"}},
  104. "handlers": {
  105. name: {
  106. "class": "logging.StreamHandler",
  107. "formatter": name,
  108. "level": level,}},
  109. "loggers": {
  110. name: {
  111. "level": level,
  112. "handlers": [name],
  113. "propagate": False,}}})
  114. set_logging(LOGGING_NAME) # run before defining LOGGER
  115. LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
  116. if platform.system() == 'Windows':
  117. for fn in LOGGER.info, LOGGER.warning:
  118. setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
  119. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  120. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  121. env = os.getenv(env_var)
  122. if env:
  123. path = Path(env) # use environment variable
  124. else:
  125. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  126. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  127. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  128. path.mkdir(exist_ok=True) # make if required
  129. return path
  130. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  131. class Profile(contextlib.ContextDecorator):
  132. # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
  133. def __init__(self, t=0.0):
  134. self.t = t
  135. self.cuda = torch.cuda.is_available()
  136. def __enter__(self):
  137. self.start = self.time()
  138. return self
  139. def __exit__(self, type, value, traceback):
  140. self.dt = self.time() - self.start # delta-time
  141. self.t += self.dt # accumulate dt
  142. def time(self):
  143. if self.cuda:
  144. torch.cuda.synchronize()
  145. return time.time()
  146. class Timeout(contextlib.ContextDecorator):
  147. # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  148. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  149. self.seconds = int(seconds)
  150. self.timeout_message = timeout_msg
  151. self.suppress = bool(suppress_timeout_errors)
  152. def _timeout_handler(self, signum, frame):
  153. raise TimeoutError(self.timeout_message)
  154. def __enter__(self):
  155. if platform.system() != 'Windows': # not supported on Windows
  156. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  157. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  158. def __exit__(self, exc_type, exc_val, exc_tb):
  159. if platform.system() != 'Windows':
  160. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  161. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  162. return True
  163. class WorkingDirectory(contextlib.ContextDecorator):
  164. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  165. def __init__(self, new_dir):
  166. self.dir = new_dir # new dir
  167. self.cwd = Path.cwd().resolve() # current dir
  168. def __enter__(self):
  169. os.chdir(self.dir)
  170. def __exit__(self, exc_type, exc_val, exc_tb):
  171. os.chdir(self.cwd)
  172. def methods(instance):
  173. # Get class/instance methods
  174. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  175. def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
  176. # Print function arguments (optional args dict)
  177. x = inspect.currentframe().f_back # previous frame
  178. file, _, func, _, _ = inspect.getframeinfo(x)
  179. if args is None: # get args automatically
  180. args, _, _, frm = inspect.getargvalues(x)
  181. args = {k: v for k, v in frm.items() if k in args}
  182. try:
  183. file = Path(file).resolve().relative_to(ROOT).with_suffix('')
  184. except ValueError:
  185. file = Path(file).stem
  186. s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
  187. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  188. def init_seeds(seed=0, deterministic=False):
  189. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  190. random.seed(seed)
  191. np.random.seed(seed)
  192. torch.manual_seed(seed)
  193. torch.cuda.manual_seed(seed)
  194. torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
  195. # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
  196. if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
  197. torch.use_deterministic_algorithms(True)
  198. torch.backends.cudnn.deterministic = True
  199. os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
  200. os.environ['PYTHONHASHSEED'] = str(seed)
  201. def intersect_dicts(da, db, exclude=()):
  202. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  203. return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
  204. def get_default_args(func):
  205. # Get func() default arguments
  206. signature = inspect.signature(func)
  207. return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
  208. def get_latest_run(search_dir='.'):
  209. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  210. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  211. return max(last_list, key=os.path.getctime) if last_list else ''
  212. def file_age(path=__file__):
  213. # Return days since last file update
  214. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  215. return dt.days # + dt.seconds / 86400 # fractional days
  216. def file_date(path=__file__):
  217. # Return human-readable file modification date, i.e. '2021-3-26'
  218. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  219. return f'{t.year}-{t.month}-{t.day}'
  220. def file_size(path):
  221. # Return file/dir size (MB)
  222. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  223. path = Path(path)
  224. if path.is_file():
  225. return path.stat().st_size / mb
  226. elif path.is_dir():
  227. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  228. else:
  229. return 0.0
  230. def check_online():
  231. # Check internet connectivity
  232. import socket
  233. def run_once():
  234. # Check once
  235. try:
  236. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  237. return True
  238. except OSError:
  239. return False
  240. return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
  241. def git_describe(path=ROOT): # path must be a directory
  242. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  243. try:
  244. assert (Path(path) / '.git').is_dir()
  245. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  246. except Exception:
  247. return ''
  248. @TryExcept()
  249. @WorkingDirectory(ROOT)
  250. def check_git_status(repo='ultralytics/yolov5', branch='master'):
  251. # YOLOv5 status check, recommend 'git pull' if code is out of date
  252. url = f'https://github.com/{repo}'
  253. msg = f', for updates see {url}'
  254. s = colorstr('github: ') # string
  255. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  256. assert check_online(), s + 'skipping check (offline)' + msg
  257. splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
  258. matches = [repo in s for s in splits]
  259. if any(matches):
  260. remote = splits[matches.index(True) - 1]
  261. else:
  262. remote = 'ultralytics'
  263. check_output(f'git remote add {remote} {url}', shell=True)
  264. check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
  265. local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  266. n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
  267. if n > 0:
  268. pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
  269. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
  270. else:
  271. s += f'up to date with {url} ✅'
  272. LOGGER.info(s)
  273. @WorkingDirectory(ROOT)
  274. def check_git_info(path='.'):
  275. # YOLOv5 git info check, return {remote, branch, commit}
  276. check_requirements('gitpython')
  277. import git
  278. try:
  279. repo = git.Repo(path)
  280. remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/ultralytics/yolov5'
  281. commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
  282. try:
  283. branch = repo.active_branch.name # i.e. 'main'
  284. except TypeError: # not on any branch
  285. branch = None # i.e. 'detached HEAD' state
  286. return {'remote': remote, 'branch': branch, 'commit': commit}
  287. except git.exc.InvalidGitRepositoryError: # path is not a git dir
  288. return {'remote': None, 'branch': None, 'commit': None}
  289. def check_python(minimum='3.7.0'):
  290. # Check current python version vs. required python version
  291. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  292. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  293. # Check version vs. required version
  294. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  295. result = (current == minimum) if pinned else (current >= minimum) # bool
  296. s = f'WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
  297. if hard:
  298. assert result, emojis(s) # assert min requirements met
  299. if verbose and not result:
  300. LOGGER.warning(s)
  301. return result
  302. @TryExcept()
  303. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
  304. # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages or single package str)
  305. prefix = colorstr('red', 'bold', 'requirements:')
  306. check_python() # check python version
  307. if isinstance(requirements, Path): # requirements.txt file
  308. file = requirements.resolve()
  309. assert file.exists(), f"{prefix} {file} not found, check failed."
  310. with file.open() as f:
  311. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  312. elif isinstance(requirements, str):
  313. requirements = [requirements]
  314. s = ''
  315. n = 0
  316. for r in requirements:
  317. try:
  318. pkg.require(r)
  319. except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
  320. s += f'"{r}" '
  321. n += 1
  322. if s and install and AUTOINSTALL: # check environment variable
  323. LOGGER.info(f"{prefix} YOLOv5 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
  324. try:
  325. # assert check_online(), "AutoUpdate skipped (offline)"
  326. LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
  327. source = file if 'file' in locals() else requirements
  328. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  329. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  330. LOGGER.info(s)
  331. except Exception as e:
  332. LOGGER.warning(f'{prefix} ❌ {e}')
  333. def check_img_size(imgsz, s=32, floor=0):
  334. # Verify image size is a multiple of stride s in each dimension
  335. if isinstance(imgsz, int): # integer i.e. img_size=640
  336. new_size = max(make_divisible(imgsz, int(s)), floor)
  337. else: # list i.e. img_size=[640, 480]
  338. imgsz = list(imgsz) # convert to list if tuple
  339. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  340. if new_size != imgsz:
  341. LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  342. return new_size
  343. def check_imshow(warn=False):
  344. # Check if environment supports image displays
  345. try:
  346. assert not is_notebook()
  347. assert not is_docker()
  348. cv2.imshow('test', np.zeros((1, 1, 3)))
  349. cv2.waitKey(1)
  350. cv2.destroyAllWindows()
  351. cv2.waitKey(1)
  352. return True
  353. except Exception as e:
  354. if warn:
  355. LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
  356. return False
  357. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  358. # Check file(s) for acceptable suffix
  359. if file and suffix:
  360. if isinstance(suffix, str):
  361. suffix = [suffix]
  362. for f in file if isinstance(file, (list, tuple)) else [file]:
  363. s = Path(f).suffix.lower() # file suffix
  364. if len(s):
  365. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  366. def check_yaml(file, suffix=('.yaml', '.yml')):
  367. # Search/download YAML file (if necessary) and return path, checking suffix
  368. return check_file(file, suffix)
  369. def check_file(file, suffix=''):
  370. # Search/download file (if necessary) and return path
  371. check_suffix(file, suffix) # optional
  372. file = str(file) # convert to str()
  373. if os.path.isfile(file) or not file: # exists
  374. return file
  375. elif file.startswith(('http:/', 'https:/')): # download
  376. url = file # warning: Pathlib turns :// -> :/
  377. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  378. if os.path.isfile(file):
  379. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  380. else:
  381. LOGGER.info(f'Downloading {url} to {file}...')
  382. torch.hub.download_url_to_file(url, file)
  383. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  384. return file
  385. elif file.startswith('clearml://'): # ClearML Dataset ID
  386. assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
  387. return file
  388. else: # search
  389. files = []
  390. for d in 'data', 'models', 'utils': # search directories
  391. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  392. assert len(files), f'File not found: {file}' # assert file was found
  393. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  394. return files[0] # return file
  395. def check_font(font=FONT, progress=False):
  396. # Download font to CONFIG_DIR if necessary
  397. font = Path(font)
  398. file = CONFIG_DIR / font.name
  399. if not font.exists() and not file.exists():
  400. url = f'https://ultralytics.com/assets/{font.name}'
  401. LOGGER.info(f'Downloading {url} to {file}...')
  402. torch.hub.download_url_to_file(url, str(file), progress=progress)
  403. def check_dataset(data, autodownload=True):
  404. # Download, check and/or unzip dataset if not found locally
  405. # Download (optional)
  406. extract_dir = ''
  407. if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
  408. download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
  409. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  410. extract_dir, autodownload = data.parent, False
  411. # Read yaml (optional)
  412. if isinstance(data, (str, Path)):
  413. data = yaml_load(data) # dictionary
  414. # Checks
  415. for k in 'train', 'val', 'names':
  416. assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
  417. if isinstance(data['names'], (list, tuple)): # old array format
  418. data['names'] = dict(enumerate(data['names'])) # convert to dict
  419. assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
  420. data['nc'] = len(data['names'])
  421. # Resolve paths
  422. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  423. if not path.is_absolute():
  424. path = (ROOT / path).resolve()
  425. data['path'] = path # download scripts
  426. for k in 'train', 'val', 'test':
  427. if data.get(k): # prepend path
  428. if isinstance(data[k], str):
  429. x = (path / data[k]).resolve()
  430. if not x.exists() and data[k].startswith('../'):
  431. x = (path / data[k][3:]).resolve()
  432. data[k] = str(x)
  433. else:
  434. data[k] = [str((path / x).resolve()) for x in data[k]]
  435. # Parse yaml
  436. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  437. if val:
  438. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  439. if not all(x.exists() for x in val):
  440. LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
  441. if not s or not autodownload:
  442. raise Exception('Dataset not found ❌')
  443. t = time.time()
  444. if s.startswith('http') and s.endswith('.zip'): # URL
  445. f = Path(s).name # filename
  446. LOGGER.info(f'Downloading {s} to {f}...')
  447. torch.hub.download_url_to_file(s, f)
  448. Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
  449. unzip_file(f, path=DATASETS_DIR) # unzip
  450. Path(f).unlink() # remove zip
  451. r = None # success
  452. elif s.startswith('bash '): # bash script
  453. LOGGER.info(f'Running {s} ...')
  454. r = os.system(s)
  455. else: # python script
  456. r = exec(s, {'yaml': data}) # return None
  457. dt = f'({round(time.time() - t, 1)}s)'
  458. s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
  459. LOGGER.info(f"Dataset download {s}")
  460. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  461. return data # dictionary
  462. def check_amp(model):
  463. # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
  464. from models.common import AutoShape, DetectMultiBackend
  465. def amp_allclose(model, im):
  466. # All close FP32 vs AMP results
  467. m = AutoShape(model, verbose=False) # model
  468. a = m(im).xywhn[0] # FP32 inference
  469. m.amp = True
  470. b = m(im).xywhn[0] # AMP inference
  471. return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
  472. prefix = colorstr('AMP: ')
  473. device = next(model.parameters()).device # get model device
  474. if device.type in ('cpu', 'mps'):
  475. return False # AMP only used on CUDA devices
  476. f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
  477. im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
  478. try:
  479. assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
  480. LOGGER.info(f'{prefix}checks passed ✅')
  481. return True
  482. except Exception:
  483. help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
  484. LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
  485. return False
  486. def yaml_load(file='data.yaml'):
  487. # Single-line safe yaml loading
  488. with open(file, errors='ignore') as f:
  489. return yaml.safe_load(f)
  490. def yaml_save(file='data.yaml', data={}):
  491. # Single-line safe yaml saving
  492. with open(file, 'w') as f:
  493. yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
  494. def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
  495. # Unzip a *.zip file to path/, excluding files containing strings in exclude list
  496. if path is None:
  497. path = Path(file).parent # default path
  498. with ZipFile(file) as zipObj:
  499. for f in zipObj.namelist(): # list all archived filenames in the zip
  500. if all(x not in f for x in exclude):
  501. zipObj.extract(f, path=path)
  502. def url2file(url):
  503. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  504. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  505. return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  506. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  507. # Multithreaded file download and unzip function, used in data.yaml for autodownload
  508. def download_one(url, dir):
  509. # Download 1 file
  510. success = True
  511. if os.path.isfile(url):
  512. f = Path(url) # filename
  513. else: # does not exist
  514. f = dir / Path(url).name
  515. LOGGER.info(f'Downloading {url} to {f}...')
  516. for i in range(retry + 1):
  517. if curl:
  518. s = 'sS' if threads > 1 else '' # silent
  519. r = os.system(
  520. f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
  521. success = r == 0
  522. else:
  523. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  524. success = f.is_file()
  525. if success:
  526. break
  527. elif i < retry:
  528. LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
  529. else:
  530. LOGGER.warning(f'❌ Failed to download {url}...')
  531. if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
  532. LOGGER.info(f'Unzipping {f}...')
  533. if is_zipfile(f):
  534. unzip_file(f, dir) # unzip
  535. elif is_tarfile(f):
  536. os.system(f'tar xf {f} --directory {f.parent}') # unzip
  537. elif f.suffix == '.gz':
  538. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  539. if delete:
  540. f.unlink() # remove zip
  541. dir = Path(dir)
  542. dir.mkdir(parents=True, exist_ok=True) # make directory
  543. if threads > 1:
  544. pool = ThreadPool(threads)
  545. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
  546. pool.close()
  547. pool.join()
  548. else:
  549. for u in [url] if isinstance(url, (str, Path)) else url:
  550. download_one(u, dir)
  551. def make_divisible(x, divisor):
  552. # Returns nearest x divisible by divisor
  553. if isinstance(divisor, torch.Tensor):
  554. divisor = int(divisor.max()) # to int
  555. return math.ceil(x / divisor) * divisor
  556. def clean_str(s):
  557. # Cleans a string by replacing special characters with underscore _
  558. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  559. def one_cycle(y1=0.0, y2=1.0, steps=100):
  560. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  561. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  562. def colorstr(*input):
  563. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  564. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  565. colors = {
  566. 'black': '\033[30m', # basic colors
  567. 'red': '\033[31m',
  568. 'green': '\033[32m',
  569. 'yellow': '\033[33m',
  570. 'blue': '\033[34m',
  571. 'magenta': '\033[35m',
  572. 'cyan': '\033[36m',
  573. 'white': '\033[37m',
  574. 'bright_black': '\033[90m', # bright colors
  575. 'bright_red': '\033[91m',
  576. 'bright_green': '\033[92m',
  577. 'bright_yellow': '\033[93m',
  578. 'bright_blue': '\033[94m',
  579. 'bright_magenta': '\033[95m',
  580. 'bright_cyan': '\033[96m',
  581. 'bright_white': '\033[97m',
  582. 'end': '\033[0m', # misc
  583. 'bold': '\033[1m',
  584. 'underline': '\033[4m'}
  585. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  586. def labels_to_class_weights(labels, nc=80):
  587. # Get class weights (inverse frequency) from training labels
  588. if labels[0] is None: # no labels loaded
  589. return torch.Tensor()
  590. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  591. classes = labels[:, 0].astype(int) # labels = [class xywh]
  592. weights = np.bincount(classes, minlength=nc) # occurrences per class
  593. # Prepend gridpoint count (for uCE training)
  594. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  595. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  596. weights[weights == 0] = 1 # replace empty bins with 1
  597. weights = 1 / weights # number of targets per class
  598. weights /= weights.sum() # normalize
  599. return torch.from_numpy(weights).float()
  600. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  601. # Produces image weights based on class_weights and image contents
  602. # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
  603. class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
  604. return (class_weights.reshape(1, nc) * class_counts).sum(1)
  605. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  606. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  607. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  608. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  609. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  610. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  611. return [
  612. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  613. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  614. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  615. def xyxy2xywh(x):
  616. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  617. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  618. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  619. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  620. y[:, 2] = x[:, 2] - x[:, 0] # width
  621. y[:, 3] = x[:, 3] - x[:, 1] # height
  622. return y
  623. def xywh2xyxy(x):
  624. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  625. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  626. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  627. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  628. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  629. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  630. return y
  631. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  632. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  633. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  634. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  635. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  636. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  637. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  638. return y
  639. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  640. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  641. if clip:
  642. clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
  643. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  644. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  645. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  646. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  647. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  648. return y
  649. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  650. # Convert normalized segments into pixel segments, shape (n,2)
  651. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  652. y[:, 0] = w * x[:, 0] + padw # top left x
  653. y[:, 1] = h * x[:, 1] + padh # top left y
  654. return y
  655. def segment2box(segment, width=640, height=640):
  656. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  657. x, y = segment.T # segment xy
  658. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  659. x, y, = x[inside], y[inside]
  660. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  661. def segments2boxes(segments):
  662. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  663. boxes = []
  664. for s in segments:
  665. x, y = s.T # segment xy
  666. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  667. return xyxy2xywh(np.array(boxes)) # cls, xywh
  668. def resample_segments(segments, n=1000):
  669. # Up-sample an (n,2) segment
  670. for i, s in enumerate(segments):
  671. s = np.concatenate((s, s[0:1, :]), axis=0)
  672. x = np.linspace(0, len(s) - 1, n)
  673. xp = np.arange(len(s))
  674. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  675. return segments
  676. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
  677. # Rescale boxes (xyxy) from img1_shape to img0_shape
  678. if ratio_pad is None: # calculate from img0_shape
  679. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  680. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  681. else:
  682. gain = ratio_pad[0][0]
  683. pad = ratio_pad[1]
  684. boxes[:, [0, 2]] -= pad[0] # x padding
  685. boxes[:, [1, 3]] -= pad[1] # y padding
  686. boxes[:, :4] /= gain
  687. clip_boxes(boxes, img0_shape)
  688. return boxes
  689. def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
  690. # Rescale coords (xyxy) from img1_shape to img0_shape
  691. if ratio_pad is None: # calculate from img0_shape
  692. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  693. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  694. else:
  695. gain = ratio_pad[0][0]
  696. pad = ratio_pad[1]
  697. segments[:, 0] -= pad[0] # x padding
  698. segments[:, 1] -= pad[1] # y padding
  699. segments /= gain
  700. clip_segments(segments, img0_shape)
  701. if normalize:
  702. segments[:, 0] /= img0_shape[1] # width
  703. segments[:, 1] /= img0_shape[0] # height
  704. return segments
  705. def clip_boxes(boxes, shape):
  706. # Clip boxes (xyxy) to image shape (height, width)
  707. if isinstance(boxes, torch.Tensor): # faster individually
  708. boxes[:, 0].clamp_(0, shape[1]) # x1
  709. boxes[:, 1].clamp_(0, shape[0]) # y1
  710. boxes[:, 2].clamp_(0, shape[1]) # x2
  711. boxes[:, 3].clamp_(0, shape[0]) # y2
  712. else: # np.array (faster grouped)
  713. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  714. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  715. def clip_segments(segments, shape):
  716. # Clip segments (xy1,xy2,...) to image shape (height, width)
  717. if isinstance(segments, torch.Tensor): # faster individually
  718. segments[:, 0].clamp_(0, shape[1]) # x
  719. segments[:, 1].clamp_(0, shape[0]) # y
  720. else: # np.array (faster grouped)
  721. segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
  722. segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
  723. def non_max_suppression(
  724. prediction,
  725. conf_thres=0.25,
  726. iou_thres=0.45,
  727. classes=None,
  728. agnostic=False,
  729. multi_label=False,
  730. labels=(),
  731. max_det=300,
  732. nm=0, # number of masks
  733. ):
  734. """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
  735. Returns:
  736. list of detections, on (n,6) tensor per image [xyxy, conf, cls,mask]
  737. """
  738. #print("prediction.shape: ", prediction.shape) #[1, 16128, 44]
  739. if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
  740. prediction = prediction[0] # select only inference output
  741. device = prediction.device
  742. mps = 'mps' in device.type # Apple MPS
  743. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  744. prediction = prediction.cpu()
  745. bs = prediction.shape[0] # batch size
  746. nc = prediction.shape[2] - nm - 5 # number of classes
  747. xc = prediction[..., 4] > conf_thres # candidates
  748. # Checks
  749. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  750. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  751. # Settings
  752. # min_wh = 2 # (pixels) minimum box width and height
  753. max_wh = 7680 # (pixels) maximum box width and height
  754. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  755. time_limit = 0.5 + 0.05 * bs # seconds to quit after
  756. redundant = True # require redundant detections
  757. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  758. merge = False # use merge-NMS
  759. t = time.time()
  760. mi = 5 + nc # mask start index
  761. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  762. for xi, x in enumerate(prediction): # image index, image inference
  763. # Apply constraints
  764. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  765. x = x[xc[xi]] # confidence
  766. # Cat apriori labels if autolabelling
  767. if labels and len(labels[xi]):
  768. lb = labels[xi]
  769. v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  770. v[:, :4] = lb[:, 1:5] # box
  771. v[:, 4] = 1.0 # conf
  772. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  773. x = torch.cat((x, v), 0)
  774. # If none remain process next image
  775. if not x.shape[0]:
  776. continue
  777. # Compute conf
  778. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  779. # Box/Mask
  780. box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
  781. mask = x[:, mi:] # zero columns if no masks
  782. # Detections matrix nx6 (xyxy, conf, cls)
  783. if multi_label:
  784. i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
  785. x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
  786. else: # best class only
  787. conf, j = x[:, 5:mi].max(1, keepdim=True)
  788. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  789. # Filter by class
  790. if classes is not None:
  791. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  792. # Apply finite constraint
  793. # if not torch.isfinite(x).all():
  794. # x = x[torch.isfinite(x).all(1)]
  795. # Check shape
  796. n = x.shape[0] # number of boxes
  797. if not n: # no boxes
  798. continue
  799. elif n > max_nms: # excess boxes
  800. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  801. else:
  802. x = x[x[:, 4].argsort(descending=True)] # sort by confidence
  803. # Batched NMS
  804. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  805. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  806. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  807. if i.shape[0] > max_det: # limit detections
  808. i = i[:max_det]
  809. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  810. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  811. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  812. weights = iou * scores[None] # box weights
  813. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  814. if redundant:
  815. i = i[iou.sum(1) > 1] # require redundancy
  816. output[xi] = x[i]
  817. if mps:
  818. output[xi] = output[xi].to(device)
  819. if (time.time() - t) > time_limit:
  820. LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
  821. break # time limit exceeded
  822. return output
  823. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  824. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  825. x = torch.load(f, map_location=torch.device('cpu'))
  826. if x.get('ema'):
  827. x['model'] = x['ema'] # replace model with ema
  828. for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
  829. x[k] = None
  830. x['epoch'] = -1
  831. x['model'].half() # to FP16
  832. for p in x['model'].parameters():
  833. p.requires_grad = False
  834. torch.save(x, s or f)
  835. mb = os.path.getsize(s or f) / 1E6 # filesize
  836. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  837. def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  838. evolve_csv = save_dir / 'evolve.csv'
  839. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  840. keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
  841. keys = tuple(x.strip() for x in keys)
  842. vals = results + tuple(hyp.values())
  843. n = len(keys)
  844. # Download (optional)
  845. if bucket:
  846. url = f'gs://{bucket}/evolve.csv'
  847. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  848. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  849. # Log to evolve.csv
  850. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  851. with open(evolve_csv, 'a') as f:
  852. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  853. # Save yaml
  854. with open(evolve_yaml, 'w') as f:
  855. data = pd.read_csv(evolve_csv, skipinitialspace=True)
  856. data = data.rename(columns=lambda x: x.strip()) # strip keys
  857. i = np.argmax(fitness(data.values[:, :4])) #
  858. generations = len(data)
  859. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  860. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  861. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  862. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  863. # Print to screen
  864. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  865. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  866. for x in vals) + '\n\n')
  867. if bucket:
  868. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  869. def apply_classifier(x, model, img, im0):
  870. # Apply a second stage classifier to YOLO outputs
  871. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  872. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  873. for i, d in enumerate(x): # per image
  874. if d is not None and len(d):
  875. d = d.clone()
  876. # Reshape and pad cutouts
  877. b = xyxy2xywh(d[:, :4]) # boxes
  878. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  879. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  880. d[:, :4] = xywh2xyxy(b).long()
  881. # Rescale boxes from img_size to im0 size
  882. scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
  883. # Classes
  884. pred_cls1 = d[:, 5].long()
  885. ims = []
  886. for a in d:
  887. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  888. im = cv2.resize(cutout, (224, 224)) # BGR
  889. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  890. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  891. im /= 255 # 0 - 255 to 0.0 - 1.0
  892. ims.append(im)
  893. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  894. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  895. return x
  896. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  897. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  898. path = Path(path) # os-agnostic
  899. if path.exists() and not exist_ok:
  900. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  901. # Method 1
  902. for n in range(2, 9999):
  903. p = f'{path}{sep}{n}{suffix}' # increment path
  904. if not os.path.exists(p): #
  905. break
  906. path = Path(p)
  907. # Method 2 (deprecated)
  908. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  909. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  910. # i = [int(m.groups()[0]) for m in matches if m] # indices
  911. # n = max(i) + 1 if i else 2 # increment number
  912. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  913. if mkdir:
  914. path.mkdir(parents=True, exist_ok=True) # make directory
  915. return path
  916. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  917. imshow_ = cv2.imshow # copy to avoid recursion errors
  918. def imread(path, flags=cv2.IMREAD_COLOR):
  919. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  920. def imwrite(path, im):
  921. try:
  922. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  923. return True
  924. except Exception:
  925. return False
  926. def imshow(path, im):
  927. imshow_(path.encode('unicode_escape').decode(), im)
  928. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  929. # Variables ------------------------------------------------------------------------------------------------------------