dataset.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3. try:
  4. from preprocess import BertFeature
  5. from utils import commonUtils
  6. except Exception as e:
  7. from .preprocess import BertFeature
  8. from .utils import commonUtils
  9. else:
  10. from preprocess import BertFeature
  11. from utils import commonUtils
  12. class NerDataset(Dataset):
  13. def __init__(self, features):
  14. # self.callback_info = callback_info
  15. self.nums = len(features)
  16. self.token_ids = [torch.tensor(example.token_ids).long() for example in features]
  17. self.attention_masks = [torch.tensor(example.attention_masks, dtype=torch.uint8) for example in features]
  18. self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features]
  19. self.labels = [torch.tensor(example.labels).long() for example in features]
  20. def __len__(self):
  21. return self.nums
  22. def __getitem__(self, index):
  23. data = {
  24. 'token_ids': self.token_ids[index],
  25. 'attention_masks': self.attention_masks[index],
  26. 'token_type_ids': self.token_type_ids[index]
  27. }
  28. data['labels'] = self.labels[index]
  29. return data