preprocess.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import os
  2. import json
  3. import logging
  4. from transformers import BertTokenizer
  5. try:
  6. from utils import cutSentences, commonUtils
  7. import config
  8. except Exception as e:
  9. from .utils import cutSentences, commonUtils
  10. from . import config
  11. else:
  12. from utils import cutSentences, commonUtils
  13. import config
  14. logger = logging.getLogger(__name__)
  15. class InputExample:
  16. def __init__(self, set_type, text, subject_labels=None, object_labels=None):
  17. self.set_type = set_type
  18. self.text = text
  19. self.subject_labels = subject_labels
  20. self.object_labels = object_labels
  21. class BaseFeature:
  22. def __init__(self, token_ids, attention_masks, token_type_ids):
  23. # BERT 输入
  24. self.token_ids = token_ids
  25. self.attention_masks = attention_masks
  26. self.token_type_ids = token_type_ids
  27. class BertFeature(BaseFeature):
  28. def __init__(self, token_ids, attention_masks, token_type_ids, labels=None):
  29. super(BertFeature, self).__init__(
  30. token_ids=token_ids,
  31. attention_masks=attention_masks,
  32. token_type_ids=token_type_ids)
  33. # labels
  34. self.labels = labels
  35. class NerProcessor:
  36. def __init__(self, cut_sent=True, cut_sent_len=256):
  37. self.cut_sent = cut_sent
  38. self.cut_sent_len = cut_sent_len
  39. @staticmethod
  40. def read_json(file_path):
  41. with open(file_path, encoding='utf-8') as f:
  42. raw_examples = json.load(f)
  43. return raw_examples
  44. def get_examples(self, raw_examples, set_type):
  45. examples = []
  46. # 这里是从json数据中的字典中获取
  47. for i, item in enumerate(raw_examples):
  48. # print(i,item)
  49. text = item['text']
  50. if self.cut_sent:
  51. sentences = cutSentences.cut_sent_for_bert(text, self.cut_sent_len)
  52. start_index = 0
  53. for sent in sentences:
  54. labels = cutSentences.refactor_labels(sent, item['labels'], start_index)
  55. start_index += len(sent)
  56. examples.append(InputExample(set_type=set_type,
  57. text=sent,
  58. labels=labels))
  59. else:
  60. subject_labels = item['subject_labels']
  61. object_labels = item['object_labels']
  62. if len(subject_labels) != 0:
  63. subject_labels = [('subject',label[1],label[2]) for label in subject_labels]
  64. if len(object_labels) != 0:
  65. object_labels = [('object',label[1],label[2]) for label in object_labels]
  66. examples.append(InputExample(set_type=set_type,
  67. text=text,
  68. subject_labels=subject_labels,
  69. object_labels=object_labels))
  70. return examples
  71. def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
  72. max_seq_len, nerlabel2id, ent_labels):
  73. set_type = example.set_type
  74. raw_text = example.text
  75. subject_entities = example.subject_labels
  76. object_entities = example.object_labels
  77. entities = subject_entities + object_entities
  78. # 文本元组
  79. callback_info = (raw_text,)
  80. # 标签字典
  81. callback_labels = {x: [] for x in ent_labels}
  82. # _label:实体类别 实体名 实体起始位置
  83. for _label in entities:
  84. # print(_label)
  85. callback_labels[_label[0]].append((_label[0], _label[1]))
  86. callback_info += (callback_labels,)
  87. # 序列标注任务 BERT 分词器可能会导致标注偏
  88. # tokens = commonUtils.fine_grade_tokenize(raw_text, tokenizer)
  89. tokens = [i for i in raw_text]
  90. assert len(tokens) == len(raw_text)
  91. label_ids = None
  92. # information for dev callback
  93. # ========================
  94. label_ids = [0] * len(tokens)
  95. # tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
  96. for ent in entities:
  97. # ent: ('PER', '陈元', 0)
  98. ent_type = ent[0] # 类别
  99. ent_start = ent[-1] # 起始位置
  100. ent_end = ent_start + len(ent[1]) - 1
  101. if ent_start == ent_end:
  102. label_ids[ent_start] = nerlabel2id['B-' + ent_type]
  103. else:
  104. try:
  105. label_ids[ent_start] = nerlabel2id['B-' + ent_type]
  106. label_ids[ent_end] = nerlabel2id['I-' + ent_type]
  107. for i in range(ent_start + 1, ent_end):
  108. label_ids[i] = nerlabel2id['I-' + ent_type]
  109. except Exception as e:
  110. print(ent)
  111. print(tokens)
  112. import sys
  113. sys.exit(0)
  114. if len(label_ids) > max_seq_len - 2:
  115. label_ids = label_ids[:max_seq_len - 2]
  116. label_ids = [0] + label_ids + [0]
  117. # pad
  118. if len(label_ids) < max_seq_len:
  119. pad_length = max_seq_len - len(label_ids)
  120. label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O
  121. assert len(label_ids) == max_seq_len, f'{len(label_ids)}'
  122. # ========================
  123. encode_dict = tokenizer.encode_plus(text=tokens,
  124. max_length=max_seq_len,
  125. padding='max_length',
  126. truncation='longest_first',
  127. return_token_type_ids=True,
  128. return_attention_mask=True)
  129. tokens = ['[CLS]'] + tokens + ['[SEP]']
  130. token_ids = encode_dict['input_ids']
  131. attention_masks = encode_dict['attention_mask']
  132. token_type_ids = encode_dict['token_type_ids']
  133. if ex_idx < 3:
  134. logger.info(f"*** {set_type}_example-{ex_idx} ***")
  135. print(tokenizer.decode(token_ids[:len(raw_text)]))
  136. logger.info(f'text: {" ".join(tokens)}')
  137. logger.info(f"token_ids: {token_ids}")
  138. logger.info(f"attention_masks: {attention_masks}")
  139. logger.info(f"token_type_ids: {token_type_ids}")
  140. logger.info(f"labels: {label_ids}")
  141. logger.info('length: ' + str(len(token_ids)))
  142. # for word, token, attn, label in zip(tokens, token_ids, attention_masks, label_ids):
  143. # print(word + ' ' + str(token) + ' ' + str(attn) + ' ' + str(label))
  144. feature = BertFeature(
  145. # bert inputs
  146. token_ids=token_ids,
  147. attention_masks=attention_masks,
  148. token_type_ids=token_type_ids,
  149. labels=label_ids,
  150. )
  151. return feature, callback_info
  152. def convert_examples_to_features(examples, max_seq_len, bert_dir, nerlabel2id, ent_labels):
  153. tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
  154. features = []
  155. callback_info = []
  156. logger.info(f'Convert {len(examples)} examples to features')
  157. for i, example in enumerate(examples):
  158. """
  159. subject_entities = example.subject_labels
  160. object_entities = example.object_labels
  161. entities = subject_entities + object_entities
  162. flag = False
  163. for ent in entities:
  164. start_id = ent[1]
  165. end_id = len(ent[0]) + ent[1]
  166. if start_id >= max_seq_len - 2 or end_id >= max_seq_len - 2:
  167. flag = True
  168. break
  169. if flag:
  170. continue
  171. """
  172. feature, tmp_callback = convert_bert_example(
  173. ex_idx=i,
  174. example=example,
  175. max_seq_len=max_seq_len,
  176. nerlabel2id=nerlabel2id,
  177. tokenizer=tokenizer,
  178. ent_labels=ent_labels,
  179. )
  180. if feature is None:
  181. continue
  182. features.append(feature)
  183. callback_info.append(tmp_callback)
  184. logger.info(f'Build {len(features)} features')
  185. out = (features,)
  186. if not len(callback_info):
  187. return out
  188. out += (callback_info,)
  189. return out
  190. def get_data(processor, raw_data_path, json_file, mode, nerlabel2id, ent_slabels, args):
  191. raw_examples = processor.read_json(os.path.join(raw_data_path, json_file))
  192. examples = processor.get_examples(raw_examples, mode)
  193. data = convert_examples_to_features(examples, args.max_seq_len, args.bert_dir, nerlabel2id, ent_labels)
  194. save_path = os.path.join(args.data_dir, 'ner_final_data')
  195. if not os.path.exists(save_path):
  196. os.makedirs(save_path)
  197. commonUtils.save_pkl(save_path, data, mode)
  198. return data
  199. def save_file(filename, data, id2nerlabel):
  200. features, callback_info = data
  201. file = open(filename,'w',encoding='utf-8')
  202. for feature,tmp_callback in zip(features, callback_info):
  203. text, gt_entities = tmp_callback
  204. for word, label in zip(text, feature.labels[1:len(text)+1]):
  205. file.write(word + ' ' + id2nerlabel[label] + '\n')
  206. file.write('\n')
  207. file.close()
  208. if __name__ == '__main__':
  209. dataset = "dgre"
  210. args = config.Args().get_parser()
  211. args.bert_dir = '../model_hub/chinese-roberta-wwm-ext/'
  212. commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))
  213. if dataset == "dgre":
  214. args.data_dir = '../data/dgre/'
  215. args.max_seq_len = 512
  216. elif dataset == "duie":
  217. args.data_dir = '../data/'
  218. args.max_seq_len = 300
  219. mid_data_path = os.path.join(args.data_dir, 'mid_data')
  220. # 真实标签
  221. ent_labels_path = mid_data_path + '/ent_labels.txt'
  222. # 序列标注标签B I O
  223. ner_labels_path = mid_data_path + '/ner_labels.txt'
  224. with open(ent_labels_path, 'r') as fp:
  225. ent_labels = fp.read().strip().split('\n')
  226. entlabel2id = {}
  227. id2entlabel = {}
  228. for i,j in enumerate(ent_labels):
  229. entlabel2id[j] = i
  230. id2entlabel[i] = j
  231. nerlabel2id = {}
  232. id2nerlabel = {}
  233. with open(ner_labels_path,'r') as fp:
  234. ner_labels = fp.read().strip().split('\n')
  235. for i,j in enumerate(ner_labels):
  236. nerlabel2id[j] = i
  237. id2nerlabel[i] = j
  238. processor = NerProcessor(cut_sent=False, cut_sent_len=args.max_seq_len)
  239. train_data = get_data(processor, mid_data_path, "train.json", "train", nerlabel2id, ent_labels, args)
  240. save_file(os.path.join(args.data_dir,"{}_{}_cut.txt".format(dataset, args.max_seq_len)), train_data, id2nerlabel)
  241. dev_data = get_data(processor, mid_data_path, "dev.json", "dev", nerlabel2id, ent_labels, args)