123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- import os
- import pickle
- import logging
- import codecs
- from transformers import BertTokenizer
- try:
- import bert_config
- from bert_re.utils import utils
- except Exception as e:
- import bert_config
- from utils import utils
- import numpy as np
- import json
- logger = logging.getLogger(__name__)
- class InputExample:
- def __init__(self, set_type, text, labels=None, ids=None):
- self.set_type = set_type
- self.text = text
- self.labels = labels
- self.ids = ids
- class BaseFeature:
- def __init__(self, token_ids, attention_masks, token_type_ids):
- # BERT 输入
- self.token_ids = token_ids
- self.attention_masks = attention_masks
- self.token_type_ids = token_type_ids
- class BertFeature(BaseFeature):
- def __init__(self, token_ids, attention_masks, token_type_ids, labels=None, ids=None):
- super(BertFeature, self).__init__(
- token_ids=token_ids,
- attention_masks=attention_masks,
- token_type_ids=token_type_ids)
- # labels
- self.labels = labels
- # ids
- self.ids = ids
- class Processor:
- @staticmethod
- def read_txt(file_path):
- with codecs.open(file_path,'r',encoding='utf-8') as f:
- raw_examples = f.read().strip()
- return raw_examples
- def get_examples(self, raw_examples, set_type):
- examples = []
- # 这里是从json数据中的字典中获取
- for line in raw_examples.split('\n'):
- line = line.split('\t')
- if len(line) == 6:
- labels = line[0]
- text = line[1]
- ids = [int(line[2]),int(line[3]),int(line[4]),int(line[5])]
- examples.append(InputExample(set_type=set_type,
- text=text,
- labels=labels,
- ids=ids))
- return examples
- def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer, max_seq_len, label2id):
- set_type = example.set_type
- raw_text = example.text
- labels = example.labels
- ids =example.ids
- # 文本元组
- callback_info = (raw_text,)
- callback_labels = label2id[labels]
- callback_info += (callback_labels,)
- labels = label2id[labels]
- # label_ids = label2id[labels]
- ids = [x for x in ids]
- tokens = [i for i in raw_text]
- encode_dict = tokenizer.encode_plus(text=tokens,
- add_special_tokens=True,
- max_length=max_seq_len,
- truncation='longest_first',
- padding="max_length",
- return_token_type_ids=True,
- return_attention_mask=True)
- token_ids = encode_dict['input_ids']
- attention_masks = encode_dict['attention_mask']
- token_type_ids = encode_dict['token_type_ids']
- if ex_idx < 3:
- decode_text = tokenizer.decode(np.array(token_ids)[np.where(np.array(attention_masks) == 1)[0]].tolist())
- logger.info(f"*** {set_type}_example-{ex_idx} ***")
- logger.info(f"text: {decode_text}")
- logger.info(f"token_ids: {token_ids}")
- logger.info(f"attention_masks: {attention_masks}")
- logger.info(f"token_type_ids: {token_type_ids}")
- logger.info(f"labels: {labels}")
- logger.info(f"ids:{ids}")
- feature = BertFeature(
- # bert inputs
- token_ids=token_ids,
- attention_masks=attention_masks,
- token_type_ids=token_type_ids,
- labels=labels,
- ids=ids
- )
- return feature, callback_info
- def convert_examples_to_features(examples, max_seq_len, bert_dir, label2id):
- tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
- features = []
- callback_info = []
- logger.info(f'Convert {len(examples)} examples to features')
- longer_count=0
- for i, example in enumerate(examples):
- ids = example.ids
- flag = False
- for x in ids:
- if x > max_seq_len - 1:
- longer_count += 1
- flag = True
- break
- if flag:
- continue
- feature, tmp_callback = convert_bert_example(
- ex_idx=i,
- example=example,
- max_seq_len=max_seq_len,
- tokenizer=tokenizer,
- label2id=label2id,
- )
- if feature is None:
- continue
- features.append(feature)
- callback_info.append(tmp_callback)
- logger.info(f'Build {len(features)} features')
- logger.info(f"超出最大长度的有:{longer_count}")
- out = (features,)
- if not len(callback_info):
- return out
- out += (callback_info,)
- return out
- def get_out(processor, txt_path, args, id2label, label2id, mode):
- raw_examples = processor.read_txt(txt_path)
- examples = processor.get_examples(raw_examples, mode)
- for i, example in enumerate(examples):
- print("==========================")
- print(example.text)
- print(example.labels)
- print(example.ids)
- print("==========================")
- if i == 5:
- break
- out = convert_examples_to_features(examples, args.max_seq_len, args.bert_dir, label2id)
- def save_pkl(data_dir, data, desc):
- """保存.pkl文件"""
- with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'wb') as f:
- pickle.dump(data, f)
- save_path = os.path.join(args.data_dir, 're_final_data')
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- save_pkl(save_path, out, mode)
- return out
- if __name__ == '__main__':
- data_name = "dgre"
- args = bert_config.Args().get_parser()
- args.log_dir = './logs/'
- args.bert_dir = '../model_hub/chinese-roberta-wwm-ext/'
- utils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))
- logger.info(vars(args))
- if data_name == "dgre":
- args.max_seq_len = 512
- args.data_dir = '../data/dgre/'
- re_mid_data_path = '../data/dgre/re_mid_data'
- elif data_name == "duie":
- args.max_seq_len = 300
- re_mid_data_path = '../data/re_mid_data'
- processor = Processor()
- label2id = {}
- id2label = {}
- with open(re_mid_data_path+'/rels.txt','r',encoding='utf-8') as fp:
- labels = fp.read().split('\n')
- for i,j in enumerate(labels):
- label2id[j] = i
- id2label[i] = j
- print(label2id)
- train_out = get_out(processor, re_mid_data_path+'/train.txt', args, id2label, label2id, 'train')
- dev_out = get_out(processor, re_mid_data_path+'/dev.txt', args, id2label, label2id, 'dev')
- test_out = get_out(processor, re_mid_data_path+'/dev.txt', args, id2label, label2id, 'test')
|