main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. from pprint import pprint
  2. import os
  3. import logging
  4. import json
  5. import shutil
  6. from sklearn.metrics import accuracy_score, f1_score, classification_report
  7. import torch
  8. import torch.nn as nn
  9. import numpy as np
  10. import pickle
  11. from torch.utils.data import DataLoader, RandomSampler
  12. from transformers import BertTokenizer
  13. from tqdm import tqdm
  14. try:
  15. import bert_config
  16. import preprocess
  17. # 由于读取pickle文件,这里要显示传入
  18. from preprocess import BertFeature
  19. import dataset
  20. import models
  21. import utils
  22. from data_loader import Collate, MyDataset
  23. except Exception as e:
  24. import bert_config
  25. import preprocess
  26. # 由于读取pickle文件,这里要显示传入
  27. from .preprocess import BertFeature
  28. from . import dataset
  29. from . import models
  30. from . import utils
  31. from .data_loader import Collate, MyDataset
  32. logger = logging.getLogger(__name__)
  33. class Trainer:
  34. def __init__(self, args, train_loader, dev_loader, test_loader):
  35. self.args = args
  36. gpu_ids = args.gpu_ids.split(',')
  37. self.device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
  38. self.model = models.BertForRelationExtraction(args)
  39. self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.args.lr)
  40. self.criterion = nn.CrossEntropyLoss()
  41. self.train_loader = train_loader
  42. self.dev_loader = dev_loader
  43. self.test_loader = test_loader
  44. self.model.to(self.device)
  45. def load_ckp(self, model, checkpoint_path):
  46. checkpoint = torch.load(checkpoint_path, map_location=self.device)
  47. model.load_state_dict(checkpoint['state_dict'])
  48. return model
  49. def save_ckp(self, state, checkpoint_path):
  50. torch.save(state, checkpoint_path)
  51. """
  52. def save_ckp(self, state, is_best, checkpoint_path, best_model_path):
  53. tmp_checkpoint_path = checkpoint_path
  54. torch.save(state, tmp_checkpoint_path)
  55. if is_best:
  56. tmp_best_model_path = best_model_path
  57. shutil.copyfile(tmp_checkpoint_path, tmp_best_model_path)
  58. """
  59. def train(self):
  60. total_step = len(self.train_loader) * self.args.train_epochs
  61. global_step = 0
  62. eval_step = 100
  63. best_dev_micro_f1 = 0.0
  64. for epoch in range(args.train_epochs):
  65. for train_step, train_data in enumerate(self.train_loader):
  66. self.model.train()
  67. token_ids = train_data[0].to(self.device)
  68. attention_masks = train_data[1].to(self.device)
  69. token_type_ids = train_data[2].to(self.device)
  70. labels = train_data[3].to(self.device)
  71. ids = train_data[4].to(self.device)
  72. train_outputs = self.model(token_ids, attention_masks, token_type_ids, ids)
  73. loss = self.criterion(train_outputs, labels)
  74. self.optimizer.zero_grad()
  75. loss.backward()
  76. self.optimizer.step()
  77. logger.info(
  78. "【train】 epoch:{} step:{}/{} loss:{:.6f}".format(epoch, global_step, total_step, loss.item()))
  79. global_step += 1
  80. if global_step >= 279 :
  81. a=1
  82. # 由于数据量有点大,我们直接保存最后的模型就行
  83. # if global_step % eval_step == 0:
  84. # dev_loss, dev_outputs, dev_targets = self.dev()
  85. # accuracy, micro_f1, macro_f1 = self.get_metrics(dev_outputs, dev_targets)
  86. # logger.info(
  87. # "【dev】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(dev_loss, accuracy, micro_f1, macro_f1))
  88. # if macro_f1 > best_dev_micro_f1:
  89. # logger.info("------------>保存当前最好的模型")
  90. # checkpoint = {
  91. # 'epoch': epoch,
  92. # 'loss': dev_loss,
  93. # 'state_dict': self.model.state_dict(),
  94. # 'optimizer': self.optimizer.state_dict(),
  95. # }
  96. # best_dev_micro_f1 = macro_f1
  97. # checkpoint_path = os.path.join(self.args.output_dir, 'best.pt')
  98. # self.save_ckp(checkpoint, checkpoint_path)
  99. # if global_step == 4000:
  100. # checkpoint_path = os.path.join(self.args.output_dir, 'best.pt')
  101. # checkpoint = {
  102. # 'state_dict': self.model.state_dict(),
  103. # }
  104. # self.save_ckp(checkpoint, checkpoint_path)
  105. # break
  106. checkpoint_path = os.path.join(self.args.output_dir, 'best.pt')
  107. checkpoint = {
  108. 'state_dict': self.model.state_dict(),
  109. }
  110. self.save_ckp(checkpoint, checkpoint_path)
  111. def dev(self):
  112. self.model.eval()
  113. total_loss = 0.0
  114. dev_outputs = []
  115. dev_targets = []
  116. with torch.no_grad():
  117. for dev_step, dev_data in enumerate(self.dev_loader):
  118. token_ids = dev_data[0].to(self.device)
  119. attention_masks = dev_data[1].to(self.device)
  120. token_type_ids = dev_data[2].to(self.device)
  121. labels = dev_data[3].to(self.device)
  122. ids = dev_data[4].to(self.device)
  123. outputs = self.model(token_ids, attention_masks, token_type_ids, ids)
  124. loss = self.criterion(outputs, labels)
  125. # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
  126. total_loss += loss.item()
  127. outputs = np.argmax(outputs.cpu().detach().numpy(),axis=1).flatten()
  128. dev_outputs.extend(outputs.tolist())
  129. dev_targets.extend(labels.cpu().detach().numpy().tolist())
  130. return total_loss, dev_outputs, dev_targets
  131. def test(self, checkpoint_path):
  132. model = self.model
  133. optimizer = self.optimizer
  134. model = self.load_ckp(model, checkpoint_path)
  135. model.eval()
  136. model.to(self.device)
  137. total_loss = 0.0
  138. test_outputs = []
  139. test_targets = []
  140. with torch.no_grad():
  141. for test_step, test_data in enumerate(tqdm(self.test_loader, ncols=100)):
  142. token_ids = test_data[0].to(self.device)
  143. attention_masks = test_data[1].to(self.device)
  144. token_type_ids = test_data[2].to(self.device)
  145. labels = test_data[3].to(self.device)
  146. ids = test_data[4].to(self.device)
  147. outputs = model(token_ids, attention_masks, token_type_ids, ids)
  148. loss = self.criterion(outputs, labels)
  149. # val_loss = val_loss + ((1 / (dev_step + 1))) * (loss.item() - val_loss)
  150. total_loss += loss.item()
  151. outputs = np.argmax(outputs.cpu().detach().numpy(),axis=1).flatten()
  152. test_outputs.extend(outputs.tolist())
  153. test_targets.extend(labels.cpu().detach().numpy().tolist())
  154. return total_loss, test_outputs, test_targets
  155. def predict(self, tokenizer, text, id2label, args, ids):
  156. model = self.model
  157. optimizer = self.optimizer
  158. checkpoint = os.path.join(args.output_dir, 'best.pt')
  159. model = self.load_ckp(model, checkpoint)
  160. model.eval()
  161. model.to(self.device)
  162. with torch.no_grad():
  163. text = [i for i in text]
  164. inputs = tokenizer.encode_plus(text=text,
  165. add_special_tokens=True,
  166. max_length=args.max_seq_len,
  167. truncation='longest_first',
  168. padding="max_length",
  169. return_token_type_ids=True,
  170. return_attention_mask=True,
  171. return_tensors='pt')
  172. # token_ids = inputs['input_ids'].to(self.device)
  173. token_ids = inputs['input_ids'].to(self.device).long()
  174. attention_masks = inputs['attention_mask'].to(self.device)
  175. token_type_ids = inputs['token_type_ids'].to(self.device)
  176. ids = torch.from_numpy(np.array([[x+1 for x in ids]])).to(self.device)
  177. outputs = model(token_ids, attention_masks, token_type_ids, ids)
  178. outputs = np.argmax(outputs.cpu().detach().numpy(),axis=1).flatten().tolist()
  179. if len(outputs) != 0:
  180. outputs = [id2label[i] for i in outputs]
  181. return outputs
  182. else:
  183. return '不好意思,我没有识别出来'
  184. def get_metrics(self, outputs, targets):
  185. accuracy = accuracy_score(targets, outputs)
  186. micro_f1 = f1_score(targets, outputs, average='micro')
  187. macro_f1 = f1_score(targets, outputs, average='macro')
  188. return accuracy, micro_f1, macro_f1
  189. def get_classification_report(self, outputs, targets, labels):
  190. report = classification_report(targets, outputs, target_names=labels)
  191. return report
  192. if __name__ == '__main__':
  193. args = bert_config.Args().get_parser()
  194. utils.utils.set_seed(args.seed)
  195. utils.utils.set_logger(os.path.join(args.log_dir, 'main.log'))
  196. args.bert_dir = "../model_hub/chinese-roberta-wwm-ext/"
  197. args.data_dir = "../data/dgre/"
  198. args.log_dir = "./logs/"
  199. args.output_dir = "./checkpoints/"
  200. args.num_tags = 5 # 根据rels.txt里面数目而定
  201. args.seed = 123
  202. args.gpu_ids = "0"
  203. args.max_seq_len = 512
  204. args.lr = 3e-5
  205. args.other_lr = 3e-4
  206. args.train_batch_size = 2
  207. args.train_epochs = 1
  208. args.eval_batch_size = 2
  209. args.dropout_prob = 0.3
  210. processor = preprocess.Processor()
  211. re_mid_data_path = os.path.join(args.data_dir, 're_mid_data')
  212. re_final_data_path = os.path.join(args.data_dir, 're_final_data')
  213. label2id = {}
  214. id2label = {}
  215. with open(re_mid_data_path+'/rels.txt','r',encoding='utf-8') as fp:
  216. labels = fp.read().strip().split('\n')
  217. for i,j in enumerate(labels):
  218. label2id[j] = i
  219. id2label[i] = j
  220. print(label2id)
  221. # train_out = preprocess.get_out(processor, './data/train.txt', args, id2label, 'train')
  222. # dev_out = preprocess.get_out(processor, './data/test.txt', args, id2label, 'dev')
  223. # test_out = preprocess.get_out(processor, './data/test.txt', args, id2label, 'test')
  224. # train_out = pickle.load(open(re_final_data_path+'/train.pkl','rb'))
  225. # dev_out = pickle.load(open(re_final_data_path+'/dev.pkl','rb'))
  226. # test_out = pickle.load(open(re_final_data_path+'/dev.pkl','rb'))
  227. # train_features, train_callback_info = train_out
  228. # train_dataset = dataset.ReDataset(train_features)
  229. # train_sampler = RandomSampler(train_dataset)
  230. # train_loader = DataLoader(dataset=train_dataset,
  231. # batch_size=args.train_batch_size,
  232. # sampler=train_sampler,
  233. # num_workers=2)
  234. # dev_features, dev_callback_info = dev_out[:500]
  235. # dev_dataset = dataset.ReDataset(dev_features)
  236. # dev_loader = DataLoader(dataset=dev_dataset,
  237. # batch_size=args.eval_batch_size,
  238. # num_workers=2)
  239. # test_features, test_callback_info = dev_out
  240. # test_dataset = dataset.ReDataset(test_features)
  241. # test_loader = DataLoader(dataset=test_dataset,
  242. # batch_size=args.eval_batch_size,
  243. # num_workers=2)
  244. device = torch.device("cpu" if args.gpu_ids[0] == '-1' else "cuda:" + args.gpu_ids[0])
  245. tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
  246. collate = Collate(max_len=args.max_seq_len, tag2id=label2id, device=device, tokenizer=tokenizer)
  247. train_dataset = MyDataset(file_path=re_mid_data_path + '/train.txt')
  248. train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate.collate_fn)
  249. dev_dataset = MyDataset(file_path=re_mid_data_path + '/dev.txt')
  250. dev_loader = DataLoader(dev_dataset, batch_size=args.eval_batch_size, shuffle=False, collate_fn=collate.collate_fn)
  251. test_loader = dev_loader
  252. trainer = Trainer(args, train_loader, dev_loader, test_loader)
  253. # 训练和验证
  254. trainer.train()
  255. # 测试
  256. logger.info('========进行测试========')
  257. checkpoint_path = './checkpoints/best.pt'
  258. total_loss, test_outputs, test_targets = trainer.test(checkpoint_path)
  259. accuracy, micro_f1, macro_f1 = trainer.get_metrics(test_outputs, test_targets)
  260. logger.info(
  261. "【test】 loss:{:.6f} accuracy:{:.4f} micro_f1:{:.4f} macro_f1:{:.4f}".format(total_loss, accuracy, micro_f1, macro_f1))
  262. # report = trainer.get_classification_report(test_outputs, test_targets, labels)
  263. # logger.info(report)
  264. # 预测
  265. # with open(re_mid_data_path + '/predict.txt', 'r') as fp:
  266. # lines = fp.readlines()
  267. # for line in lines:
  268. # line = line.strip().split('\t')
  269. # label = line[0]
  270. # text = line[1]
  271. # ids = [int(line[2]),int(line[3]),int(line[4]),int(line[5])]
  272. # logger.info(text)
  273. # result = trainer.predict(tokenizer, text, id2label, args, ids)
  274. # logger.info("预测标签:" + "".join(result))
  275. # logger.info("真实标签:" + label)
  276. # logger.info("==========================")
  277. # # 预测单条
  278. # # text = '丈夫 这件婚事原本与陈$国峻$无关,但陈国峻却“欲求配而无由,夜间乃潜入#天城公主#所居通之 34 39 9 12'
  279. # text = '1537年,#亨利八世#和他的第三个王后$简·西摩$生了一个男孩:爱德华(后来的爱德华六世)。'
  280. # ids = [34, 39, 9, 12]
  281. # print('预测标签:', trainer.predict(tokenizer, text, id2label, args, ids))
  282. # print('真实标签:', '丈夫')
  283. text = '62号汽车故障报告综合情况:故障现象:加速后,丢开油门,#发动机#$熄火$。'
  284. ids = [29, 33, 34, 37]
  285. print('预测标签:', trainer.predict(tokenizer, text, id2label, args, ids))
  286. print('真实标签:', '部件故障')