match_entity_extractor.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. from itertools import combinations
  3. from typing import Any, Text, Dict
  4. from rasa.nlu.extractors.extractor import EntityExtractor
  5. class MatchEntityExtractor(EntityExtractor):
  6. """绝对匹配提取实体"""
  7. provides = ["entities"]
  8. defaults = {
  9. "dictionary_path": None,
  10. "take_long": None,
  11. "take_short": None
  12. }
  13. def __init__(self, component_config=None):
  14. print("init")
  15. super(MatchEntityExtractor, self).__init__(component_config)
  16. self.dictionary_path = self.component_config.get("dictionary_path")
  17. self.take_long = self.component_config.get("take_long")
  18. self.take_short = self.component_config.get("take_short")
  19. if self.take_long and self.take_short:
  20. raise ValueError("take_long and take_short can not be both True")
  21. self.data = {} # 用于绝对匹配的数据
  22. for file_path in os.listdir(self.dictionary_path):
  23. if file_path.endswith(".txt"):
  24. file_path = os.path.join(self.dictionary_path, file_path)
  25. file_name = os.path.basename(file_path)[:-4]
  26. with open(file_path, mode="r", encoding="utf-8") as f:
  27. self.data[file_name] = f.read().splitlines()
  28. def process(self, message, **kwargs):
  29. """绝对匹配提取实体词"""
  30. print("process")
  31. entities = []
  32. for entity, value in self.data.items():
  33. for i in value:
  34. start = message.text.find(i)
  35. if start != -1:
  36. entities.append({
  37. "start": start,
  38. "end": start + len(i),
  39. "value": i,
  40. "entity": entity,
  41. "confidence": 1
  42. })
  43. if self.take_long or self.take_short:
  44. for i in list(combinations(entities, 2)):
  45. v0, v1 = i[0]["value"], i[1]["value"]
  46. if v0 in v1 or v1 in v0:
  47. (long, short) = (i[0], i[1]) if len(v0) > len(v1) else (i[1], i[0])
  48. if self.take_long == True and short in entities:
  49. entities.remove(short)
  50. if self.take_short == True and long in entities:
  51. entities.remove(long)
  52. extracted = self.add_extractor_name(entities)
  53. message.set("entities", extracted, add_to_output=True)
  54. @classmethod
  55. def load(cls, meta: Dict[Text, Any], model_dir=None, model_metadata=None, cached_component=None, **kwargs):
  56. print("load")
  57. return cls(meta)