question_classifier.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. import ahocorasick
  3. class QuestionClassifier:
  4. def __init__(self):
  5. cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1])
  6. # 特征词路径
  7. self.entity_path = os.path.join(cur_dir, 'dict/entity.txt')
  8. self.fault_path = os.path.join(cur_dir, 'dict/fault.txt')
  9. self.parts_path = os.path.join(cur_dir, 'dict/parts.txt')
  10. self.deny_path = os.path.join(cur_dir, 'dict/deny.txt')
  11. # 加载特征词
  12. self.entity_wds= [i.strip() for i in open(self.entity_path, encoding='utf-8') if i.strip()]
  13. self.region_words = set(self.entity_wds)
  14. self.deny_words = [i.strip() for i in open(self.deny_path,encoding='utf-8') if i.strip()]
  15. # 构造领域actree
  16. self.region_tree = self.build_actree(list(self.region_words))
  17. # 构建词典
  18. self.wdtype_dict = self.build_wdtype_dict()
  19. # 问句疑问词
  20. self.cause_qwds = ['原因','成因', '为什么', '怎么会', '怎样会', '如何会', '为啥', '为何']
  21. self.solve_qwds = ['解决','处理','修理','修复','维修','怎么修','咋修','怎么办']
  22. with open(self.fault_path, 'r',encoding='utf-8') as file:
  23. lines = file.readlines()
  24. # 去除每行末尾的换行符并拆分为词
  25. self.faults = []
  26. for line in lines:
  27. line = line.strip() # 去除换行符
  28. words = line.split() # 拆分为词
  29. self.faults.extend(words) # 添加到数组
  30. with open(self.parts_path, 'r',encoding='utf-8') as file:
  31. lines = file.readlines()
  32. # 去除每行末尾的换行符并拆分为词
  33. self.parts = []
  34. for line in lines:
  35. line = line.strip() # 去除换行符
  36. words = line.split() # 拆分为词
  37. self.parts.extend(words) # 添加到数组
  38. print('model init finished ......')
  39. return
  40. '''分类主函数'''
  41. def classify(self, question):
  42. data = {}
  43. max = 100.0
  44. cur = ''
  45. for i in self.entity_wds:
  46. score = edit_distance(question, i)
  47. if(score <= max):
  48. if(score == max and len(i) < len(cur)):
  49. continue
  50. cur = i
  51. max = score
  52. if(max == len(cur) + len(question)):return{}
  53. medical_dict = {cur:['entity']}
  54. data['args'] = medical_dict
  55. #收集问句当中所涉及到的实体类型
  56. types = ['entity']
  57. #for type_ in medical_dict.values():
  58. # types += type_
  59. question_type = 'others'
  60. question_types = []
  61. # 故障
  62. if self.check_words(self.faults, question) and ('entity' in types):
  63. question_type = 'solve'
  64. question_types.append(question_type)
  65. # 原因
  66. if self.check_words(self.cause_qwds, question) and ('entity' in types):
  67. question_type = 'entity_cause'
  68. question_types.append(question_type)
  69. if self.check_words(self.parts, question) and ('entity' in types):
  70. question_type = 'parts'
  71. question_types.append(question_type)
  72. #解决
  73. if self.check_words(self.solve_qwds, question) and ('entity' in types):
  74. question_type = 'solve'
  75. question_types.append(question_type)
  76. # 若没有查到相关的外部查询信息,那么则将描述信息返回
  77. if question_types == [] and 'entity' in types:
  78. question_types = ['entity_desc', 'entity_desc1']
  79. # 将多个分类结果进行合并处理,组装成一个字典
  80. data['question_types'] = question_types
  81. return data
  82. '''构造词对应的类型'''
  83. def build_wdtype_dict(self):
  84. wd_dict = dict()
  85. for wd in self.region_words:
  86. wd_dict[wd] = []
  87. if wd in self.entity_wds:
  88. wd_dict[wd].append('entity')
  89. return wd_dict
  90. '''构造actree,加速过滤'''
  91. def build_actree(self, wordlist):
  92. actree = ahocorasick.Automaton()
  93. for index, word in enumerate(wordlist):
  94. actree.add_word(word, (index, word))
  95. actree.make_automaton()
  96. return actree
  97. '''问句过滤'''
  98. def check_medical(self, question):
  99. region_wds = []
  100. for i in self.region_tree.iter(question):
  101. wd = i[1][1]
  102. region_wds.append(wd)
  103. stop_wds = []
  104. for wd1 in region_wds:
  105. for wd2 in region_wds:
  106. if wd1 in wd2 and wd1 != wd2:
  107. stop_wds.append(wd1)
  108. final_wds = [i for i in region_wds if i not in stop_wds]
  109. final_dict = {i:self.wdtype_dict.get(i) for i in final_wds}
  110. return final_dict
  111. '''基于特征词进行分类'''
  112. def check_words(self, wds, sent):
  113. for wd in wds:
  114. if wd in sent:
  115. return True
  116. return False
  117. def edit_distance(text1, text2):
  118. # 初始化矩阵
  119. m = len(text1) + 1
  120. n = len(text2) + 1
  121. dp = [[0 for _ in range(n)] for _ in range(m)]
  122. # 初始化第一行和第一列
  123. for i in range(1, m):
  124. dp[i][0] = i
  125. for j in range(1, n):
  126. dp[0][j] = j
  127. # 计算编辑距离
  128. for i in range(1, m):
  129. for j in range(1, n):
  130. if text1[i-1] == text2[j-1]:
  131. dp[i][j] = dp[i-1][j-1]
  132. else:
  133. dp[i][j] = min(dp[i-1][j], dp[i][j-1]) + 1
  134. # 返回编辑距离
  135. return dp[-1][-1]
  136. if __name__ == '__main__':
  137. handler = QuestionClassifier()
  138. while 1:
  139. question = input('input an question:')
  140. data = handler.classify(question)
  141. print(data)