onehot_generate.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import json
  2. import numpy as np
  3. import os
  4. path = os.path.abspath(os.path.dirname(__file__))
  5. def onehot_generate():
  6. # 从文件读取映射关系
  7. with open(path + "/mapping.json", "r", encoding="utf-8") as json_file:
  8. loaded_data = json.load(json_file)
  9. # 提取关键词和故障现象
  10. keywords_list = []
  11. phenomenon_list = []
  12. for mapping in loaded_data:
  13. keywords_list.extend(mapping["keywords"])
  14. phenomenon_list.append(mapping["phenomenon"])
  15. # 生成关键词词表
  16. vocabulary = sorted(set(keywords_list))
  17. # 生成one-hot编码
  18. one_hot_vectors = []
  19. for word in vocabulary:
  20. one_hot = [1 if word == w else 0 for w in vocabulary]
  21. one_hot_vectors.append(one_hot)
  22. # 转换为NumPy数组
  23. one_hot_array = np.array(one_hot_vectors, dtype=np.int32)
  24. # 生成关键词与one-hot编码的字典
  25. keyword_one_hot_dict = {}
  26. for i, word in enumerate(vocabulary):
  27. keyword_one_hot_dict[word] = one_hot_array[i].tolist()
  28. # 生成故障现象词的向量
  29. phenomenon_vectors = []
  30. for phenomenon in phenomenon_list:
  31. vector_sum = np.zeros(len(vocabulary), dtype=np.int32)
  32. for keyword in keywords_list:
  33. if keyword in phenomenon:
  34. vector_sum = np.logical_or(vector_sum, np.array(keyword_one_hot_dict[keyword], dtype=np.int32))
  35. phenomenon_vectors.append(vector_sum.tolist())
  36. # 创建故障现象词的向量字典
  37. phenomenon_vector_dict = {}
  38. for i, phenomenon in enumerate(phenomenon_list):
  39. phenomenon_vector_dict[phenomenon] = [int(val) for val in phenomenon_vectors[i]]
  40. # 保存关键词与one-hot编码的字典为JSON文件
  41. with open(path + "/keyword_one_hot_dict.json", "w", encoding="utf-8") as json_file:
  42. json.dump(keyword_one_hot_dict, json_file, ensure_ascii=False, indent=4)
  43. # 保存故障现象词的向量字典为JSON文件
  44. with open(path + "/phenomenon_vector_dict.json", "w", encoding="utf-8") as json_file:
  45. json.dump(phenomenon_vector_dict, json_file, ensure_ascii=False, indent=4)