bert_base_model.py 1.3 KB

1234567891011121314151617181920212223242526272829303132
  1. import os
  2. import torch.nn as nn
  3. from transformers import BertModel
  4. class BaseModel(nn.Module):
  5. def __init__(self, bert_dir, dropout_prob):
  6. super(BaseModel, self).__init__()
  7. config_path = os.path.join(bert_dir, 'config.json')
  8. # assert os.path.exists(bert_dir) and os.path.exists(config_path),
  9. # 'pretrained bert file does not exist'
  10. self.bert_module = BertModel.from_pretrained(bert_dir, output_hidden_states=True,
  11. hidden_dropout_prob=dropout_prob)
  12. self.bert_config = self.bert_module.config
  13. @staticmethod
  14. def _init_weights(blocks, **kwargs):
  15. """
  16. 参数初始化,将 Linear / Embedding / LayerNorm 与 Bert 进行一样的初始化
  17. """
  18. for block in blocks:
  19. for module in block.modules():
  20. if isinstance(module, nn.Linear):
  21. if module.bias is not None:
  22. nn.init.zeros_(module.bias)
  23. elif isinstance(module, nn.Embedding):
  24. nn.init.normal_(module.weight, mean=0, std=kwargs.pop('initializer_range', 0.02))
  25. elif isinstance(module, nn.LayerNorm):
  26. nn.init.ones_(module.weight)
  27. nn.init.zeros_(module.bias)