123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # from typing import Any, Optional, Text, Dict
- # from rasa.nlu.components import Component
- # from rasa.shared.nlu.training_data.message import Message
- # from rasa.shared.nlu.training_data.training_data import TrainingData
- # from neo4j import GraphDatabase
- # import jellyfish
- #
- # class Neo4jJaroWinkler(Component):
- # """Custom component to query Neo4j and calculate Jaro-Winkler similarity."""
- #
- # name = "Neo4jJaroWinkler"
- #
- # defaults = {
- # "uri": "bolt://localhost:7687",
- # "auth": ("neo4j", "fdx3081475970"),
- # "label": "故障名称",
- # "property": "name",
- # "threshold": 0.85
- # }
- #
- # def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
- # super().__init__(component_config)
- # self.uri = component_config.get("uri")
- # self.auth = component_config.get("auth")
- # self.label = component_config.get("label")
- # self.property = component_config.get("property")
- # self.threshold = component_config.get("threshold")
- # self.driver = GraphDatabase.driver(self.uri, auth=self.auth)
- #
- # def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None:
- # """Train the component."""
- # pass
- #
- # def process(self, message: Message, **kwargs: Any) -> None:
- # """Process a message."""
- # target_string = message.text
- # query = f"""
- # MATCH (n:{self.label})
- # RETURN n.{self.property} AS name
- # """
- #
- # with self.driver.session() as session:
- # result = session.run(query)
- # names = [record["name"] for record in result]
- #
- # similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names]
- # filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold]
- #
- # if filtered_similarities:
- # highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1])
- # message.set("highest_similarity_name", highest_similarity_name)
- # message.set("highest_similarity_score", highest_similarity_score)
- # else:
- # message.set("highest_similarity_name", None)
- # message.set("highest_similarity_score", None)
- #
- # def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
- # """Persist the component."""
- # pass
- #
- # def __del__(self):
- # """Close the Neo4j driver when the component is deleted."""
- # self.driver.close()
- # neo4j_jaro_winkler.py
- from typing import Any, Optional, Text, Dict
- from rasa.nlu.components import Component
- from rasa.shared.nlu.training_data.message import Message
- from rasa.shared.nlu.training_data.training_data import TrainingData
- from neo4j import GraphDatabase
- import jellyfish
- class Neo4jJaroWinkler(Component):
- """Custom component to query Neo4j and calculate Jaro-Winkler similarity."""
- name = "Neo4jJaroWinkler"
- defaults = {
- "uri": "bolt://localhost:7687",
- "auth": ("neo4j", "fdx3081475970"),
- "label": "故障名称",
- "property": "name",
- "threshold": 0.85
- }
- def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None:
- super().__init__(component_config)
- self.uri = component_config.get("uri")
- self.auth = tuple(component_config.get("auth"))
- self.label = component_config.get("label")
- self.property = component_config.get("property")
- self.threshold = component_config.get("threshold")
- self.driver = GraphDatabase.driver(self.uri, auth=self.auth)
- def train(self, training_data: TrainingData, config: Dict[Text, Any], **kwargs: Any) -> None:
- """Train the component."""
- pass
- def process(self, message: Message, **kwargs: Any) -> None:
- """Process a message."""
- target_string = message.data.get("text")
- if not target_string:
- return
- query = f"""
- MATCH (n:{self.label})
- RETURN n.{self.property} AS name
- """
- with self.driver.session() as session:
- result = session.run(query)
- names = [record["name"] for record in result]
- similarity_scores = [(name, jellyfish.jaro_winkler_similarity(name, target_string)) for name in names]
- filtered_similarities = [(name, score) for name, score in similarity_scores if score > self.threshold]
- if filtered_similarities:
- highest_similarity_name, highest_similarity_score = max(filtered_similarities, key=lambda x: x[1])
- message.set("highest_similarity_name", highest_similarity_name)
- message.set("highest_similarity_score", highest_similarity_score)
- else:
- message.set("highest_similarity_name", None)
- message.set("highest_similarity_score", None)
- def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:
- """Persist the component."""
- pass
- def __del__(self):
- """Close the Neo4j driver when the component is deleted."""
- self.driver.close()
|