import torch from torch.utils.data import Dataset, DataLoader try: from preprocess import BertFeature from utils import commonUtils except Exception as e: from .preprocess import BertFeature from .utils import commonUtils else: from preprocess import BertFeature from utils import commonUtils class NerDataset(Dataset): def __init__(self, features): # self.callback_info = callback_info self.nums = len(features) self.token_ids = [torch.tensor(example.token_ids).long() for example in features] self.attention_masks = [torch.tensor(example.attention_masks, dtype=torch.uint8) for example in features] self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features] self.labels = [torch.tensor(example.labels).long() for example in features] def __len__(self): return self.nums def __getitem__(self, index): data = { 'token_ids': self.token_ids[index], 'attention_masks': self.attention_masks[index], 'token_type_ids': self.token_type_ids[index] } data['labels'] = self.labels[index] return data