# from typing import Any, Optional, Text, Dict # from rasa.nlu.components import Component # from rasa.shared.nlu.training_data.message import Message # from rasa.shared.nlu.training_data.training_data import TrainingData # from neo4j import GraphDatabase # import jellyfish # # class Neo4jJaroWinkler(Component): # """Custom component to query Neo4j and calculate Jaro-Winkler similarity.""" # # name = "Neo4jJaroWinkler" # # defaults = { # "uri": "bolt://localhost:7687", # "auth": ("neo4j", "fdx3081475970"), # "label": "故障名称", # "property": "name", # "threshold": 0.85 # } # # def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None: # super().__init__(component_config) # self.uri = component_config.get("uri") # self.auth = component_config.get("auth") # self.label = component_config.get("label") # self.property = component_config.get("property") # self.threshold = component_config.get("threshold") # self.driver = GraphDatabase.driver(self.uri, auth=self.auth) # # def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None: # """Train the component.""" # pass # # def process(self, message: Message, **kwargs: Any) -> None: # """Process a message.""" # target_string = message.text # query = f""" # MATCH (n:{self.label}) # RETURN n.{self.property} AS name # """ # # with self.driver.session() as session: # result = session.run(query) # names = [record["name"] for record in result] # # similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names] # filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold] # # if filtered_similarities: # highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1]) # message.set("highest_similarity_name", highest_similarity_name) # message.set("highest_similarity_score", highest_similarity_score) # else: # message.set("highest_similarity_name", None) # message.set("highest_similarity_score", None) # # def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]: # """Persist the component.""" # pass # # def __del__(self): # """Close the Neo4j driver when the component is deleted.""" # self.driver.close() # neo4j_jaro_winkler.py from typing import Any, Optional, Text, Dict from rasa.nlu.components import Component from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData from neo4j import GraphDatabase import jellyfish class Neo4jJaroWinkler(Component): """Custom component to query Neo4j and calculate Jaro-Winkler similarity.""" name = "Neo4jJaroWinkler" defaults = { "uri": "bolt://localhost:7687", "auth": ("neo4j", "fdx3081475970"), "label": "故障名称", "property": "name", "threshold": 0.85 } def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None: super().__init__(component_config) self.uri = component_config.get("uri") self.auth = tuple(component_config.get("auth")) self.label = component_config.get("label") self.property = component_config.get("property") self.threshold = component_config.get("threshold") self.driver = GraphDatabase.driver(self.uri, auth=self.auth) def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None: """Train the component.""" pass def process(self, message: Message, **kwargs: Any) -> None: """Process a message.""" target_string = message.data.get("text") if not target_string: return query = f""" MATCH (n:{self.label}) RETURN n.{self.property} AS name """ with self.driver.session() as session: result = session.run(query) names = [record["name"] for record in result] similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names] filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold] if filtered_similarities: highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1]) message.set("highest_similarity_name", highest_similarity_name) message.set("highest_similarity_score", highest_similarity_score) else: message.set("highest_similarity_name", None) message.set("highest_similarity_score", None) def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]: """Persist the component.""" pass def __del__(self): """Close the Neo4j driver when the component is deleted.""" self.driver.close()