123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- from fastapi import FastAPI, HTTPException, Request
- from pydantic import BaseModel
- import aiohttp
- import logging
- from fastapi.responses import JSONResponse
- import json
- from typing import Optional, List, Dict
- import asyncio
- app = FastAPI()
- # 配置日志
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- class Query(BaseModel):
- question: str
- userId: str
- async def call_generate_sql(question: str, userId: str) -> Optional[List]:
- url = f"http://192.168.72.100:9040/ai/ollama/syncChatArr?prompt={question},其中不要关心数字内容"
- try:
- async with aiohttp.ClientSession() as session:
- async with session.get(url) as response:
- if response.status == 200:
- return await response.json()
- else:
- logger.error(f"Error calling generate-sql service: {response.status} - {await response.text()}")
- return None
- except Exception as e:
- logger.error(f"Unexpected error while calling generate-sql service: {e}")
- return None
- async def call_llm(question: str, userId: str) -> Optional[Dict]:
- url = "http://192.168.72.100:7861/chat/kb_chat"
- # qusetion = f"我的的问题是{question},R1,你可以在知识库中帮我找到问题的答案么,其中我想要的输出和原文一样就可以,尽可能的将完整的内容回复给我就好,如果在原文中有明确的列出了某个问题的现象/原因/原理/排故流程等内容的话,请将原文的内容输出给我就好,"
- als =f"{question}?,不要生成,不要统计,只需要展示原文和相近问题的回复有什么"
- print("question", als)
- # payload = {
- # "query": question,
- # "mode": "local_kb",
- # "kb_name": "lqbz",
- # "top_k": 3,
- # "score_threshold": 2,
- # "history": [
- # {
- # "content": "",
- # "role": "user"
- # },
- # {
- # "content": "",
- # "role": "assistant"
- # }
- # ],
- # "stream": True,
- # "model": "qwen2.5-coder:7b",
- # "temperature": 0.7,
- # "max_tokens": 0,
- # "prompt_name": "default",
- # "return_direct": False
- # }
- #
- payload = {
- "query": als,
- "mode": "local_kb",
- "kb_name": "lqbz",
- "top_k": 3,
- "score_threshold": 2,
- "history": [
- {
- "content": '',
- "role": "user"
- },
- {
- "content": "",
- "role": "assistant"
- }
- ],
- "stream": True, # 设置为True以启用流式传输
- "model": "deepseek-r1:7b",
- "temperature": 0.7,
- "max_tokens": 0,
- "prompt_name": "default",
- "return_direct": False
- }
- combined_text = ""
- all_docs = []
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url, json=payload) as response:
- if response.status == 200:
- async for line in response.content:
- decoded_line = line.decode('utf-8')
- if decoded_line.startswith('data: '):
- json_str = decoded_line[len('data: '):]
- try:
- json_obj = json.loads(json_str)
- # 提取choices中的内容并合并
- for choice in json_obj.get("choices", []):
- content = choice.get("delta", {}).get("content", "")
- if content:
- combined_text += content # 直接追加,保留换行符
- # 收集docs中的内容
- docs = json_obj.get("docs", [])
- all_docs.extend(docs)
- except json.JSONDecodeError as e:
- logger.error(f"无法解析的块: {json_str[:50]}...")
- logger.error(f"错误信息: {e}")
- return {"answer": combined_text, "docs": all_docs}
- else:
- logger.error(f"Error calling service: {response.status} - {await response.text()}")
- return None
- except Exception as e:
- logger.error(f"Unexpected error while calling service: {e}")
- return None
- async def call_rasa(question: str, userId: str) -> dict:
- url = "http://localhost:5005/webhooks/rest/webhook"
- payload = {"userId": userId, "message": question}
- async def _call_rasa():
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(url, json=payload) as response:
- if response.status == 200:
- content_type = response.headers.get("Content-Type", "")
- if "application/json" in content_type:
- data = await response.json()
- logger.info(f"Response from Rasa: {data}")
- # 判断是否是简单的文本回复
- if (isinstance(data, list) and len(data) == 1 and
- isinstance(data[0], dict) and
- 'recipient_id' in data[0] and
- 'text' in data[0]):
- # 这里只返回 text 的内容
- return {
- "answer": data[0].get("text"),
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- answer = data[0].get("text", "No answer found")
- entity_lists = data[1].get("text", "No graph found")[1:-1] if len(data) > 1 else "None"
- entity_str = entity_lists.replace("'", '"')
- try:
- entity_data = json.loads(entity_str)
- doc = entity_data['docs'][0]
- ossid = doc['ossId']
- file_name = doc['file_name']
- filePage = doc["filePage"]
- graph = {
- "data": entity_data['data'],
- "links": entity_data['links']
- }
- return {
- "answer": answer,
- "ossId": str(ossid),
- "fileName": file_name,
- "filePage": filePage,
- "graph": str(graph)
- }
- except (json.JSONDecodeError, KeyError, IndexError):
- logger.error("Failed to parse additional data from Rasa response.")
- # 如果没有额外的数据或解析失败,则只返回基本的回答
- return {
- "answer": answer,
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- else:
- logger.error(f"Invalid content type: {content_type}")
- raise HTTPException(status_code=500, detail="Invalid response content type")
- else:
- logger.error(f"Error calling Rasa server: {response.status} - {await response.text()}")
- raise HTTPException(status_code=response.status, detail="Error while calling Rasa Server")
- except Exception as e:
- logger.error(f"Unexpected error while calling Rasa server: {e}")
- raise HTTPException(status_code=500, detail="Internal Server Error")
- try:
- # 使用 asyncio.wait_for 设置超时时间为10秒
- result = await asyncio.wait_for(_call_rasa(), timeout=10.0)
- return result
- except asyncio.TimeoutError:
- # 如果超时,返回空结果
- return {
- "answer": None,
- "ossId": None,
- "fileName": None,
- "filePage": None,
- "graph": None
- }
- @app.post("/kgqa/ask")
- async def ask_rasa(query: Query):
- tasks = [
- call_generate_sql(query.question, query.userId),
- call_llm(query.question, query.userId),
- call_rasa(query.question, query.userId)
- ]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- consolidated_response = {"userId": query.userId}
- for idx, result in enumerate(results):
- if isinstance(result, Exception):
- logger.error(f"Error occurred during request {idx}: {result}")
- continue
- if idx == 0:
- consolidated_response.update({"sqlAnswer": result})
- elif idx == 1 and result is not None:
- # print(result)
- res = result['answer']
- if "<think>" in res:
- think = res.split("</think>")[0]
- think = think.replace("<think>", "")
- answer = res.split("</think>")[1]
- else:
- think = ""
- answer = res
- docs = result['docs']
- result = {"think": think, "answer": answer, "docs": docs}
- consolidated_response.update({"llmAnswer": result})
- elif idx == 2 and result is not None:
- if result['graph'] == None:
- result = None
- consolidated_response.update({"graphAnswer": result})
- return {"code": 200, "msg": "操作成功", "data": consolidated_response}
- @app.exception_handler(HTTPException)
- async def http_exception_handler(request: Request, exc: HTTPException):
- return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.detail},
- )
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=7074)
|