main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import os
  2. import logging
  3. import numpy as np
  4. import torch
  5. # 之前是自定义评价指标
  6. try:
  7. from utils import commonUtils, metricsUtils, decodeUtils, trainUtils
  8. import config
  9. import dataset
  10. from preprocess import BertFeature
  11. import bert_ner_model
  12. except Exception as e:
  13. from utils import commonUtils, metricsUtils, decodeUtils, trainUtils
  14. import config
  15. import dataset
  16. from preprocess import BertFeature
  17. import bert_ner_model
  18. # 现在我们使用seqeval库里面的
  19. from seqeval.metrics import accuracy_score,precision_score,recall_score,f1_score,classification_report
  20. # 要显示传入BertFeature
  21. from torch.utils.data import DataLoader, RandomSampler
  22. from transformers import BertTokenizer
  23. args = config.Args().get_parser()
  24. commonUtils.set_seed(args.seed)
  25. logger = logging.getLogger(__name__)
  26. class BertForNer:
  27. def __init__(self, args, train_loader, dev_loader, test_loader, idx2tag):
  28. self.train_loader = train_loader
  29. self.dev_loader = dev_loader
  30. self.test_loader = test_loader
  31. self.args = args
  32. self.idx2tag = idx2tag
  33. model = bert_ner_model.BertNerModel(args)
  34. self.model, self.device = trainUtils.load_model_and_parallel(model, args.gpu_ids)
  35. if self.train_loader:
  36. self.t_total = len(self.train_loader) * args.train_epochs
  37. self.optimizer, self.scheduler = trainUtils.build_optimizer_and_scheduler(args, model, self.t_total)
  38. def train(self):
  39. # Train
  40. global_step = 0
  41. self.model.zero_grad()
  42. eval_steps = 1 #每多少个step打印损失及进行验证
  43. best_f1 = 0.0
  44. for epoch in range(args.train_epochs):
  45. for step, batch_data in enumerate(self.train_loader):
  46. self.model.train()
  47. for key in batch_data.keys():
  48. if key != 'texts':
  49. batch_data[key] = batch_data[key].to(self.device)
  50. loss, logits = self.model(batch_data['token_ids'], batch_data['attention_masks'], batch_data['token_type_ids'], batch_data['labels'])
  51. torch.nn.utils.clip_grad_norm_(self.model.parameters(), args.max_grad_norm)
  52. # loss.backward(loss.clone().detach())
  53. loss.backward()
  54. self.optimizer.step()
  55. self.scheduler.step()
  56. self.model.zero_grad()
  57. logger.info('【train】 epoch:{} {}/{} loss:{:.4f}'.format(epoch, global_step, self.t_total, loss.item()))
  58. global_step += 1
  59. """这里验证耗时有点长,我们最后直接保存模型就好
  60. if global_step % eval_steps == 0:
  61. dev_loss, accuracy, precision, recall, f1 = self.dev()
  62. logger.info('[eval] loss:{:.4f} accuracy:{:.4f} precision={:.4f} recall={:.4f} f1_score={:.4f}'.format(dev_loss, accuracy, precision, recall, f1))
  63. if f1 > best_f1:
  64. trainUtils.save_model(args, self.model, model_name, global_step)
  65. f1 = f1_score
  66. """
  67. trainUtils.save_model(args, self.model, model_name)
  68. def dev(self):
  69. self.model.eval()
  70. with torch.no_grad():
  71. batch_output_all = []
  72. batch_true_all = []
  73. tot_dev_loss = 0.0
  74. for eval_step, dev_batch_data in enumerate(self.dev_loader):
  75. for key in dev_batch_data.keys():
  76. dev_batch_data[key] = dev_batch_data[key].to(self.device)
  77. dev_loss, dev_logits = self.model(dev_batch_data['token_ids'], dev_batch_data['attention_masks'],dev_batch_data['token_type_ids'], dev_batch_data['labels'])
  78. tot_dev_loss += dev_loss.item()
  79. if self.args.use_crf == 'True':
  80. batch_output = dev_logits
  81. else:
  82. batch_output = dev_logits.detach().cpu().numpy()
  83. batch_output = np.argmax(batch_output, axis=2)
  84. if len(batch_output_all) == 0:
  85. batch_output_all = batch_output
  86. # 获取真实的长度标签
  87. tmp_labels = dev_batch_data['labels'].detach().cpu().numpy()
  88. tmp_masks = dev_batch_data['attention_masks'].detach().cpu().numpy()
  89. # print(tmp_labels.shape)
  90. # print(tmp_masks.shape)
  91. batch_output_all = [list(map(lambda x:self.idx2tag[x], i)) for i in batch_output_all]
  92. batch_true_all = [list(tmp_labels[i][tmp_masks[i]==1]) for i in range(tmp_labels.shape[0])]
  93. batch_true_all = [list(map(lambda x:self.idx2tag[x], i)) for i in batch_true_all]
  94. # print(batch_output_all[1])
  95. # print(batch_true_all[1])
  96. else:
  97. batch_output = [list(map(lambda x:self.idx2tag[x], i)) for i in batch_output]
  98. batch_output_all = np.append(batch_output_all, batch_output, axis=0)
  99. tmp_labels = dev_batch_data['labels'].detach().cpu().numpy()
  100. tmp_masks = dev_batch_data['attention_masks'].detach().cpu().numpy()
  101. tmp_batch_true_all = [list(tmp_labels[i][tmp_masks[i]==1]) for i in range(tmp_labels.shape[0])]
  102. tmp_batch_true_all = [list(map(lambda x:self.idx2tag[x], i)) for i in tmp_batch_true_all]
  103. batch_true_all = np.append(batch_true_all, tmp_batch_true_all, axis=0)
  104. accuracy = accuracy_score(batch_true_all, batch_output_all)
  105. precision = precision_score(batch_true_all, batch_output_all)
  106. recall = recall_score(batch_true_all, batch_output_all)
  107. f1 = f1_score(batch_true_all, batch_output_all)
  108. return tot_dev_loss, accuracy, precision, recall, f1
  109. def test(self, model_path):
  110. model = bert_ner_model.BertNerModel(self.args)
  111. model, device = trainUtils.load_model_and_parallel(model, self.args.gpu_ids, model_path)
  112. model.eval()
  113. pred_label = []
  114. true_label = []
  115. with torch.no_grad():
  116. for eval_step, dev_batch_data in enumerate(self.test_loader):
  117. for key in dev_batch_data.keys():
  118. dev_batch_data[key] = dev_batch_data[key].to(device)
  119. _, logits = model(dev_batch_data['token_ids'], dev_batch_data['attention_masks'],dev_batch_data['token_type_ids'],dev_batch_data['labels'])
  120. if self.args.use_crf == 'True':
  121. batch_output = logits
  122. else:
  123. batch_output = logits.detach().cpu().numpy()
  124. batch_output = np.argmax(batch_output, axis=2)
  125. if len(pred_label) == 0:
  126. tmp_labels = dev_batch_data['labels'].detach().cpu().numpy()
  127. tmp_masks = dev_batch_data['attention_masks'].detach().cpu().numpy()
  128. pred_label = [list(map(lambda x:self.idx2tag[x], i)) for i in batch_output]
  129. # true_label = dev_batch_data['labels'].detach().cpu().numpy().tolist()
  130. true_label = [list(tmp_labels[i][tmp_masks[i]==1]) for i in range(tmp_labels.shape[0])]
  131. true_label = [list(map(lambda x:self.idx2tag[x], i)) for i in true_label]
  132. print(pred_label)
  133. print(true_label)
  134. else:
  135. # pred_label = np.append(pred_label, batch_output, axis=0)
  136. # true_label = np.append(pred_label, dev_batch_data['labels'].detach().cpu().numpy().tolist(), axis=0)
  137. batch_output = [list(map(lambda x:self.idx2tag[x], i)) for i in batch_output]
  138. print(batch_output)
  139. pred_label = np.append(pred_label, batch_output)
  140. print(pred_label)
  141. tmp_labels = dev_batch_data['labels'].detach().cpu().numpy()
  142. # print( tmp_labels)
  143. tmp_masks = dev_batch_data['attention_masks'].detach().cpu().numpy()
  144. tmp_batch_true_all = [list(tmp_labels[i][tmp_masks[i]==1]) for i in range(tmp_labels.shape[0])]
  145. tmp_batch_true_all = [list(map(lambda x:self.idx2tag[x], i)) for i in tmp_batch_true_all]
  146. true_label = np.append(true_label, tmp_batch_true_all)
  147. #logger.info(classification_report(true_label, pred_label))
  148. # pred_label = str(pred_label)
  149. # true_label = str(true_label)
  150. # print(classification_report(true_label, pred_label))
  151. def predict(self, raw_text, model_path):
  152. model = bert_ner_model.BertNerModel(self.args)
  153. model, device = trainUtils.load_model_and_parallel(model, self.args.gpu_ids, model_path)
  154. model.eval()
  155. with torch.no_grad():
  156. tokenizer = BertTokenizer(
  157. os.path.join(self.args.bert_dir, 'vocab.txt'))
  158. # tokens = commonUtils.fine_grade_tokenize(raw_text, tokenizer)
  159. tokens = [i for i in raw_text]
  160. encode_dict = tokenizer.encode_plus(text=tokens,
  161. max_length=self.args.max_seq_len,
  162. padding='max_length',
  163. truncation='longest_first',
  164. is_pretokenized=True,
  165. return_token_type_ids=True,
  166. return_attention_mask=True)
  167. # tokens = ['[CLS]'] + tokens + ['[SEP]']
  168. token_ids = torch.from_numpy(np.array(encode_dict['input_ids'])).unsqueeze(0)
  169. token_ids=torch.LongTensor(token_ids.numpy())
  170. attention_masks = torch.from_numpy(np.array(encode_dict['attention_mask'],dtype=np.uint8)).unsqueeze(0)
  171. # attention_masks=torch.LongTensor(attention_masks.numpy(),dtype=np.uint8)
  172. token_type_ids = torch.from_numpy(np.array(encode_dict['token_type_ids'])).unsqueeze(0)
  173. token_type_ids=torch.LongTensor(token_type_ids.numpy())
  174. logits = model(token_ids.to(device), attention_masks.to(device), token_type_ids.to(device), None)
  175. if self.args.use_crf != "True":
  176. logits = logits.detach().cpu().numpy()
  177. logits = np.argmax(logits, axis=2)
  178. pred_label = [list(map(lambda x:self.idx2tag[x], i)) for i in logits]
  179. # assert len(pred_label[0]) == len(tokens)+2
  180. pred_entities = decodeUtils.get_entities(pred_label[0][1:1+len(tokens)], "".join(tokens))
  181. logger.info(pred_entities)
  182. return pred_entities
  183. if __name__ == '__main__':
  184. data_name = 'dgre'
  185. args.bert_dir = "../model_hub/chinese-roberta-wwm-ext/" # 预训练模型名称
  186. args.data_dir = "../data/dgre/"
  187. args.log_dir = "./logs/"
  188. args.output_dir = "./checkpoints/"
  189. args.num_tags = 5
  190. args.seed = 123
  191. args.gpu_ids = "0"
  192. args.max_seq_len = 512 # 和preprocess.py里面的一致
  193. args.lr = 3e-5
  194. args.crf_lr = 3e-2
  195. args.other_lr = 3e-4
  196. args.train_batch_size = 2
  197. args.train_epochs = 5
  198. args.eval_batch_size = 2
  199. args.max_grad_norm = 1
  200. args.warmup_proportion = 0.1
  201. args.adam_epsilon = 1e-8
  202. args.weight_decay = 0.01
  203. args.lstm_hidden = 128
  204. args.num_layers = 1
  205. args.use_lstm = 'False'
  206. args.use_crf = 'True'
  207. args.dropout_prob = 0.3
  208. args.dropout = 0.3
  209. # args.train_epochs = 1
  210. # args.train_batch_size = 32
  211. # args.max_seq_len = 300
  212. model_name = ''
  213. if args.use_lstm == 'True' and args.use_crf == 'False':
  214. model_name = 'bert_bilstm'
  215. if args.use_lstm == 'True' and args.use_crf == 'True':
  216. model_name = 'bert_bilstm_crf'
  217. if args.use_lstm == 'False' and args.use_crf == 'True':
  218. model_name = 'bert_crf'
  219. if args.use_lstm == 'False' and args.use_crf == 'False':
  220. model_name = 'bert'
  221. commonUtils.set_logger(os.path.join(args.log_dir, '{}.log'.format(model_name)))
  222. data_path = os.path.join(args.data_dir, 'ner_final_data')
  223. mid_data_path = os.path.join(args.data_dir, 'mid_data')
  224. # 真实标签
  225. ent_labels_path = mid_data_path + '/ent_labels.txt'
  226. # 序列标注标签B I O
  227. ner_labels_path = mid_data_path + '/ner_labels.txt'
  228. with open(ent_labels_path, 'r') as fp:
  229. ent_labels = fp.read().strip().split('\n')
  230. entlabel2id = {}
  231. id2entlabel = {}
  232. for i,j in enumerate(ent_labels):
  233. entlabel2id[j] = i
  234. id2entlabel[i] = j
  235. nerlabel2id = {}
  236. id2nerlabel = {}
  237. with open(ner_labels_path,'r') as fp:
  238. ner_labels = fp.read().strip().split('\n')
  239. for i,j in enumerate(ner_labels):
  240. nerlabel2id[j] = i
  241. id2nerlabel[i] = j
  242. logger.info(id2nerlabel)
  243. args.num_tags = len(ner_labels)
  244. logger.info(args)
  245. train_features, train_callback_info = commonUtils.read_pkl(data_path, 'train')
  246. train_dataset = dataset.NerDataset(train_features)
  247. train_sampler = RandomSampler(train_dataset)
  248. train_loader = DataLoader(dataset=train_dataset,
  249. batch_size=args.train_batch_size,
  250. sampler=train_sampler,
  251. num_workers=2)
  252. dev_features, dev_callback_info = commonUtils.read_pkl(data_path, 'dev')
  253. dev_dataset = dataset.NerDataset(dev_features)
  254. dev_loader = DataLoader(dataset=dev_dataset,
  255. batch_size=args.eval_batch_size,
  256. num_workers=2)
  257. # test_features, test_callback_info = commonUtils.read_pkl(data_path, 'test')
  258. # test_dataset = dataset.NerDataset(test_features)
  259. # test_loader = DataLoader(dataset=test_dataset,
  260. # batch_size=args.eval_batch_size,
  261. # num_workers=2)
  262. bertForNer = BertForNer(args, train_loader, dev_loader, dev_loader, id2nerlabel)
  263. bertForNer.train()
  264. model_path = './checkpoints/{}/model.pt'.format(model_name)
  265. bertForNer.test(model_path)
  266. if data_name == "dgre":
  267. raw_text = "211号汽车故障报告综合情况:故障现象:开暖风鼓风机运转时有异常响声。故障原因简要分析:该故障是鼓风机运转时有异响由此可以判断可能原因:1鼓风机故障 2鼓风机内有杂物"
  268. elif data_name == "duie":
  269. raw_text = "《单身》是Outsider演唱的歌曲,收录于专辑《2辑Maestro》。描写一个人单身的感觉,单身即是痛苦的也是幸福的,在于人们如何去看待s"
  270. logger.info(raw_text)
  271. bertForNer.predict(raw_text, model_path)