data_loader.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import json
  2. import torch
  3. import numpy as np
  4. from torch.utils.data import DataLoader, Dataset
  5. try:
  6. from utils.utils import sequence_padding
  7. except:
  8. from .utils.utils import sequence_padding
  9. class ListDataset(Dataset):
  10. def __init__(self, file_path=None, data=None, **kwargs):
  11. self.kwargs = kwargs
  12. if isinstance(file_path, (str, list)):
  13. self.data = self.load_data(file_path)
  14. elif isinstance(data, list):
  15. self.data = data
  16. else:
  17. raise ValueError('The input args shall be str format file_path / list format dataset')
  18. def __len__(self):
  19. return len(self.data)
  20. def __getitem__(self, index):
  21. return self.data[index]
  22. @staticmethod
  23. def load_data(file_path):
  24. return file_path
  25. # 加载数据集
  26. class MyDataset(ListDataset):
  27. @staticmethod
  28. def load_data(filename):
  29. examples = []
  30. with open(filename, encoding='utf-8') as f:
  31. raw_examples = f.readlines()
  32. # 这里是从json数据中的字典中获取
  33. for i, item in enumerate(raw_examples):
  34. print(i,item)
  35. item = item.strip().split('\t')
  36. if len(item) < 5:
  37. continue
  38. text = item[1]
  39. labels = item[0]
  40. ids = item[2:6]
  41. examples.append((text, labels, ids)) # 注意,这里的ids里面的索引已经加上了CLS
  42. return examples
  43. class Collate:
  44. def __init__(self, max_len, tag2id, device, tokenizer):
  45. self.maxlen = max_len
  46. self.tag2id = tag2id
  47. self.id2tag = {v:k for k,v in tag2id.items()}
  48. self.device = device
  49. self.tokenizer = tokenizer
  50. def collate_fn(self, batch):
  51. batch_labels = []
  52. batch_ids = []
  53. batch_token_ids = []
  54. batch_attention_mask = []
  55. batch_token_type_ids = []
  56. callback = []
  57. for i, (text, label ,ids) in enumerate(batch):
  58. if len(text) > self.maxlen - 2:
  59. text = text[:self.maxlen - 2]
  60. tokens = [i for i in text]
  61. tokens = ['[CLS]'] + tokens + ['[SEP]']
  62. # 过滤掉超过文本最大长度的
  63. flag = False
  64. for j in ids:
  65. if int(j) > self.maxlen - 2:
  66. flag = True
  67. break
  68. if flag:
  69. continue
  70. token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
  71. batch_token_ids.append(token_ids) # 前面已经限制了长度
  72. batch_attention_mask.append([1] * len(token_ids))
  73. batch_token_type_ids.append([0] * len(token_ids))
  74. batch_labels.append(int(self.tag2id[label]))
  75. batch_ids.append([int(m) for m in ids])
  76. callback.append((text, label))
  77. batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=self.maxlen), dtype=torch.long, device=self.device)
  78. attention_mask = torch.tensor(sequence_padding(batch_attention_mask, length=self.maxlen), dtype=torch.long, device=self.device)
  79. token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids, length=self.maxlen), dtype=torch.long, device=self.device)
  80. batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=self.device)
  81. batch_ids = torch.tensor(batch_ids, dtype=torch.long, device=self.device)
  82. return batch_token_ids, attention_mask, token_type_ids, batch_labels, batch_ids, callback
  83. if __name__ == "__main__":
  84. from transformers import BertTokenizer
  85. max_len = 300
  86. tokenizer = BertTokenizer.from_pretrained('../model_hub/chinese-roberta-wwm-ext/vocab.txt')
  87. train_dataset = MyDataset(file_path='../data/dgre/re_mid_data/train.txt')
  88. # print(train_dataset[0])
  89. with open('../data/dgre/re_mid_data/rels.txt', 'r',encoding='utf-8') as fp:
  90. labels = fp.read().split('\n')
  91. id2tag = {}
  92. tag2id = {}
  93. for i,label in enumerate(labels):
  94. id2tag[i] = label
  95. tag2id[label] = i
  96. print(tag2id)
  97. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  98. collate = Collate(max_len=max_len, tag2id=tag2id, device=device, tokenizer=tokenizer)
  99. collate.collate_fn(train_dataset[:16])
  100. batch_size = 2
  101. train_dataset = train_dataset[:10]
  102. train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate.collate_fn)
  103. for i, batch in enumerate(train_dataloader):
  104. leng = len(batch) - 1
  105. for j in range(leng):
  106. print(batch[j].shape)
  107. break