123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # coding=utf-8
- import random
- import logging
- import time
- import numpy as np
- import torch
- from torch.nn.utils.rnn import pad_sequence
- def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
- """将序列padding到同一长度
- """
- # global inp
- # if len(list(inputs))==0:
- # inputs=inp
- if isinstance(inputs[0], (np.ndarray, list)):
- # inp=inputs
- if length is None:
- length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
- elif not hasattr(length, '__getitem__'):
- length = [length]
- slices = [np.s_[:length[i]] for i in range(seq_dims)]
- slices = tuple(slices) if len(slices) > 1 else slices[0]
- pad_width = [(0, 0) for _ in np.shape(inputs[0])]
- outputs = []
- for x in inputs:
- x = x[slices]
- for i in range(seq_dims):
- if mode == 'post':
- pad_width[i] = (0, length[i] - np.shape(x)[i])
- elif mode == 'pre':
- pad_width[i] = (length[i] - np.shape(x)[i], 0)
- else:
- raise ValueError('"mode" argument must be "post" or "pre".')
- x = np.pad(x, pad_width, 'constant', constant_values=value)
- outputs.append(x)
- return np.array(outputs)
-
- elif isinstance(inputs[0], torch.Tensor):
- # assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor'
- if length is not None:
- inputs = [i[:length] for i in inputs]
- return pad_sequence(inputs, padding_value=value, batch_first=True)
- elif len(list(inputs))==0:
- a=1
- else:
- # else:
- a=1
- # raise ValueError('"input" argument must be tensor/list/ndarray.')
- def timer(func):
- """
- 函数计时器
- :param func:
- :return:
- """
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- start = time.time()
- res = func(*args, **kwargs)
- end = time.time()
- print("{}共耗时约{:.4f}秒".format(func.__name__, end - start))
- return res
- return wrapper
- def set_seed(seed=123):
- """
- 设置随机数种子,保证实验可重现
- :param seed:
- :return:
- """
- random.seed(seed)
- torch.manual_seed(seed)
- np.random.seed(seed)
- torch.cuda.manual_seed_all(seed)
- def set_logger(log_path):
- """
- 配置log
- :param log_path:s
- :return:
- """
- logger = logging.getLogger()
- logger.setLevel(logging.INFO)
- # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler
- if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers):
- file_handler = logging.FileHandler(log_path)
- formatter = logging.Formatter(
- '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s')
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers):
- stream_handler = logging.StreamHandler()
- stream_handler.setFormatter(logging.Formatter('%(message)s'))
- logger.addHandler(stream_handler)
|