neo4j_jaro_winkler.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # from typing import Any, Optional, Text, Dict
  2. # from rasa.nlu.components import Component
  3. # from rasa.shared.nlu.training_data.message import Message
  4. # from rasa.shared.nlu.training_data.training_data import TrainingData
  5. # from neo4j import GraphDatabase
  6. # import jellyfish
  7. #
  8. # class Neo4jJaroWinkler(Component):
  9. # """Custom component to query Neo4j and calculate Jaro-Winkler similarity."""
  10. #
  11. # name = "Neo4jJaroWinkler"
  12. #
  13. # defaults = {
  14. # "uri": "bolt://localhost:7687",
  15. # "auth": ("neo4j", "fdx3081475970"),
  16. # "label": "故障名称",
  17. # "property": "name",
  18. # "threshold": 0.85
  19. # }
  20. #
  21. # def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
  22. # super().__init__(component_config)
  23. # self.uri = component_config.get("uri")
  24. # self.auth = component_config.get("auth")
  25. # self.label = component_config.get("label")
  26. # self.property = component_config.get("property")
  27. # self.threshold = component_config.get("threshold")
  28. # self.driver = GraphDatabase.driver(self.uri, auth=self.auth)
  29. #
  30. # def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None:
  31. # """Train the component."""
  32. # pass
  33. #
  34. # def process(self, message: Message, **kwargs: Any) -> None:
  35. # """Process a message."""
  36. # target_string = message.text
  37. # query = f"""
  38. # MATCH (n:{self.label})
  39. # RETURN n.{self.property} AS name
  40. # """
  41. #
  42. # with self.driver.session() as session:
  43. # result = session.run(query)
  44. # names = [record["name"] for record in result]
  45. #
  46. # similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names]
  47. # filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold]
  48. #
  49. # if filtered_similarities:
  50. # highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1])
  51. # message.set("highest_similarity_name", highest_similarity_name)
  52. # message.set("highest_similarity_score", highest_similarity_score)
  53. # else:
  54. # message.set("highest_similarity_name", None)
  55. # message.set("highest_similarity_score", None)
  56. #
  57. # def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
  58. # """Persist the component."""
  59. # pass
  60. #
  61. # def __del__(self):
  62. # """Close the Neo4j driver when the component is deleted."""
  63. # self.driver.close()
  64. # neo4j_jaro_winkler.py
  65. from typing import Any, Optional, Text, Dict
  66. from rasa.nlu.components import Component
  67. from rasa.shared.nlu.training_data.message import Message
  68. from rasa.shared.nlu.training_data.training_data import TrainingData
  69. from neo4j import GraphDatabase
  70. import jellyfish
  71. class Neo4jJaroWinkler(Component):
  72. """Custom component to query Neo4j and calculate Jaro-Winkler similarity."""
  73. name = "Neo4jJaroWinkler"
  74. defaults = {
  75. "uri": "bolt://localhost:7687",
  76. "auth": ("neo4j", "fdx3081475970"),
  77. "label": "故障名称",
  78. "property": "name",
  79. "threshold": 0.85
  80. }
  81. def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
  82. super().__init__(component_config)
  83. self.uri = component_config.get("uri")
  84. self.auth = tuple(component_config.get("auth"))
  85. self.label = component_config.get("label")
  86. self.property = component_config.get("property")
  87. self.threshold = component_config.get("threshold")
  88. self.driver = GraphDatabase.driver(self.uri, auth=self.auth)
  89. def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None:
  90. """Train the component."""
  91. pass
  92. def process(self, message: Message, **kwargs: Any) -> None:
  93. """Process a message."""
  94. target_string = message.data.get("text")
  95. if not target_string:
  96. return
  97. query = f"""
  98. MATCH (n:{self.label})
  99. RETURN n.{self.property} AS name
  100. """
  101. with self.driver.session() as session:
  102. result = session.run(query)
  103. names = [record["name"] for record in result]
  104. similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names]
  105. filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold]
  106. if filtered_similarities:
  107. highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1])
  108. message.set("highest_similarity_name", highest_similarity_name)
  109. message.set("highest_similarity_score", highest_similarity_score)
  110. else:
  111. message.set("highest_similarity_name", None)
  112. message.set("highest_similarity_score", None)
  113. def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
  114. """Persist the component."""
  115. pass
  116. def __del__(self):
  117. """Close the Neo4j driver when the component is deleted."""
  118. self.driver.close()