1234567891011121314151617181920212223 |
- import json
- def create_one_hot_encoding(vocabulary, word):
- encoding = [0] * len(vocabulary)
- if word in vocabulary:
- encoding[vocabulary[word]] = 1
- return encoding
- # 读取文本文件中的词表(使用 UTF-8 编码)
- with open("./vocabulary.txt", "r", encoding="utf-8") as file:
- loaded_vocabulary = [word.strip() for word in file.readlines()]
- # 创建词汇表字典
- vocabulary_dict = {word: index for index, word in enumerate(loaded_vocabulary)}
- # 创建 one-hot 编码字典
- one_hot_dict = {word: create_one_hot_encoding(vocabulary_dict, word) for word in vocabulary_dict}
- # 保存为 JSON 文件(使用 UTF-8 编码)
- with open("./one_hot_dict.json", "w", encoding="utf-8") as json_file:
- json.dump(one_hot_dict, json_file, ensure_ascii=False)
- print(one_hot_dict)
|