# coding=utf-8 import random import logging import time import numpy as np import torch from torch.nn.utils.rnn import pad_sequence def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): """将序列padding到同一长度 """ # global inp # if len(list(inputs))==0: # inputs=inp if isinstance(inputs[0], (np.ndarray, list)): # inp=inputs if length is None: length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) elif not hasattr(length, '__getitem__'): length = [length] slices = [np.s_[:length[i]] for i in range(seq_dims)] slices = tuple(slices) if len(slices) > 1 else slices[0] pad_width = [(0, 0) for _ in np.shape(inputs[0])] outputs = [] for x in inputs: x = x[slices] for i in range(seq_dims): if mode == 'post': pad_width[i] = (0, length[i] - np.shape(x)[i]) elif mode == 'pre': pad_width[i] = (length[i] - np.shape(x)[i], 0) else: raise ValueError('"mode" argument must be "post" or "pre".') x = np.pad(x, pad_width, 'constant', constant_values=value) outputs.append(x) return np.array(outputs) elif isinstance(inputs[0], torch.Tensor): # assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor' if length is not None: inputs = [i[:length] for i in inputs] return pad_sequence(inputs, padding_value=value, batch_first=True) elif len(list(inputs))==0: a=1 else: # else: a=1 # raise ValueError('"input" argument must be tensor/list/ndarray.') def timer(func): """ 函数计时器 :param func: :return: """ @functools.wraps(func) def wrapper(*args, **kwargs): start = time.time() res = func(*args, **kwargs) end = time.time() print("{}共耗时约{:.4f}秒".format(func.__name__, end - start)) return res return wrapper def set_seed(seed=123): """ 设置随机数种子,保证实验可重现 :param seed: :return: """ random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) def set_logger(log_path): """ 配置log :param log_path:s :return: """ logger = logging.getLogger() logger.setLevel(logging.INFO) # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers): file_handler = logging.FileHandler(log_path) formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s') file_handler.setFormatter(formatter) logger.addHandler(file_handler) if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers): stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter('%(message)s')) logger.addHandler(stream_handler)