123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import os
- import json
- import logging
- from transformers import BertTokenizer
- try:
- from utils import cutSentences, commonUtils
- import config
- except Exception as e:
- from .utils import cutSentences, commonUtils
- from . import config
- else:
- from utils import cutSentences, commonUtils
- import config
- logger = logging.getLogger(__name__)
- class InputExample:
- def __init__(self, set_type, text, subject_labels=None, object_labels=None):
- self.set_type = set_type
- self.text = text
- self.subject_labels = subject_labels
- self.object_labels = object_labels
- class BaseFeature:
- def __init__(self, token_ids, attention_masks, token_type_ids):
- # BERT 输入
- self.token_ids = token_ids
- self.attention_masks = attention_masks
- self.token_type_ids = token_type_ids
- class BertFeature(BaseFeature):
- def __init__(self, token_ids, attention_masks, token_type_ids, labels=None):
- super(BertFeature, self).__init__(
- token_ids=token_ids,
- attention_masks=attention_masks,
- token_type_ids=token_type_ids)
- # labels
- self.labels = labels
- class NerProcessor:
- def __init__(self, cut_sent=True, cut_sent_len=256):
- self.cut_sent = cut_sent
- self.cut_sent_len = cut_sent_len
- @staticmethod
- def read_json(file_path):
- with open(file_path, encoding='utf-8') as f:
- raw_examples = json.load(f)
- return raw_examples
- def get_examples(self, raw_examples, set_type):
- examples = []
- # 这里是从json数据中的字典中获取
- for i, item in enumerate(raw_examples):
- # print(i,item)
- text = item['text']
- if self.cut_sent:
- sentences = cutSentences.cut_sent_for_bert(text, self.cut_sent_len)
- start_index = 0
- for sent in sentences:
- labels = cutSentences.refactor_labels(sent, item['labels'], start_index)
- start_index += len(sent)
- examples.append(InputExample(set_type=set_type,
- text=sent,
- labels=labels))
- else:
- subject_labels = item['subject_labels']
- object_labels = item['object_labels']
- if len(subject_labels) != 0:
- subject_labels = [('subject',label[1],label[2]) for label in subject_labels]
- if len(object_labels) != 0:
- object_labels = [('object',label[1],label[2]) for label in object_labels]
- examples.append(InputExample(set_type=set_type,
- text=text,
- subject_labels=subject_labels,
- object_labels=object_labels))
- return examples
- def convert_bert_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
- max_seq_len, nerlabel2id, ent_labels):
- set_type = example.set_type
- raw_text = example.text
- subject_entities = example.subject_labels
- object_entities = example.object_labels
- entities = subject_entities + object_entities
- # 文本元组
- callback_info = (raw_text,)
- # 标签字典
- callback_labels = {x: [] for x in ent_labels}
- # _label:实体类别 实体名 实体起始位置
- for _label in entities:
- # print(_label)
- callback_labels[_label[0]].append((_label[0], _label[1]))
- callback_info += (callback_labels,)
- # 序列标注任务 BERT 分词器可能会导致标注偏
- # tokens = commonUtils.fine_grade_tokenize(raw_text, tokenizer)
- tokens = [i for i in raw_text]
- assert len(tokens) == len(raw_text)
- label_ids = None
- # information for dev callback
- # ========================
- label_ids = [0] * len(tokens)
- # tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
- for ent in entities:
-
- # ent: ('PER', '陈元', 0)
- ent_type = ent[0] # 类别
- ent_start = ent[-1] # 起始位置
- ent_end = ent_start + len(ent[1]) - 1
- if ent_start == ent_end:
- label_ids[ent_start] = nerlabel2id['B-' + ent_type]
- else:
- try:
- label_ids[ent_start] = nerlabel2id['B-' + ent_type]
- label_ids[ent_end] = nerlabel2id['I-' + ent_type]
- for i in range(ent_start + 1, ent_end):
- label_ids[i] = nerlabel2id['I-' + ent_type]
- except Exception as e:
- print(ent)
- print(tokens)
- import sys
- sys.exit(0)
- if len(label_ids) > max_seq_len - 2:
- label_ids = label_ids[:max_seq_len - 2]
- label_ids = [0] + label_ids + [0]
- # pad
- if len(label_ids) < max_seq_len:
- pad_length = max_seq_len - len(label_ids)
- label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O
- assert len(label_ids) == max_seq_len, f'{len(label_ids)}'
- # ========================
- encode_dict = tokenizer.encode_plus(text=tokens,
- max_length=max_seq_len,
- padding='max_length',
- truncation='longest_first',
- return_token_type_ids=True,
- return_attention_mask=True)
- tokens = ['[CLS]'] + tokens + ['[SEP]']
- token_ids = encode_dict['input_ids']
- attention_masks = encode_dict['attention_mask']
- token_type_ids = encode_dict['token_type_ids']
- if ex_idx < 3:
- logger.info(f"*** {set_type}_example-{ex_idx} ***")
- print(tokenizer.decode(token_ids[:len(raw_text)]))
- logger.info(f'text: {" ".join(tokens)}')
- logger.info(f"token_ids: {token_ids}")
- logger.info(f"attention_masks: {attention_masks}")
- logger.info(f"token_type_ids: {token_type_ids}")
- logger.info(f"labels: {label_ids}")
- logger.info('length: ' + str(len(token_ids)))
- # for word, token, attn, label in zip(tokens, token_ids, attention_masks, label_ids):
- # print(word + ' ' + str(token) + ' ' + str(attn) + ' ' + str(label))
- feature = BertFeature(
- # bert inputs
- token_ids=token_ids,
- attention_masks=attention_masks,
- token_type_ids=token_type_ids,
- labels=label_ids,
- )
- return feature, callback_info
- def convert_examples_to_features(examples, max_seq_len, bert_dir, nerlabel2id, ent_labels):
- tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
- features = []
- callback_info = []
- logger.info(f'Convert {len(examples)} examples to features')
- for i, example in enumerate(examples):
- """
- subject_entities = example.subject_labels
- object_entities = example.object_labels
- entities = subject_entities + object_entities
- flag = False
- for ent in entities:
- start_id = ent[1]
- end_id = len(ent[0]) + ent[1]
- if start_id >= max_seq_len - 2 or end_id >= max_seq_len - 2:
- flag = True
- break
- if flag:
- continue
- """
- feature, tmp_callback = convert_bert_example(
- ex_idx=i,
- example=example,
- max_seq_len=max_seq_len,
- nerlabel2id=nerlabel2id,
- tokenizer=tokenizer,
- ent_labels=ent_labels,
- )
- if feature is None:
- continue
- features.append(feature)
- callback_info.append(tmp_callback)
- logger.info(f'Build {len(features)} features')
- out = (features,)
- if not len(callback_info):
- return out
- out += (callback_info,)
- return out
- def get_data(processor, raw_data_path, json_file, mode, nerlabel2id, ent_slabels, args):
- raw_examples = processor.read_json(os.path.join(raw_data_path, json_file))
- examples = processor.get_examples(raw_examples, mode)
- data = convert_examples_to_features(examples, args.max_seq_len, args.bert_dir, nerlabel2id, ent_labels)
- save_path = os.path.join(args.data_dir, 'ner_final_data')
- if not os.path.exists(save_path):
- os.makedirs(save_path)
- commonUtils.save_pkl(save_path, data, mode)
- return data
- def save_file(filename, data, id2nerlabel):
- features, callback_info = data
- file = open(filename,'w',encoding='utf-8')
- for feature,tmp_callback in zip(features, callback_info):
- text, gt_entities = tmp_callback
- for word, label in zip(text, feature.labels[1:len(text)+1]):
- file.write(word + ' ' + id2nerlabel[label] + '\n')
- file.write('\n')
- file.close()
- if __name__ == '__main__':
- dataset = "dgre"
- args = config.Args().get_parser()
- args.bert_dir = '../model_hub/chinese-roberta-wwm-ext/'
- commonUtils.set_logger(os.path.join(args.log_dir, 'preprocess.log'))
- if dataset == "dgre":
- args.data_dir = '../data/dgre/'
- args.max_seq_len = 512
- elif dataset == "duie":
- args.data_dir = '../data/'
- args.max_seq_len = 300
- mid_data_path = os.path.join(args.data_dir, 'mid_data')
- # 真实标签
- ent_labels_path = mid_data_path + '/ent_labels.txt'
- # 序列标注标签B I O
- ner_labels_path = mid_data_path + '/ner_labels.txt'
- with open(ent_labels_path, 'r') as fp:
- ent_labels = fp.read().strip().split('\n')
- entlabel2id = {}
- id2entlabel = {}
- for i,j in enumerate(ent_labels):
- entlabel2id[j] = i
- id2entlabel[i] = j
- nerlabel2id = {}
- id2nerlabel = {}
- with open(ner_labels_path,'r') as fp:
- ner_labels = fp.read().strip().split('\n')
- for i,j in enumerate(ner_labels):
- nerlabel2id[j] = i
- id2nerlabel[i] = j
- processor = NerProcessor(cut_sent=False, cut_sent_len=args.max_seq_len)
- train_data = get_data(processor, mid_data_path, "train.json", "train", nerlabel2id, ent_labels, args)
- save_file(os.path.join(args.data_dir,"{}_{}_cut.txt".format(dataset, args.max_seq_len)), train_data, id2nerlabel)
- dev_data = get_data(processor, mid_data_path, "dev.json", "dev", nerlabel2id, ent_labels, args)
|