zzz.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from fastapi import FastAPI, HTTPException, Request
  2. from pydantic import BaseModel
  3. import aiohttp
  4. import logging
  5. from fastapi.responses import JSONResponse
  6. import json
  7. from typing import Optional, List, Dict
  8. import asyncio
  9. app = FastAPI()
  10. # 配置日志
  11. logging.basicConfig(level=logging.INFO)
  12. logger = logging.getLogger(__name__)
  13. class Query(BaseModel):
  14. question: str
  15. userId: str
  16. async def call_generate_sql(question: str, userId: str) -> Optional[List]:
  17. url = f"http://192.168.0.106:9040/ai/ollama/syncChatArr?prompt={question},其中不要关心数字内容"
  18. try:
  19. async with aiohttp.ClientSession() as session:
  20. async with session.get(url) as response:
  21. if response.status == 200:
  22. return await response.json()
  23. else:
  24. logger.error(f"Error calling generate-sql service: {response.status} - {await response.text()}")
  25. return None
  26. except Exception as e:
  27. logger.error(f"Unexpected error while calling generate-sql service: {e}")
  28. return None
  29. async def call_llm(question: str, userId: str) -> Optional[Dict]:
  30. url = "http://192.168.100.100:7861/chat/kb_chat"
  31. payload = {
  32. # ... 保留原有的payload配置
  33. }
  34. combined_text = ""
  35. all_docs = []
  36. think_blocks = [] # 存储所有的think块内容
  37. current_think_block = "" # 当前的think块内容
  38. in_think_block = False
  39. try:
  40. async with aiohttp.ClientSession() as session:
  41. async with session.post(url, json=payload) as response:
  42. if response.status == 200:
  43. async for line in response.content:
  44. decoded_line = line.decode('utf-8')
  45. if decoded_line.startswith('data: '):
  46. json_str = decoded_line[len('data: '):]
  47. try:
  48. json_obj = json.loads(json_str)
  49. # 提取choices中的内容并合并
  50. for choice in json_obj.get("choices", []):
  51. content = choice.get("delta", {}).get("content", "")
  52. if content:
  53. while "<think>" in content or "</think>" in content:
  54. start_index = content.find("<think>")
  55. end_index = content.find("</think>")
  56. if start_index != -1 and end_index != -1:
  57. if not in_think_block:
  58. in_think_block = True
  59. think_part = content[start_index + len("<think>"):end_index]
  60. current_think_block += think_part
  61. think_blocks.append(current_think_block.strip())
  62. current_think_block = ""
  63. content = content[end_index + len("</think>"):]
  64. in_think_block = False
  65. elif start_index != -1:
  66. if not in_think_block:
  67. in_think_block = True
  68. content = content[start_index + len("<think>"):]
  69. elif end_index != -1:
  70. current_think_block += content[:end_index]
  71. think_blocks.append(current_think_block.strip())
  72. current_think_block = ""
  73. content = content[end_index + len("</think>"):]
  74. in_think_block = False
  75. break
  76. else:
  77. break
  78. if in_think_block:
  79. current_think_block += content
  80. else:
  81. combined_text += content
  82. docs = json_obj.get("docs", [])
  83. all_docs.extend(docs)
  84. except json.JSONDecodeError as e:
  85. logger.error(f"无法解析的块: {json_str[:50]}...")
  86. logger.error(f"错误信息: {e}")
  87. return {"answer": combined_text, "docs": all_docs, "think": think_blocks}
  88. else:
  89. logger.error(f"Error calling service: {response.status} - {await response.text()}")
  90. return None
  91. except Exception as e:
  92. logger.error(f"Unexpected error while calling service: {e}")
  93. return None
  94. async def call_rasa(question: str, userId: str) -> dict:
  95. url = "http://localhost:5005/webhooks/rest/webhook"
  96. payload = {"userId": userId, "message": question}
  97. async def _call_rasa():
  98. try:
  99. async with aiohttp.ClientSession() as session:
  100. async with session.post(url, json=payload) as response:
  101. if response.status == 200:
  102. content_type = response.headers.get("Content-Type", "")
  103. if "application/json" in content_type:
  104. data = await response.json()
  105. logger.info(f"Response from Rasa: {data}")
  106. # 判断是否是简单的文本回复
  107. if (isinstance(data, list) and len(data) == 1 and
  108. isinstance(data[0], dict) and
  109. 'recipient_id' in data[0] and
  110. 'text' in data[0]):
  111. # 这里只返回 text 的内容
  112. return {
  113. "answer": data[0].get("text"),
  114. "ossId": None,
  115. "fileName": None,
  116. "filePage": None,
  117. "graph": None
  118. }
  119. answer = data[0].get("text", "No answer found")
  120. entity_lists = data[1].get("text", "No graph found")[1:-1] if len(data) > 1 else "None"
  121. entity_str = entity_lists.replace("'", '"')
  122. try:
  123. entity_data = json.loads(entity_str)
  124. doc = entity_data['docs'][0]
  125. ossid = doc['ossId']
  126. file_name = doc['file_name']
  127. filePage = doc["filePage"]
  128. graph = {
  129. "data": entity_data['data'],
  130. "links": entity_data['links']
  131. }
  132. return {
  133. "answer": answer,
  134. "ossId": str(ossid),
  135. "fileName": file_name,
  136. "filePage": filePage,
  137. "graph": str(graph)
  138. }
  139. except (json.JSONDecodeError, KeyError, IndexError):
  140. logger.error("Failed to parse additional data from Rasa response.")
  141. # 如果没有额外的数据或解析失败,则只返回基本的回答
  142. return {
  143. "answer": answer,
  144. "ossId": None,
  145. "fileName": None,
  146. "filePage": None,
  147. "graph": None
  148. }
  149. else:
  150. logger.error(f"Invalid content type: {content_type}")
  151. raise HTTPException(status_code=500, detail="Invalid response content type")
  152. else:
  153. logger.error(f"Error calling Rasa server: {response.status} - {await response.text()}")
  154. raise HTTPException(status_code=response.status, detail="Error while calling Rasa Server")
  155. except Exception as e:
  156. logger.error(f"Unexpected error while calling Rasa server: {e}")
  157. raise HTTPException(status_code=500, detail="Internal Server Error")
  158. try:
  159. # 使用 asyncio.wait_for 设置超时时间为10秒
  160. result = await asyncio.wait_for(_call_rasa(), timeout=10.0)
  161. return result
  162. except asyncio.TimeoutError:
  163. # 如果超时,返回空结果
  164. return {
  165. "answer": None,
  166. "ossId": None,
  167. "fileName": None,
  168. "filePage": None,
  169. "graph": None
  170. }
  171. @app.post("/kgqa/ask")
  172. async def ask_rasa(query: Query):
  173. tasks = [
  174. call_generate_sql(query.question, query.userId),
  175. call_llm(query.question, query.userId),
  176. call_rasa(query.question, query.userId)
  177. ]
  178. results = await asyncio.gather(*tasks, return_exceptions=True)
  179. consolidated_response = {"userId": query.userId}
  180. for idx, result in enumerate(results):
  181. if isinstance(result, Exception):
  182. logger.error(f"Error occurred during request {idx}: {result}")
  183. continue
  184. if idx == 0:
  185. consolidated_response.update({"sqlAnswer": result})
  186. elif idx == 1 and result is not None:
  187. consolidated_response.update({"llmAnswer": result})
  188. elif idx == 2 and result is not None:
  189. if result['graph'] == None:
  190. result = None
  191. consolidated_response.update({"graphAnswer": result})
  192. return {"code": 200, "msg": "操作成功", "data": consolidated_response}
  193. @app.exception_handler(HTTPException)
  194. async def http_exception_handler(request: Request, exc: HTTPException):
  195. return JSONResponse(
  196. status_code=exc.status_code,
  197. content={"detail": exc.detail},
  198. )
  199. if __name__ == "__main__":
  200. import uvicorn
  201. uvicorn.run(app, host="0.0.0.0", port=7074)