dataset.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import json
  2. import torch
  3. from torch.utils.data import Dataset, DataLoader, RandomSampler
  4. # 这里要显示的引入BertFeature,不然会报错
  5. try:
  6. from bert_re.preprocess import BertFeature
  7. from bert_re.preprocess import get_out, Processor
  8. import bert_config
  9. except Exception as e:
  10. from .preprocess import BertFeature
  11. from .preprocess import get_out, Processor
  12. from . import bert_config
  13. class ReDataset(Dataset):
  14. def __init__(self, features):
  15. self.nums = len(features)
  16. self.features = features
  17. def __len__(self):
  18. return self.nums
  19. def __getitem__(self, index):
  20. example = self.features[index]
  21. data = {
  22. 'token_ids': torch.tensor(example.token_ids).long(),
  23. 'attention_masks': torch.tensor(example.attention_masks).float(),
  24. 'token_type_ids': torch.tensor(example.token_type_ids).long(),
  25. }
  26. data['labels'] = torch.tensor(example.labels).long()
  27. data['ids'] = torch.tensor(example.ids).long()
  28. return data
  29. if __name__ == '__main__':
  30. args = bert_config.Args().get_parser()
  31. args.log_dir = './logs/'
  32. args.max_seq_len = 300
  33. args.bert_dir = '../model_hub/chinese-roberta-wwm-ext/'
  34. processor = Processor()
  35. label2id = {}
  36. id2label = {}
  37. with open('../data/re_mid_data/rels.txt','r') as fp:
  38. labels = fp.read().strip().split('\n')
  39. for i,j in enumerate(labels):
  40. label2id[j] = i
  41. id2label[i] = j
  42. print(label2id)
  43. # train_out = get_out(processor, './data/train.txt', args, id2label, 'train')
  44. # dev_out = get_out(processor, './data/test.txt', args, id2label, 'dev')
  45. # test_out = get_out(processor, './data/test.txt', args, id2label, 'test')
  46. import pickle
  47. train_out = pickle.load(open('../data/re_final_data/train.pkl','rb'))
  48. train_features, train_callback_info = train_out
  49. train_dataset = ReDataset(train_features)
  50. for data in train_dataset:
  51. print(data['token_ids'])
  52. print(data['attention_masks'])
  53. print(data['token_type_ids'])
  54. print(data['labels'])
  55. print(data['ids'])
  56. break
  57. args.train_batch_size = 2
  58. train_dataset = ReDataset(train_features)
  59. train_sampler = RandomSampler(train_dataset)
  60. train_loader = DataLoader(dataset=train_dataset,
  61. batch_size=args.train_batch_size,
  62. sampler=train_sampler,
  63. num_workers=2)
  64. for step, train_data in enumerate(train_loader):
  65. print(train_data['token_ids'].shape)
  66. print(train_data['attention_masks'].shape)
  67. print(train_data['token_type_ids'].shape)
  68. print(train_data['labels'])
  69. print(train_data['ids'])
  70. break