bert_ner_model.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import torch
  2. import torch.nn as nn
  3. from torchcrf import CRF
  4. # import config
  5. try:
  6. from bert_base_model import BaseModel
  7. except Exception as e:
  8. from bert_base_model import BaseModel
  9. class BertNerModel(BaseModel):
  10. def __init__(self,
  11. args,
  12. **kwargs):
  13. super(BertNerModel, self).__init__(bert_dir=args.bert_dir, dropout_prob=args.dropout_prob)
  14. self.args = args
  15. self.num_layers = args.num_layers
  16. self.lstm_hidden = args.lstm_hidden
  17. gpu_ids = args.gpu_ids.split(',')
  18. device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
  19. self.device = device
  20. out_dims = self.bert_config.hidden_size
  21. if args.use_lstm == 'True':
  22. self.lstm = nn.LSTM(out_dims, args.lstm_hidden, args.num_layers, bidirectional=True,batch_first=True, dropout=args.dropout).to(device)
  23. self.linear = nn.Linear(args.lstm_hidden * 2, args.num_tags).to(device)
  24. self.criterion = nn.CrossEntropyLoss().to(device)
  25. init_blocks = [self.linear]
  26. # init_blocks = [self.classifier]
  27. self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)
  28. else:
  29. mid_linear_dims = kwargs.pop('mid_linear_dims', 256)
  30. self.mid_linear = nn.Sequential(
  31. nn.Linear(out_dims, mid_linear_dims),
  32. nn.ReLU(),
  33. nn.Dropout(args.dropout)).to(device)
  34. #
  35. out_dims = mid_linear_dims
  36. # self.dropout = nn.Dropout(dropout_prob)
  37. self.classifier = nn.Linear(out_dims, args.num_tags).to(device)
  38. # self.criterion = nn.CrossEntropyLoss(reduction='none')
  39. self.criterion = nn.CrossEntropyLoss().to(device)
  40. init_blocks = [self.mid_linear, self.classifier]
  41. # init_blocks = [self.classifier]
  42. self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)
  43. if args.use_crf == 'True':
  44. self.crf = CRF(args.num_tags, batch_first=True).to(device)
  45. def init_hidden(self, batch_size):
  46. h0 = torch.randn(2 * self.num_layers, batch_size, self.lstm_hidden, requires_grad=True).to(self.device)
  47. c0 = torch.randn(2 * self.num_layers, batch_size, self.lstm_hidden, requires_grad=True).to(self.device)
  48. return h0, c0
  49. def forward(self,
  50. token_ids,
  51. attention_masks,
  52. token_type_ids,
  53. labels):
  54. bert_outputs = self.bert_module(
  55. input_ids=token_ids,
  56. attention_mask=attention_masks,
  57. token_type_ids=token_type_ids
  58. )
  59. # 常规
  60. seq_out = bert_outputs[0] # [batchsize, max_len, 768]
  61. batch_size = seq_out.size(0)
  62. if self.args.use_lstm == 'True':
  63. hidden = self.init_hidden(batch_size)
  64. seq_out, (hn, _) = self.lstm(seq_out, hidden)
  65. seq_out = seq_out.contiguous().view(-1, self.lstm_hidden * 2)
  66. seq_out = self.linear(seq_out)
  67. seq_out = seq_out.contiguous().view(batch_size, self.args.max_seq_len, -1) #[batchsize, max_len, num_tags]
  68. else:
  69. seq_out = self.mid_linear(seq_out) # [batchsize, max_len, 256]
  70. # seq_out = self.dropout(seq_out)
  71. seq_out = self.classifier(seq_out) # [24, 256, 53]
  72. if self.args.use_crf == 'True':
  73. logits = self.crf.decode(seq_out, mask=attention_masks)
  74. if labels is None:
  75. return logits
  76. loss = -self.crf(seq_out, labels, mask=attention_masks, reduction='mean')
  77. outputs = (loss, ) + (logits,)
  78. return outputs
  79. else:
  80. logits = seq_out
  81. if labels is None:
  82. return logits
  83. active_loss = attention_masks.view(-1) == 1
  84. active_logits = logits.view(-1, logits.size()[2])[active_loss]
  85. active_labels = labels.view(-1)[active_loss]
  86. loss = self.criterion(active_logits, active_labels)
  87. outputs = (loss,) + (logits,)
  88. return outputs
  89. if __name__ == '__main__':
  90. args = config.Args().get_parser()
  91. args.num_tags = 33
  92. args.use_lstm = 'True'
  93. args.use_crf = 'True'
  94. model = BertNerModel(args)
  95. for name,weight in model.named_parameters():
  96. print(name)