utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # coding=utf-8
  2. import random
  3. import logging
  4. import time
  5. import numpy as np
  6. import torch
  7. from torch.nn.utils.rnn import pad_sequence
  8. def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
  9. """将序列padding到同一长度
  10. """
  11. # global inp
  12. # if len(list(inputs))==0:
  13. # inputs=inp
  14. if isinstance(inputs[0], (np.ndarray, list)):
  15. # inp=inputs
  16. if length is None:
  17. length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
  18. elif not hasattr(length, '__getitem__'):
  19. length = [length]
  20. slices = [np.s_[:length[i]] for i in range(seq_dims)]
  21. slices = tuple(slices) if len(slices) > 1 else slices[0]
  22. pad_width = [(0, 0) for _ in np.shape(inputs[0])]
  23. outputs = []
  24. for x in inputs:
  25. x = x[slices]
  26. for i in range(seq_dims):
  27. if mode == 'post':
  28. pad_width[i] = (0, length[i] - np.shape(x)[i])
  29. elif mode == 'pre':
  30. pad_width[i] = (length[i] - np.shape(x)[i], 0)
  31. else:
  32. raise ValueError('"mode" argument must be "post" or "pre".')
  33. x = np.pad(x, pad_width, 'constant', constant_values=value)
  34. outputs.append(x)
  35. return np.array(outputs)
  36. elif isinstance(inputs[0], torch.Tensor):
  37. # assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor'
  38. if length is not None:
  39. inputs = [i[:length] for i in inputs]
  40. return pad_sequence(inputs, padding_value=value, batch_first=True)
  41. elif len(list(inputs))==0:
  42. a=1
  43. else:
  44. # else:
  45. a=1
  46. # raise ValueError('"input" argument must be tensor/list/ndarray.')
  47. def timer(func):
  48. """
  49. 函数计时器
  50. :param func:
  51. :return:
  52. """
  53. @functools.wraps(func)
  54. def wrapper(*args, **kwargs):
  55. start = time.time()
  56. res = func(*args, **kwargs)
  57. end = time.time()
  58. print("{}共耗时约{:.4f}秒".format(func.__name__, end - start))
  59. return res
  60. return wrapper
  61. def set_seed(seed=123):
  62. """
  63. 设置随机数种子,保证实验可重现
  64. :param seed:
  65. :return:
  66. """
  67. random.seed(seed)
  68. torch.manual_seed(seed)
  69. np.random.seed(seed)
  70. torch.cuda.manual_seed_all(seed)
  71. def set_logger(log_path):
  72. """
  73. 配置log
  74. :param log_path:s
  75. :return:
  76. """
  77. logger = logging.getLogger()
  78. logger.setLevel(logging.INFO)
  79. # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler
  80. if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers):
  81. file_handler = logging.FileHandler(log_path)
  82. formatter = logging.Formatter(
  83. '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s')
  84. file_handler.setFormatter(formatter)
  85. logger.addHandler(file_handler)
  86. if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers):
  87. stream_handler = logging.StreamHandler()
  88. stream_handler.setFormatter(logging.Formatter('%(message)s'))
  89. logger.addHandler(stream_handler)