| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 | 
							- from flask import Flask, request, jsonify
 
- from sentence_transformers import SentenceTransformer
 
- from sklearn.metrics.pairwise import cosine_similarity
 
- import json
 
- app = Flask(__name__)
 
- # 创建一个全局的模型缓存字典
 
- model_cache = {}
 
- # 分割文本块
 
- def split_text(text, block_size, overlap_chars, delimiter):
 
-     chunks = text.split(delimiter)
 
-     text_blocks = []
 
-     current_block = ""
 
-     for chunk in chunks:
 
-         if len(current_block) + len(chunk) + 1 <= block_size:
 
-             if current_block:
 
-                 current_block += " " + chunk
 
-             else:
 
-                 current_block = chunk
 
-         else:
 
-             text_blocks.append(current_block)
 
-             current_block = chunk
 
-     if current_block:
 
-         text_blocks.append(current_block)
 
-     overlap_blocks = []
 
-     for i in range(len(text_blocks)):
 
-         if i > 0:
 
-             overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
 
-             overlap_blocks.append(overlap_block)
 
-         overlap_blocks.append(text_blocks[i])
 
-     return overlap_blocks
 
- # 文本向量化
 
- def vectorize_text_blocks(text_blocks, model):
 
-     return model.encode(text_blocks)
 
- # 文本检索
 
- def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
 
-     # 将知识库拆分为文本块
 
-     text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
 
-     # 向量化文本块
 
-     knowledge_vectors = vectorize_text_blocks(text_blocks, model)
 
-     # 向量化查询文本
 
-     query_vector = model.encode([query]).reshape(1, -1)
 
-     # 计算相似度
 
-     similarities = cosine_similarity(query_vector, knowledge_vectors)
 
-     # 获取相似度最高的 k 个文本块的索引
 
-     top_k_indices = similarities[0].argsort()[-k:][::-1]
 
-     # 返回文本块和它们的向量
 
-     top_k_texts = [text_blocks[i] for i in top_k_indices]
 
-     top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
 
-     return top_k_texts, top_k_embeddings
 
- @app.route('/vectorize', methods=['POST'])
 
- def vectorize_text():
 
-     # 从请求中获取 JSON 数据
 
-     data = request.json
 
-     print(f"Received request data: {data}")  # 调试输出请求数据
 
-     text_list = data.get("text", [])
 
-     model_name = data.get("model_name", "msmarco-distilbert-base-tas-b")  # 默认模型
 
-     delimiter = data.get("delimiter", "\n")  # 默认分隔符
 
-     k = int(data.get("k", 3))  # 默认检索条数
 
-     block_size = int(data.get("block_size", 500))  # 默认文本块大小
 
-     overlap_chars = int(data.get("overlap_chars", 50))  # 默认重叠字符数
 
-     if not text_list:
 
-         return jsonify({"error": "Text is required."}), 400
 
-     # 检查模型是否已经加载
 
-     if model_name not in model_cache:
 
-         try:
 
-             model = SentenceTransformer(model_name)
 
-             model_cache[model_name] = model  # 缓存模型
 
-         except Exception as e:
 
-             return jsonify({"error": f"Failed to load model: {e}"}), 500
 
-     model = model_cache[model_name]
 
-     top_k_texts_all = []
 
-     top_k_embeddings_all = []
 
-     # 如果只有一个查询文本
 
-     if len(text_list) == 1:
 
-         top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
 
-         top_k_texts_all.append(top_k_texts)
 
-         top_k_embeddings_all.append(top_k_embeddings)
 
-     elif len(text_list) > 1:
 
-         # 如果多个查询文本,依次处理
 
-         for query in text_list:
 
-             top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
 
-             top_k_texts_all.append(top_k_texts)
 
-             top_k_embeddings_all.append(top_k_embeddings)
 
-     # 将嵌入向量(ndarray)转换为可序列化的列表
 
-     top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
 
-     print(f"Top K texts: {top_k_texts_all}")  # 打印检索到的文本
 
-     print(f"Top K embeddings: {top_k_embeddings_all}")  # 打印检索到的向量
 
-     # 返回 JSON 格式的数据
 
-     return jsonify({
 
-         "topKEmbeddings": top_k_embeddings_all  # 返回嵌入向量
 
-     })
 
- if __name__ == '__main__':
 
-     app.run(host="0.0.0.0", port=5000, debug=True)
 
 
  |