answer_search.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from py2neo import Graph
  2. class AnswerSearcher:
  3. def __init__(self):
  4. self.g = Graph(
  5. "bolt://127.0.0.1:7687",
  6. # host="127.0.0.1",
  7. # http_port=7474,
  8. # user="neo4j",
  9. # password="123456"
  10. # http_port=7687, # neo4j 服务器监听的端口号
  11. user="neo4j", # 数据库user name,如果没有更改过,应该是neo4j
  12. password="123456")
  13. self.num_limit = 20
  14. '''执行cypher查询,并返回相应结果'''
  15. def search_main(self, sqls):
  16. final_answers = []
  17. for sql_ in sqls:
  18. question_type = sql_['question_type']
  19. queries = sql_['sql']
  20. answers = []
  21. for query in queries:
  22. ress = self.g.run(query).data()
  23. answers += ress
  24. final_answer = self.answer_prettify(question_type, answers)
  25. if final_answer:
  26. final_answers.append(final_answer)
  27. return final_answers
  28. '''根据对应的qustion_type,板调用相应的回复模'''
  29. def answer_prettify(self, question_type, answers):
  30. final_answer = []
  31. if not answers:
  32. return ''
  33. if question_type == 'entity_cause':
  34. desc = [i['c.name'] for i in answers]
  35. subject = answers[0]['a.name']
  36. final_answer = '{0}现象的原因包括:{1}'.format(
  37. subject, ';'.join(list(set(desc))[:self.num_limit]))
  38. elif question_type == 'parts':
  39. desc = [i['a.name'] for i in answers]
  40. subject = answers[0]['c.name']
  41. final_answer = '{0}的性能故障可能有:{1}'.format(
  42. subject, ';'.join(list(set(desc))[:self.num_limit]))
  43. elif question_type == 'solve':
  44. desc = [i['c.name'] for i in answers]
  45. subject = answers[0]['a.name']
  46. final_answer = '{0}的排故流程是:{1}'.format(
  47. subject, ';'.join(list(set(desc))[:self.num_limit]))
  48. elif question_type == 'entity_desc':
  49. subject = answers[0]['c.name']
  50. result = {}
  51. for answer in answers:
  52. rel_name = answer['rel_name']
  53. target_node_name = answer['target_nodes']
  54. if rel_name not in result:
  55. result[rel_name] = []
  56. result[rel_name].append(target_node_name)
  57. answer_list = []
  58. for key in result:
  59. if result[key]:
  60. desc = '、'.join(result[key][0])
  61. answer_list.append('{0}{1}有:{2}'.format(subject, key, desc))
  62. final_answer = ';\n'.join(answer_list)
  63. elif question_type == 'entity_desc1':
  64. subject = answers[0]['a.name']
  65. result = {}
  66. for answer in answers:
  67. rel_name = answer['rel_name']
  68. source_node_name = answer['source_nodes']
  69. if rel_name not in result:
  70. result[rel_name] = []
  71. result[rel_name].append(source_node_name)
  72. answer_list = []
  73. for key in result:
  74. if result[key]:
  75. desc = '、'.join(result[key][0])
  76. answer_list.append('{0}被发出的{1}关系有:{2}'.format(subject, key, desc))
  77. final_answer = ';'.join(answer_list)
  78. return final_answer
  79. if __name__ == '__main__':
  80. searcher = AnswerSearcher()