preprocess.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import pickle
  3. import logging
  4. import codecs
  5. from transformers import BertTokenizer
  6. try:
  7. import bert_config
  8. from bert_re.utils import utils
  9. except Exception as e:
  10. import bert_config
  11. from utils import utils
  12. import numpy as np
  13. import json
  14. logger = logging.getLogger(__name__)
  15. class InputExample:
  16. def __init__(self, set_type, text, labels=None, ids=None):
  17. self.set_type = set_type
  18. self.text = text
  19. self.labels = labels
  20. self.ids = ids
  21. class BaseFeature:
  22. def __init__(self, token_ids, attention_masks, token_type_ids):
  23. # BERT 输入
  24. self.token_ids = token_ids
  25. self.attention_masks = attention_masks
  26. self.token_type_ids = token_type_ids
  27. class BertFeature(BaseFeature):
  28. def __init__(self, token_ids, attention_masks, token_type_ids, labels=None, ids=None):
  29. super(BertFeature, self).__init__(
  30. token_ids=token_ids,
  31. attention_masks=attention_masks,
  32. token_type_ids=token_type_ids)
  33. # labels
  34. self.labels = labels
  35. # ids
  36. self.ids = ids
  37. class Processor:
  38. @staticmethod
  39. def read_txt(file_path):
  40. with codecs.open(file_path,'r',encoding='utf-8') as f:
  41. raw_examples = f.read().strip()
  42. return raw_examples
  43. def get_examples(self, raw_examples, set_type):
  44. examples = []
  45. # 这里是从json数据中的字典中获取
  46. for line in raw_examples.split('\n'):
  47. line = line.split('\t')
  48. if len(line) == 6:
  49. labels = line[0]
  50. text = line[1]
  51. ids = [int(line[2]),int(line[3]),int(line[4]),int(line[5])]
  52. examples.append(InputExample(set_type=set_type,
  53. text=text,
  54. labels=labels,
  55. ids=ids))
  56. return examples
  57. def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, max_seq_len, label2id):
  58. set_type = example.set_type
  59. raw_text = example.text
  60. labels = example.labels
  61. ids =example.ids
  62. # 文本元组
  63. callback_info = (raw_text,)
  64. callback_labels = label2id[labels]
  65. callback_info += (callback_labels,)
  66. labels = label2id[labels]
  67. # label_ids = label2id[labels]
  68. ids = [x for x in ids]
  69. tokens = [i for i in raw_text]
  70. encode_dict = tokenizer.encode_plus(text=tokens,
  71. add_special_tokens=True,
  72. max_length=max_seq_len,
  73. truncation='longest_first',
  74. padding="max_length",
  75. return_token_type_ids=True,
  76. return_attention_mask=True)
  77. token_ids = encode_dict['input_ids']
  78. attention_masks = encode_dict['attention_mask']
  79. token_type_ids = encode_dict['token_type_ids']
  80. if ex_idx < 3:
  81. decode_text = tokenizer.decode(np.array(token_ids)[np.where(np.array(attention_masks) == 1)[0]].tolist())
  82. logger.info(f"*** {set_type}_example-{ex_idx} ***")
  83. logger.info(f"text: {decode_text}")
  84. logger.info(f"token_ids: {token_ids}")
  85. logger.info(f"attention_masks: {attention_masks}")
  86. logger.info(f"token_type_ids: {token_type_ids}")
  87. logger.info(f"labels: {labels}")
  88. logger.info(f"ids:{ids}")
  89. feature = BertFeature(
  90. # bert inputs
  91. token_ids=token_ids,
  92. attention_masks=attention_masks,
  93. token_type_ids=token_type_ids,
  94. labels=labels,
  95. ids=ids
  96. )
  97. return feature, callback_info
  98. def convert_examples_to_features(examples, max_seq_len, bert_dir, label2id):
  99. tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
  100. features = []
  101. callback_info = []
  102. logger.info(f'Convert {len(examples)} examples to features')
  103. longer_count=0
  104. for i, example in enumerate(examples):
  105. ids = example.ids
  106. flag = False
  107. for x in ids:
  108. if x > max_seq_len - 1:
  109. longer_count += 1
  110. flag = True
  111. break
  112. if flag:
  113. continue
  114. feature, tmp_callback = convert_bert_example(
  115. ex_idx=i,
  116. example=example,
  117. max_seq_len=max_seq_len,
  118. tokenizer=tokenizer,
  119. label2id=label2id,
  120. )
  121. if feature is None:
  122. continue
  123. features.append(feature)
  124. callback_info.append(tmp_callback)
  125. logger.info(f'Build {len(features)} features')
  126. logger.info(f"超出最大长度的有:{longer_count}")
  127. out = (features,)
  128. if not len(callback_info):
  129. return out
  130. out += (callback_info,)
  131. return out
  132. def get_out(processor, txt_path, args, id2label, label2id, mode):
  133. raw_examples = processor.read_txt(txt_path)
  134. examples = processor.get_examples(raw_examples, mode)
  135. for i, example in enumerate(examples):
  136. print("==========================")
  137. print(example.text)
  138. print(example.labels)
  139. print(example.ids)
  140. print("==========================")
  141. if i == 5:
  142. break
  143. out = convert_examples_to_features(examples, args.max_seq_len, args.bert_dir, label2id)
  144. def save_pkl(data_dir, data, desc):
  145. """保存.pkl文件"""
  146. with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'wb') as f:
  147. pickle.dump(data, f)
  148. save_path = os.path.join(args.data_dir, 're_final_data')
  149. if not os.path.exists(save_path):
  150. os.makedirs(save_path)
  151. save_pkl(save_path, out, mode)
  152. return out
  153. if __name__ == '__main__':
  154. data_name = "dgre"
  155. args = bert_config.Args().get_parser()
  156. args.log_dir = './logs/'
  157. args.bert_dir = '../model_hub/chinese-roberta-wwm-ext/'
  158. utils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))
  159. logger.info(vars(args))
  160. if data_name == "dgre":
  161. args.max_seq_len = 512
  162. args.data_dir = '../data/dgre/'
  163. re_mid_data_path = '../data/dgre/re_mid_data'
  164. elif data_name == "duie":
  165. args.max_seq_len = 300
  166. re_mid_data_path = '../data/re_mid_data'
  167. processor = Processor()
  168. label2id = {}
  169. id2label = {}
  170. with open(re_mid_data_path+'/rels.txt','r',encoding='utf-8') as fp:
  171. labels = fp.read().split('\n')
  172. for i,j in enumerate(labels):
  173. label2id[j] = i
  174. id2label[i] = j
  175. print(label2id)
  176. train_out = get_out(processor, re_mid_data_path+'/train.txt', args, id2label, label2id, 'train')
  177. dev_out = get_out(processor, re_mid_data_path+'/dev.txt', args, id2label, label2id, 'dev')
  178. test_out = get_out(processor, re_mid_data_path+'/dev.txt', args, id2label, label2id, 'test')