commonUtils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # coding=utf-8
  2. import random
  3. import os
  4. import json
  5. import logging
  6. import time
  7. import pickle
  8. import numpy as np
  9. import torch
  10. def timer(func):
  11. """
  12. 函数计时器
  13. :param func:
  14. :return:
  15. """
  16. @functools.wraps(func)
  17. def wrapper(*args, **kwargs):
  18. start = time.time()
  19. res = func(*args, **kwargs)
  20. end = time.time()
  21. print("{}共耗时约{:.4f}秒".format(func.__name__, end - start))
  22. return res
  23. return wrapper
  24. def set_seed(seed=123):
  25. """
  26. 设置随机数种子,保证实验可重现
  27. :param seed:
  28. :return:
  29. """
  30. random.seed(seed)
  31. torch.manual_seed(seed)
  32. np.random.seed(seed)
  33. torch.cuda.manual_seed_all(seed)
  34. def set_logger(log_path):
  35. """
  36. 配置log
  37. :param log_path:s
  38. :return:
  39. """
  40. logger = logging.getLogger()
  41. logger.setLevel(logging.INFO)
  42. # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler
  43. if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers):
  44. file_handler = logging.FileHandler(log_path)
  45. formatter = logging.Formatter(
  46. '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s')
  47. file_handler.setFormatter(formatter)
  48. logger.addHandler(file_handler)
  49. if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers):
  50. stream_handler = logging.StreamHandler()
  51. stream_handler.setFormatter(logging.Formatter('%(message)s'))
  52. logger.addHandler(stream_handler)
  53. def save_json(data_dir, data, desc):
  54. """保存数据为json"""
  55. with open(os.path.join(data_dir, '{}.json'.format(desc)), 'w', encoding='utf-8') as f:
  56. json.dump(data, f, ensure_ascii=False, indent=2)
  57. def read_json(data_dir, desc):
  58. """读取数据为json"""
  59. with open(os.path.join(data_dir, '{}.json'.format(desc)), 'r', encoding='utf-8') as f:
  60. data = json.load(f)
  61. return data
  62. def save_pkl(data_dir, data, desc):
  63. """保存.pkl文件"""
  64. with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'wb') as f:
  65. pickle.dump(data, f)
  66. def read_pkl(data_dir, desc):
  67. """读取.pkl文件"""
  68. with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'rb') as f:
  69. data = pickle.load(f)
  70. return data
  71. def fine_grade_tokenize(raw_text, tokenizer):
  72. """
  73. 序列标注任务 BERT 分词器可能会导致标注偏移,
  74. 用 char-level 来 tokenize
  75. """
  76. tokens = []
  77. for _ch in raw_text:
  78. if _ch in [' ', '\t', '\n']:
  79. tokens.append('[UNK]')
  80. else:
  81. if not len(tokenizer.tokenize(_ch)):
  82. tokens.append('[UNK]')
  83. else:
  84. tokens.append(_ch)
  85. return tokens