get_result.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import sys
  2. sys.path.append("./bert_bilstm_crf_ner")
  3. sys.path.append("./bert_re")
  4. import bert_bilstm_crf_ner.config as ner_config
  5. # import bert_bilstm_crf_ner.bert_ner_model as ner_model
  6. import bert_bilstm_crf_ner.main as ner_main
  7. import bert_re.main as re_main
  8. import bert_re.bert_config as re_config
  9. import os
  10. import re
  11. import logging
  12. from transformers import BertTokenizer
  13. from bert_bilstm_crf_ner import bert_ner_model as ner_model
  14. import bert_re.models as re_model
  15. logger = logging.getLogger(__name__)
  16. def get_ner_result(raw_text):
  17. # 命名实体识别相关
  18. model_name = 'bert_crf'
  19. ner_args = ner_config.Args().get_parser()
  20. ner_args.bert_dir = './model_hub/chinese-roberta-wwm-ext/'
  21. ner_args.gpu_ids = "-1"
  22. ner_args.use_lstm = 'False'
  23. ner_args.use_crf = 'True'
  24. ner_args.num_tags = 5
  25. ner_args.max_seq_len = 512
  26. ner_args.num_layers = 1
  27. ner_args.lstm_hidden = 128
  28. nerlabel2id = {}
  29. id2nerlabel = {}
  30. with open('./data/dgre/mid_data/ner_labels.txt','r') as fp:
  31. ner_labels = fp.read().strip().split('\n')
  32. for i,j in enumerate(ner_labels):
  33. nerlabel2id[j] = i
  34. id2nerlabel[i] = j
  35. logger.info(id2nerlabel)
  36. bertForNer = ner_main.BertForNer(ner_args, None, None, None, id2nerlabel)
  37. model_path = './bert_bilstm_crf_ner/checkpoints/{}/model.pt'.format(model_name)
  38. pred_entities = bertForNer.predict(raw_text, model_path)
  39. return pred_entities
  40. def get_re_result(entities, raw_text):
  41. # 首先先区分是主体还是客体
  42. subjects = []
  43. objects = []
  44. for info in entities:
  45. print(info)
  46. if info[2] == 'subject':
  47. subjects.append((info[0],info[1],info[1]+len(info[0])))
  48. elif info[2] == 'object':
  49. objects.append((info[0],info[1],info[1]+len(info[0])))
  50. print(subjects)
  51. print(objects)
  52. re_args = re_config.Args().get_parser()
  53. re_args.bert_dir = './model_hub/chinese-roberta-wwm-ext/'
  54. re_args.gpu_ids = "-1"
  55. re_args.num_tags = 5
  56. re_args.max_seq_len = 512
  57. trainer = re_main.Trainer(re_args, None, None, None)
  58. re_args.output_dir = './bert_re/checkpoints/'
  59. tokenizer = BertTokenizer.from_pretrained(re_args.bert_dir)
  60. process_data = transforme_re_data(subjects, objects, raw_text)
  61. label2id = {}
  62. id2label = {}
  63. with open('./data/dgre/re_mid_data/rels.txt','r',encoding='utf-8') as fp:
  64. labels = fp.read().strip().split('\n')
  65. for i,j in enumerate(labels):
  66. label2id[j] = i
  67. id2label[i] = j
  68. for data in process_data:
  69. relation = trainer.predict(tokenizer, data[0], id2label, re_args, data[1])
  70. print("==========================")
  71. print(raw_text)
  72. print("主体:", data[2][0])
  73. print("客体:", data[2][1])
  74. print("关系:", "".join(relation))
  75. def transforme_re_data(subjects, objects, text):
  76. # 遍历每一个主体和客体
  77. tmp_text = text
  78. process_data = []
  79. for sub in subjects:
  80. for obj in objects:
  81. if obj[0] in sub[0]:
  82. text = text[:sub[1]] + '&'*len(sub[0]) + text[sub[2]:]
  83. text = text[:obj[1]] + '%'*len(obj[0]) + text[obj[2]:]
  84. text = re.sub('&'*len(sub[0]),'#'+'&'*len(sub[0])+'#', text)
  85. text = re.sub('%'*len(obj[0]),'$'+'%'*len(obj[0])+'$', text)
  86. else:
  87. text = text[:obj[1]] + '%'*len(obj[0]) + text[obj[2]:]
  88. text = text[:sub[1]] + '&'*len(sub[0]) + text[sub[2]:]
  89. text = re.sub('%'*len(obj[0]),'$'+'%'*len(obj[0])+'$', text)
  90. text = re.sub('&'*len(sub[0]),'#'+'&'*len(sub[0])+'#', text)
  91. try:
  92. sub_re = re.search('&'*len(sub[0]), text)
  93. sub_re_span = sub_re.span()
  94. sub_re_start = sub_re_span[0]
  95. sub_re_end = sub_re_span[1]+1
  96. obj_res = re.search('%'*len(obj[0]), text)
  97. obj_re_span = obj_res.span()
  98. obj_re_start = obj_re_span[0]
  99. obj_re_end = obj_re_span[1]+1
  100. text = re.sub('&'*len(sub[0]),sub[0],text)
  101. text = re.sub('%'*len(obj[0]),obj[0],text)
  102. except Exception as e:
  103. print(e)
  104. continue
  105. process_data.append((text,[sub[1],sub[2],obj[1],obj[2]],(sub,obj)))
  106. # 恢复text
  107. text = tmp_text
  108. return process_data
  109. if __name__ == '__main__':
  110. # raw_texts = [
  111. # '明早起飞》是由明太鱼作词,满江作曲,戴娆演唱的一首歌曲',
  112. # '古董相机收藏与鉴赏》是由高继生、高峻岭编著,浙江科学技术出版社出版的一本书籍',
  113. # '谢顺光,男,祖籍江西都昌,出生于景德镇陶瓷世家',
  114. # ]
  115. raw_texts = [
  116. '故障现象:转向时有“咯噔”声原因分析:转向机与转向轴处缺油解决措施:向此处重新覆盖一层润滑脂后,故障消失',
  117. '1045号汽车故障报告故障现象打开点火开关,操作左前电动座椅开关,座椅6个方向均不动作故障原因六向电动座椅线束磨破搭铁修复方法包扎磨破线束,从新固定。',
  118. ]
  119. for raw_text in raw_texts:
  120. entities = get_ner_result(raw_text)
  121. get_re_result(entities, raw_text)