|
@@ -0,0 +1,166 @@
|
|
|
+
|
|
|
+import os
|
|
|
+import ahocorasick
|
|
|
+class QuestionClassifier:
|
|
|
+ def __init__(self):
|
|
|
+ cur_dir = '/'.join(os.path.abspath(__file__).split('/')[:-1])
|
|
|
+ # 特征词路径
|
|
|
+ self.entity_path = os.path.join(cur_dir, 'dict/entity.txt')
|
|
|
+ self.fault_path = os.path.join(cur_dir, 'dict/fault.txt')
|
|
|
+ self.parts_path = os.path.join(cur_dir, 'dict/parts.txt')
|
|
|
+ self.deny_path = os.path.join(cur_dir, 'dict/deny.txt')
|
|
|
+ # 加载特征词
|
|
|
+ self.entity_wds= [i.strip() for i in open(self.entity_path, encoding='utf-8') if i.strip()]
|
|
|
+ self.region_words = set(self.entity_wds)
|
|
|
+ self.deny_words = [i.strip() for i in open(self.deny_path,encoding='utf-8') if i.strip()]
|
|
|
+ # 构造领域actree
|
|
|
+ self.region_tree = self.build_actree(list(self.region_words))
|
|
|
+ # 构建词典
|
|
|
+ self.wdtype_dict = self.build_wdtype_dict()
|
|
|
+ # 问句疑问词
|
|
|
+ self.cause_qwds = ['原因','成因', '为什么', '怎么会', '怎样会', '如何会', '为啥', '为何']
|
|
|
+ self.solve_qwds = ['解决','处理','修理','修复','维修','怎么修','咋修','怎么办']
|
|
|
+
|
|
|
+ with open(self.fault_path, 'r',encoding='utf-8') as file:
|
|
|
+ lines = file.readlines()
|
|
|
+ # 去除每行末尾的换行符并拆分为词
|
|
|
+ self.faults = []
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip() # 去除换行符
|
|
|
+ words = line.split() # 拆分为词
|
|
|
+ self.faults.extend(words) # 添加到数组
|
|
|
+
|
|
|
+ with open(self.parts_path, 'r',encoding='utf-8') as file:
|
|
|
+ lines = file.readlines()
|
|
|
+ # 去除每行末尾的换行符并拆分为词
|
|
|
+ self.parts = []
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip() # 去除换行符
|
|
|
+ words = line.split() # 拆分为词
|
|
|
+ self.parts.extend(words) # 添加到数组
|
|
|
+
|
|
|
+ print('model init finished ......')
|
|
|
+
|
|
|
+ return
|
|
|
+
|
|
|
+ '''分类主函数'''
|
|
|
+ def classify(self, question):
|
|
|
+ data = {}
|
|
|
+ max = 100.0
|
|
|
+ cur = ''
|
|
|
+ for i in self.entity_wds:
|
|
|
+ score = edit_distance(question, i)
|
|
|
+ if(score <= max):
|
|
|
+ if(score == max and len(i) < len(cur)):
|
|
|
+ continue
|
|
|
+ cur = i
|
|
|
+ max = score
|
|
|
+ if(max == len(cur) + len(question)):return{}
|
|
|
+ medical_dict = {cur:['entity']}
|
|
|
+
|
|
|
+ data['args'] = medical_dict
|
|
|
+ #收集问句当中所涉及到的实体类型
|
|
|
+ types = ['entity']
|
|
|
+ #for type_ in medical_dict.values():
|
|
|
+ # types += type_
|
|
|
+ question_type = 'others'
|
|
|
+
|
|
|
+ question_types = []
|
|
|
+
|
|
|
+ # 故障
|
|
|
+ if self.check_words(self.faults, question) and ('entity' in types):
|
|
|
+ question_type = 'solve'
|
|
|
+ question_types.append(question_type)
|
|
|
+
|
|
|
+ # 原因
|
|
|
+ if self.check_words(self.cause_qwds, question) and ('entity' in types):
|
|
|
+ question_type = 'entity_cause'
|
|
|
+ question_types.append(question_type)
|
|
|
+
|
|
|
+ if self.check_words(self.parts, question) and ('entity' in types):
|
|
|
+ question_type = 'parts'
|
|
|
+ question_types.append(question_type)
|
|
|
+
|
|
|
+
|
|
|
+ #解决
|
|
|
+ if self.check_words(self.solve_qwds, question) and ('entity' in types):
|
|
|
+ question_type = 'solve'
|
|
|
+ question_types.append(question_type)
|
|
|
+
|
|
|
+ # 若没有查到相关的外部查询信息,那么则将描述信息返回
|
|
|
+ if question_types == [] and 'entity' in types:
|
|
|
+ question_types = ['entity_desc', 'entity_desc1']
|
|
|
+
|
|
|
+ # 将多个分类结果进行合并处理,组装成一个字典
|
|
|
+ data['question_types'] = question_types
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+ '''构造词对应的类型'''
|
|
|
+ def build_wdtype_dict(self):
|
|
|
+ wd_dict = dict()
|
|
|
+ for wd in self.region_words:
|
|
|
+ wd_dict[wd] = []
|
|
|
+ if wd in self.entity_wds:
|
|
|
+ wd_dict[wd].append('entity')
|
|
|
+ return wd_dict
|
|
|
+
|
|
|
+ '''构造actree,加速过滤'''
|
|
|
+ def build_actree(self, wordlist):
|
|
|
+ actree = ahocorasick.Automaton()
|
|
|
+ for index, word in enumerate(wordlist):
|
|
|
+ actree.add_word(word, (index, word))
|
|
|
+ actree.make_automaton()
|
|
|
+ return actree
|
|
|
+
|
|
|
+ '''问句过滤'''
|
|
|
+ def check_medical(self, question):
|
|
|
+ region_wds = []
|
|
|
+ for i in self.region_tree.iter(question):
|
|
|
+ wd = i[1][1]
|
|
|
+ region_wds.append(wd)
|
|
|
+ stop_wds = []
|
|
|
+ for wd1 in region_wds:
|
|
|
+ for wd2 in region_wds:
|
|
|
+ if wd1 in wd2 and wd1 != wd2:
|
|
|
+ stop_wds.append(wd1)
|
|
|
+ final_wds = [i for i in region_wds if i not in stop_wds]
|
|
|
+ final_dict = {i:self.wdtype_dict.get(i) for i in final_wds}
|
|
|
+
|
|
|
+ return final_dict
|
|
|
+
|
|
|
+ '''基于特征词进行分类'''
|
|
|
+ def check_words(self, wds, sent):
|
|
|
+ for wd in wds:
|
|
|
+ if wd in sent:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+def edit_distance(text1, text2):
|
|
|
+ # 初始化矩阵
|
|
|
+ m = len(text1) + 1
|
|
|
+ n = len(text2) + 1
|
|
|
+ dp = [[0 for _ in range(n)] for _ in range(m)]
|
|
|
+
|
|
|
+ # 初始化第一行和第一列
|
|
|
+ for i in range(1, m):
|
|
|
+ dp[i][0] = i
|
|
|
+ for j in range(1, n):
|
|
|
+ dp[0][j] = j
|
|
|
+
|
|
|
+ # 计算编辑距离
|
|
|
+ for i in range(1, m):
|
|
|
+ for j in range(1, n):
|
|
|
+ if text1[i-1] == text2[j-1]:
|
|
|
+ dp[i][j] = dp[i-1][j-1]
|
|
|
+ else:
|
|
|
+ dp[i][j] = min(dp[i-1][j], dp[i][j-1]) + 1
|
|
|
+
|
|
|
+ # 返回编辑距离
|
|
|
+ return dp[-1][-1]
|
|
|
+if __name__ == '__main__':
|
|
|
+ handler = QuestionClassifier()
|
|
|
+ while 1:
|
|
|
+ question = input('input an question:')
|
|
|
+ data = handler.classify(question)
|
|
|
+ print(data)
|