123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- from fastapi import FastAPI, HTTPException, Request
- from pydantic import BaseModel
- import aiohttp
- import logging
- from fastapi.responses import JSONResponse
- import json
- from typing import Optional, List, Dict
- import asyncio
- app = FastAPI()
- # 配置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- class Query(BaseModel):
- question: str
- userId: str
- async def call_generate_sql(question: str, userId: str) -> Optional[List]:
- url = f"http://192.168.0.106:9040/ai/ollama/syncChatArr?prompt={question},其中不要关心数字内容"
- try:
- async with aiohttp.ClientSession() as session:
- async with session.get(url) as response:
- if response.status == 200:
- return await response.json()
- else:
- logger.error(f"Error calling generate-sql service: {response.status} - {await response.text()}")
- return None
- except Exception as e:
- logger.error(f"Unexpected error while calling generate-sql service: {e}")
- return None
- async def call_llm(question: str, userId: str) -> Optional[Dict]:
- url = "http://192.168.100.100:7861/chat/kb_chat"
- payload = {
- # ... 保留原有的payload配置
- }
- combined_text = ""
- all_docs = []
- think_blocks = [] # 存储所有的think块内容
- current_think_block = "" # 当前的think块内容
- in_think_block = False
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url, json=payload) as response:
- if response.status == 200:
- async for line in response.content:
- decoded_line = line.decode('utf-8')
- if decoded_line.startswith('data: '):
- json_str = decoded_line[len('data: '):]
- try:
- json_obj = json.loads(json_str)
- # 提取choices中的内容并合并
- for choice in json_obj.get("choices", []):
- content = choice.get("delta", {}).get("content", "")
- if content:
- while "<think>" in content or "</think>" in content:
- start_index = content.find("<think>")
- end_index = content.find("</think>")
- if start_index != -1 and end_index != -1:
- if not in_think_block:
- in_think_block = True
- think_part = content[start_index + len("<think>"):end_index]
- current_think_block += think_part
- think_blocks.append(current_think_block.strip())
- current_think_block = ""
- content = content[end_index + len("</think>"):]
- in_think_block = False
- elif start_index != -1:
- if not in_think_block:
- in_think_block = True
- content = content[start_index + len("<think>"):]
- elif end_index != -1:
- current_think_block += content[:end_index]
- think_blocks.append(current_think_block.strip())
- current_think_block = ""
- content = content[end_index + len("</think>"):]
- in_think_block = False
- break
- else:
- break
- if in_think_block:
- current_think_block += content
- else:
- combined_text += content
- docs = json_obj.get("docs", [])
- all_docs.extend(docs)
- except json.JSONDecodeError as e:
- logger.error(f"无法解析的块: {json_str[:50]}...")
- logger.error(f"错误信息: {e}")
- return {"answer": combined_text, "docs": all_docs, "think": think_blocks}
- else:
- logger.error(f"Error calling service: {response.status} - {await response.text()}")
- return None
- except Exception as e:
- logger.error(f"Unexpected error while calling service: {e}")
- return None
- async def call_rasa(question: str, userId: str) -> dict:
- url = "http://localhost:5005/webhooks/rest/webhook"
- payload = {"userId": userId, "message": question}
- async def _call_rasa():
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url, json=payload) as response:
- if response.status == 200:
- content_type = response.headers.get("Content-Type", "")
- if "application/json" in content_type:
- data = await response.json()
- logger.info(f"Response from Rasa: {data}")
- # 判断是否是简单的文本回复
- if (isinstance(data, list) and len(data) == 1 and
- isinstance(data[0], dict) and
- 'recipient_id' in data[0] and
- 'text' in data[0]):
- # 这里只返回 text 的内容
- return {
- "answer": data[0].get("text"),
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- answer = data[0].get("text", "No answer found")
- entity_lists = data[1].get("text", "No graph found")[1:-1] if len(data) > 1 else "None"
- entity_str = entity_lists.replace("'", '"')
- try:
- entity_data = json.loads(entity_str)
- doc = entity_data['docs'][0]
- ossid = doc['ossId']
- file_name = doc['file_name']
- filePage = doc["filePage"]
- graph = {
- "data": entity_data['data'],
- "links": entity_data['links']
- }
- return {
- "answer": answer,
- "ossId": str(ossid),
- "fileName": file_name,
- "filePage": filePage,
- "graph": str(graph)
- }
- except (json.JSONDecodeError, KeyError, IndexError):
- logger.error("Failed to parse additional data from Rasa response.")
- # 如果没有额外的数据或解析失败,则只返回基本的回答
- return {
- "answer": answer,
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- else:
- logger.error(f"Invalid content type: {content_type}")
- raise HTTPException(status_code=500, detail="Invalid response content type")
- else:
- logger.error(f"Error calling Rasa server: {response.status} - {await response.text()}")
- raise HTTPException(status_code=response.status, detail="Error while calling Rasa Server")
- except Exception as e:
- logger.error(f"Unexpected error while calling Rasa server: {e}")
- raise HTTPException(status_code=500, detail="Internal Server Error")
- try:
- # 使用 asyncio.wait_for 设置超时时间为10秒
- result = await asyncio.wait_for(_call_rasa(), timeout=10.0)
- return result
- except asyncio.TimeoutError:
- # 如果超时,返回空结果
- return {
- "answer": None,
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- @app.post("/kgqa/ask")
- async def ask_rasa(query: Query):
- tasks = [
- call_generate_sql(query.question, query.userId),
- call_llm(query.question, query.userId),
- call_rasa(query.question, query.userId)
- ]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- consolidated_response = {"userId": query.userId}
- for idx, result in enumerate(results):
- if isinstance(result, Exception):
- logger.error(f"Error occurred during request {idx}: {result}")
- continue
- if idx == 0:
- consolidated_response.update({"sqlAnswer": result})
- elif idx == 1 and result is not None:
- consolidated_response.update({"llmAnswer": result})
- elif idx == 2 and result is not None:
- if result['graph'] == None:
- result = None
- consolidated_response.update({"graphAnswer": result})
- return {"code": 200, "msg": "操作成功", "data": consolidated_response}
- @app.exception_handler(HTTPException)
- async def http_exception_handler(request: Request, exc: HTTPException):
- return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.detail},
- )
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=7074)
|