trainUtils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # coding=utf-8
  2. import os
  3. import logging
  4. from transformers import AdamW, get_linear_schedule_with_warmup
  5. import torch
  6. logger = logging.getLogger(__name__)
  7. def build_optimizer_and_scheduler(args, model, t_total):
  8. module = (
  9. model.module if hasattr(model, "module") else model
  10. )
  11. # 差分学习率
  12. no_decay = ["bias", "LayerNorm.weight"]
  13. model_param = list(module.named_parameters())
  14. bert_param_optimizer = []
  15. crf_param_optimizer = []
  16. other_param_optimizer = []
  17. for name, para in model_param:
  18. space = name.split('.')
  19. # print(name)
  20. if space[0] == 'bert_module':
  21. bert_param_optimizer.append((name, para))
  22. elif space[0] == 'crf':
  23. crf_param_optimizer.append((name, para))
  24. else:
  25. other_param_optimizer.append((name, para))
  26. optimizer_grouped_parameters = [
  27. # bert other module
  28. {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
  29. "weight_decay": args.weight_decay, 'lr': args.lr},
  30. {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
  31. "weight_decay": 0.0, 'lr': args.lr},
  32. # crf模块
  33. {"params": [p for n, p in crf_param_optimizer if not any(nd in n for nd in no_decay)],
  34. "weight_decay": args.weight_decay, 'lr': args.crf_lr},
  35. {"params": [p for n, p in crf_param_optimizer if any(nd in n for nd in no_decay)],
  36. "weight_decay": 0.0, 'lr': args.other_lr},
  37. # 其他模块,差分学习率
  38. {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
  39. "weight_decay": args.weight_decay, 'lr': args.other_lr},
  40. {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
  41. "weight_decay": 0.0, 'lr': args.other_lr},
  42. ]
  43. optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
  44. scheduler = get_linear_schedule_with_warmup(
  45. optimizer, num_warmup_steps=int(args.warmup_proportion * t_total), num_training_steps=t_total
  46. )
  47. return optimizer, scheduler
  48. def save_model(args, model, model_name):
  49. """保存最好的验证集效果最好那个模型"""
  50. output_dir = os.path.join(args.output_dir, '{}'.format(model_name))
  51. if not os.path.exists(output_dir):
  52. os.makedirs(output_dir, exist_ok=True)
  53. # take care of model distributed / parallel training
  54. model_to_save = (
  55. model.module if hasattr(model, "module") else model
  56. )
  57. logger.info('Saving model checkpoint to {}'.format(output_dir))
  58. torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt'))
  59. def save_model_step(args, model, global_step):
  60. """根据global_step来保存模型"""
  61. output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
  62. if not os.path.exists(output_dir):
  63. os.makedirs(output_dir, exist_ok=True)
  64. # take care of model distributed / parallel training
  65. model_to_save = (
  66. model.module if hasattr(model, "module") else model
  67. )
  68. logger.info('Saving model & optimizer & scheduler checkpoint to {}.format(output_dir)')
  69. torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt'))
  70. def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True):
  71. """
  72. 加载模型 & 放置到 GPU 中(单卡 / 多卡)
  73. """
  74. gpu_ids = gpu_ids.split(',')
  75. # set to device to the first cuda
  76. device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
  77. if ckpt_path is not None:
  78. logger.info('Load ckpt from {}'.format(ckpt_path))
  79. model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict)
  80. model.to(device)
  81. if len(gpu_ids) > 1:
  82. logger.info('Use multi gpus in: {}'.format(gpu_ids))
  83. gpu_ids = [int(x) for x in gpu_ids]
  84. model = torch.nn.DataParallel(model, device_ids=gpu_ids)
  85. else:
  86. logger.info('Use single gpu in: {}'.format(gpu_ids))
  87. return model, device