123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import json
- import torch
- import numpy as np
- from torch.utils.data import DataLoader, Dataset
- try:
- from utils.utils import sequence_padding
- except:
- from .utils.utils import sequence_padding
- class ListDataset(Dataset):
- def __init__(self, file_path=None, data=None, **kwargs):
- self.kwargs = kwargs
- if isinstance(file_path, (str, list)):
- self.data = self.load_data(file_path)
- elif isinstance(data, list):
- self.data = data
- else:
- raise ValueError('The input args shall be str format file_path / list format dataset')
- def __len__(self):
- return len(self.data)
- def __getitem__(self, index):
- return self.data[index]
- @staticmethod
- def load_data(file_path):
- return file_path
- # 加载数据集
- class MyDataset(ListDataset):
- @staticmethod
- def load_data(filename):
- examples = []
- with open(filename, encoding='utf-8') as f:
- raw_examples = f.readlines()
- # 这里是从json数据中的字典中获取
- for i, item in enumerate(raw_examples):
- print(i,item)
- item = item.strip().split('\t')
- if len(item) < 5:
- continue
- text = item[1]
- labels = item[0]
- ids = item[2:6]
- examples.append((text, labels, ids)) # 注意,这里的ids里面的索引已经加上了CLS
- return examples
- class Collate:
- def __init__(self, max_len, tag2id, device, tokenizer):
- self.maxlen = max_len
- self.tag2id = tag2id
- self.id2tag = {v:k for k,v in tag2id.items()}
- self.device = device
- self.tokenizer = tokenizer
- def collate_fn(self, batch):
- batch_labels = []
- batch_ids = []
- batch_token_ids = []
- batch_attention_mask = []
- batch_token_type_ids = []
- callback = []
- for i, (text, label ,ids) in enumerate(batch):
- if len(text) > self.maxlen - 2:
- text = text[:self.maxlen - 2]
- tokens = [i for i in text]
- tokens = ['[CLS]'] + tokens + ['[SEP]']
- # 过滤掉超过文本最大长度的
- flag = False
- for j in ids:
- if int(j) > self.maxlen - 2:
- flag = True
- break
- if flag:
- continue
- token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
- batch_token_ids.append(token_ids) # 前面已经限制了长度
- batch_attention_mask.append([1] * len(token_ids))
- batch_token_type_ids.append([0] * len(token_ids))
- batch_labels.append(int(self.tag2id[label]))
- batch_ids.append([int(m) for m in ids])
- callback.append((text, label))
- batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=self.maxlen), dtype=torch.long, device=self.device)
- attention_mask = torch.tensor(sequence_padding(batch_attention_mask, length=self.maxlen), dtype=torch.long, device=self.device)
- token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids, length=self.maxlen), dtype=torch.long, device=self.device)
- batch_labels = torch.tensor(batch_labels, dtype=torch.long, device=self.device)
- batch_ids = torch.tensor(batch_ids, dtype=torch.long, device=self.device)
- return batch_token_ids, attention_mask, token_type_ids, batch_labels, batch_ids, callback
- if __name__ == "__main__":
- from transformers import BertTokenizer
- max_len = 300
- tokenizer = BertTokenizer.from_pretrained('../model_hub/chinese-roberta-wwm-ext/vocab.txt')
- train_dataset = MyDataset(file_path='../data/dgre/re_mid_data/train.txt')
- # print(train_dataset[0])
- with open('../data/dgre/re_mid_data/rels.txt', 'r',encoding='utf-8') as fp:
- labels = fp.read().split('\n')
- id2tag = {}
- tag2id = {}
- for i,label in enumerate(labels):
- id2tag[i] = label
- tag2id[label] = i
- print(tag2id)
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- collate = Collate(max_len=max_len, tag2id=tag2id, device=device, tokenizer=tokenizer)
- collate.collate_fn(train_dataset[:16])
- batch_size = 2
- train_dataset = train_dataset[:10]
- train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate.collate_fn)
- for i, batch in enumerate(train_dataloader):
- leng = len(batch) - 1
- for j in range(leng):
- print(batch[j].shape)
- break
|