models.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from transformers import BertModel
  2. import torch.nn as nn
  3. import numpy as np
  4. import torch
  5. try:
  6. import bert_config
  7. except Exception as e:
  8. import bert_config
  9. import json
  10. from torch.utils.data import Dataset, DataLoader, RandomSampler
  11. # 这里要显示的引入BertFeature,不然会报错
  12. try:
  13. from bert_re.dataset import ReDataset
  14. from preprocess import BertFeature
  15. from bert_re.preprocess import get_out, Processor
  16. except Exception as e:
  17. from bert_re.dataset import ReDataset
  18. from bert_re.preprocess import BertFeature
  19. from bert_re.preprocess import get_out, Processor
  20. class BertForRelationExtraction(nn.Module):
  21. def __init__(self, args):
  22. super(BertForRelationExtraction, self).__init__()
  23. self.bert = BertModel.from_pretrained(args.bert_dir)
  24. self.bert_config = self.bert.config
  25. out_dims = self.bert_config.hidden_size
  26. self.dropout = nn.Dropout(args.dropout_prob)
  27. self.linear = nn.Linear(out_dims * 4, args.num_tags)
  28. def forward(self, token_ids, attention_masks, token_type_ids, ids):
  29. bert_outputs = self.bert(
  30. input_ids = token_ids,
  31. attention_mask = attention_masks,
  32. token_type_ids = token_type_ids,
  33. )
  34. # 获取每一个token的嵌入
  35. seq_out = bert_outputs[0] # [batchsize, max_len, 768]
  36. batch_size = seq_out.size(0)
  37. seq_ent = torch.cat([torch.index_select(seq_out[i,:,:],0,ids[i,:].long()).unsqueeze(0) for i in range(batch_size)], 0) # [batchsize, 4, 768]
  38. # print(seq_ent.shape)
  39. seq_ent = self.dropout(seq_ent)
  40. seq_ent = seq_ent.view(batch_size, -1)
  41. seq_ent = self.linear(seq_ent)
  42. return seq_ent
  43. if __name__ == '__main__':
  44. """更换了新的数据读取的方式,以下不适用了"""
  45. args = bert_config.Args().get_parser()
  46. args.log_dir = './logs/'
  47. args.max_seq_len = 128
  48. args.bert_dir = '../model_hub/bert-base-chinese/'
  49. processor = Processor()
  50. label2id = {}
  51. id2label = {}
  52. with open('./data/rel_dict.json', 'r',encoding='utf-8') as fp:
  53. labels = json.loads(fp.read())
  54. for k, v in labels.items():
  55. label2id[k] = v
  56. id2label[v] = k
  57. print(label2id)
  58. train_out = get_out(processor, './data/train.txt', args, id2label, 'train')
  59. dev_out = get_out(processor, './data/test.txt', args, id2label, 'dev')
  60. test_out = get_out(processor, './data/test.txt', args, id2label, 'test')
  61. # import pickle
  62. # with open('./data/cnews/final_data/train.pickle','wb') as fp:
  63. # pickle.dump(train_out, fp)
  64. # with open('./data/cnews/final_data/dev.pickle','wb') as fp:
  65. # pickle.dump(dev_out, fp)
  66. # with open('./data/cnews/final_data/test.pickle','wb') as fp:
  67. # pickle.dump(test_out, fp)
  68. #
  69. # train_out = pickle.load(open('./data/cnews/final_data/dev.pickle','rb'))
  70. train_features, train_callback_info = train_out
  71. train_dataset = ReDataset(train_features)
  72. # for data in train_dataset:
  73. # print(data['token_ids'])
  74. # print(data['attention_masks'])
  75. # print(data['token_type_ids'])
  76. # print(data['labels'])
  77. # print(data['ids'])
  78. # break
  79. args.train_batch_size = 32
  80. train_dataset = ReDataset(train_features)
  81. train_sampler = RandomSampler(train_dataset)
  82. train_loader = DataLoader(dataset=train_dataset,
  83. batch_size=args.train_batch_size,
  84. sampler=train_sampler,
  85. num_workers=2)
  86. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  87. model = BertForRelationExtraction(args, device)
  88. model.to(device)
  89. model.train()
  90. for step, train_data in enumerate(train_loader):
  91. # print(train_data['token_ids'].shape)
  92. # print(train_data['attention_masks'].shape)
  93. # print(train_data['token_type_ids'].shape)
  94. # print(train_data['labels'])
  95. # print(train_data['ids'].shape)
  96. token_ids = train_data['token_ids'].to(device)
  97. attention_masks = train_data['attention_masks'].to(device)
  98. token_type_ids = train_data['token_type_ids'].to(device)
  99. label_ids = train_data['labels'].to(device)
  100. ids = train_data['ids'].to(device)
  101. output = model(token_ids,attention_masks,token_type_ids,ids)
  102. if step == 1:
  103. break