test.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from fastapi import FastAPI, HTTPException, Request
  2. from pydantic import BaseModel
  3. import aiohttp
  4. import logging
  5. from fastapi.responses import JSONResponse
  6. import json
  7. from typing import Optional, List, Dict
  8. import asyncio
  9. app = FastAPI()
  10. # 配置日志
  11. logging.basicConfig(level=logging.INFO)
  12. logger = logging.getLogger(__name__)
  13. class Query(BaseModel):
  14. question: str
  15. userId: str
  16. async def call_generate_sql(question: str, userId: str) -> Optional[List]:
  17. url = f"http://192.168.72.100:9040/ai/ollama/syncChatArr?prompt={question},其中不要关心数字内容"
  18. try:
  19. async with aiohttp.ClientSession() as session:
  20. async with session.get(url) as response:
  21. if response.status == 200:
  22. return await response.json()
  23. else:
  24. logger.error(f"Error calling generate-sql service: {response.status} - {await response.text()}")
  25. return None
  26. except Exception as e:
  27. logger.error(f"Unexpected error while calling generate-sql service: {e}")
  28. return None
  29. async def call_llm(question: str, userId: str) -> Optional[Dict]:
  30. url = "http://192.168.72.100:7861/chat/kb_chat"
  31. # qusetion = f"我的的问题是{question},R1,你可以在知识库中帮我找到问题的答案么,其中我想要的输出和原文一样就可以,尽可能的将完整的内容回复给我就好,如果在原文中有明确的列出了某个问题的现象/原因/原理/排故流程等内容的话,请将原文的内容输出给我就好,"
  32. als =f"{question}?,不要生成,不要统计,只需要展示原文和相近问题的回复有什么"
  33. print("question", als)
  34. # payload = {
  35. # "query": question,
  36. # "mode": "local_kb",
  37. # "kb_name": "lqbz",
  38. # "top_k": 3,
  39. # "score_threshold": 2,
  40. # "history": [
  41. # {
  42. # "content": "",
  43. # "role": "user"
  44. # },
  45. # {
  46. # "content": "",
  47. # "role": "assistant"
  48. # }
  49. # ],
  50. # "stream": True,
  51. # "model": "qwen2.5-coder:7b",
  52. # "temperature": 0.7,
  53. # "max_tokens": 0,
  54. # "prompt_name": "default",
  55. # "return_direct": False
  56. # }
  57. #
  58. payload = {
  59. "query": als,
  60. "mode": "local_kb",
  61. "kb_name": "lqbz",
  62. "top_k": 3,
  63. "score_threshold": 2,
  64. "history": [
  65. {
  66. "content": '',
  67. "role": "user"
  68. },
  69. {
  70. "content": "",
  71. "role": "assistant"
  72. }
  73. ],
  74. "stream": True, # 设置为True以启用流式传输
  75. "model": "deepseek-r1:7b",
  76. "temperature": 0.7,
  77. "max_tokens": 0,
  78. "prompt_name": "default",
  79. "return_direct": False
  80. }
  81. combined_text = ""
  82. all_docs = []
  83. try:
  84. async with aiohttp.ClientSession() as session:
  85. async with session.post(url, json=payload) as response:
  86. if response.status == 200:
  87. async for line in response.content:
  88. decoded_line = line.decode('utf-8')
  89. if decoded_line.startswith('data: '):
  90. json_str = decoded_line[len('data: '):]
  91. try:
  92. json_obj = json.loads(json_str)
  93. # 提取choices中的内容并合并
  94. for choice in json_obj.get("choices", []):
  95. content = choice.get("delta", {}).get("content", "")
  96. if content:
  97. combined_text += content # 直接追加,保留换行符
  98. # 收集docs中的内容
  99. docs = json_obj.get("docs", [])
  100. all_docs.extend(docs)
  101. except json.JSONDecodeError as e:
  102. logger.error(f"无法解析的块: {json_str[:50]}...")
  103. logger.error(f"错误信息: {e}")
  104. return {"answer": combined_text, "docs": all_docs}
  105. else:
  106. logger.error(f"Error calling service: {response.status} - {await response.text()}")
  107. return None
  108. except Exception as e:
  109. logger.error(f"Unexpected error while calling service: {e}")
  110. return None
  111. async def call_rasa(question: str, userId: str) -> dict:
  112. url = "http://localhost:5005/webhooks/rest/webhook"
  113. payload = {"userId": userId, "message": question}
  114. async def _call_rasa():
  115. try:
  116. async with aiohttp.ClientSession() as session:
  117. async with session.post(url, json=payload) as response:
  118. if response.status == 200:
  119. content_type = response.headers.get("Content-Type", "")
  120. if "application/json" in content_type:
  121. data = await response.json()
  122. logger.info(f"Response from Rasa: {data}")
  123. # 判断是否是简单的文本回复
  124. if (isinstance(data, list) and len(data) == 1 and
  125. isinstance(data[0], dict) and
  126. 'recipient_id' in data[0] and
  127. 'text' in data[0]):
  128. # 这里只返回 text 的内容
  129. return {
  130. "answer": data[0].get("text"),
  131. "ossId": None,
  132. "fileName": None,
  133. "filePage": None,
  134. "graph": None
  135. }
  136. answer = data[0].get("text", "No answer found")
  137. entity_lists = data[1].get("text", "No graph found")[1:-1] if len(data) > 1 else "None"
  138. entity_str = entity_lists.replace("'", '"')
  139. try:
  140. entity_data = json.loads(entity_str)
  141. doc = entity_data['docs'][0]
  142. ossid = doc['ossId']
  143. file_name = doc['file_name']
  144. filePage = doc["filePage"]
  145. graph = {
  146. "data": entity_data['data'],
  147. "links": entity_data['links']
  148. }
  149. return {
  150. "answer": answer,
  151. "ossId": str(ossid),
  152. "fileName": file_name,
  153. "filePage": filePage,
  154. "graph": str(graph)
  155. }
  156. except (json.JSONDecodeError, KeyError, IndexError):
  157. logger.error("Failed to parse additional data from Rasa response.")
  158. # 如果没有额外的数据或解析失败,则只返回基本的回答
  159. return {
  160. "answer": answer,
  161. "ossId": None,
  162. "fileName": None,
  163. "filePage": None,
  164. "graph": None
  165. }
  166. else:
  167. logger.error(f"Invalid content type: {content_type}")
  168. raise HTTPException(status_code=500, detail="Invalid response content type")
  169. else:
  170. logger.error(f"Error calling Rasa server: {response.status} - {await response.text()}")
  171. raise HTTPException(status_code=response.status, detail="Error while calling Rasa Server")
  172. except Exception as e:
  173. logger.error(f"Unexpected error while calling Rasa server: {e}")
  174. raise HTTPException(status_code=500, detail="Internal Server Error")
  175. try:
  176. # 使用 asyncio.wait_for 设置超时时间为10秒
  177. result = await asyncio.wait_for(_call_rasa(), timeout=10.0)
  178. return result
  179. except asyncio.TimeoutError:
  180. # 如果超时,返回空结果
  181. return {
  182. "answer": None,
  183. "ossId": None,
  184. "fileName": None,
  185. "filePage": None,
  186. "graph": None
  187. }
  188. @app.post("/kgqa/ask")
  189. async def ask_rasa(query: Query):
  190. tasks = [
  191. call_generate_sql(query.question, query.userId),
  192. call_llm(query.question, query.userId),
  193. call_rasa(query.question, query.userId)
  194. ]
  195. results = await asyncio.gather(*tasks, return_exceptions=True)
  196. consolidated_response = {"userId": query.userId}
  197. for idx, result in enumerate(results):
  198. if isinstance(result, Exception):
  199. logger.error(f"Error occurred during request {idx}: {result}")
  200. continue
  201. if idx == 0:
  202. consolidated_response.update({"sqlAnswer": result})
  203. elif idx == 1 and result is not None:
  204. # print(result)
  205. res = result['answer']
  206. if "<think>" in res:
  207. think = res.split("</think>")[0]
  208. think = think.replace("<think>", "")
  209. answer = res.split("</think>")[1]
  210. else:
  211. think = ""
  212. answer = res
  213. docs = result['docs']
  214. result = {"think": think, "answer": answer, "docs": docs}
  215. consolidated_response.update({"llmAnswer": result})
  216. elif idx == 2 and result is not None:
  217. if result['graph'] == None:
  218. result = None
  219. consolidated_response.update({"graphAnswer": result})
  220. return {"code": 200, "msg": "操作成功", "data": consolidated_response}
  221. @app.exception_handler(HTTPException)
  222. async def http_exception_handler(request: Request, exc: HTTPException):
  223. return JSONResponse(
  224. status_code=exc.status_code,
  225. content={"detail": exc.detail},
  226. )
  227. if __name__ == "__main__":
  228. import uvicorn
  229. uvicorn.run(app, host="0.0.0.0", port=7074)