|  | @@ -0,0 +1,198 @@
 | 
	
		
			
				|  |  | +package org.ruoyi.common.chat.localModels;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import io.micrometer.common.util.StringUtils;
 | 
	
		
			
				|  |  | +import lombok.extern.slf4j.Slf4j;
 | 
	
		
			
				|  |  | +import okhttp3.OkHttpClient;
 | 
	
		
			
				|  |  | +import org.ruoyi.common.chat.entity.models.LocalModelsSearchRequest;
 | 
	
		
			
				|  |  | +import org.ruoyi.common.chat.entity.models.LocalModelsSearchResponse;
 | 
	
		
			
				|  |  | +import org.springframework.stereotype.Service;
 | 
	
		
			
				|  |  | +import retrofit2.Call;
 | 
	
		
			
				|  |  | +import retrofit2.Callback;
 | 
	
		
			
				|  |  | +import retrofit2.Response;
 | 
	
		
			
				|  |  | +import retrofit2.Retrofit;
 | 
	
		
			
				|  |  | +import retrofit2.converter.jackson.JacksonConverterFactory;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  | +import java.util.concurrent.CountDownLatch;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +@Slf4j
 | 
	
		
			
				|  |  | +@Service
 | 
	
		
			
				|  |  | +public class LocalModelsofitClient {
 | 
	
		
			
				|  |  | +    private static final String BASE_URL = "http://127.0.0.1:5000"; // Flask 服务的 URL
 | 
	
		
			
				|  |  | +    private static Retrofit retrofit = null;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    // 获取 Retrofit 实例
 | 
	
		
			
				|  |  | +    public static Retrofit getRetrofitInstance() {
 | 
	
		
			
				|  |  | +        if (retrofit == null) {
 | 
	
		
			
				|  |  | +            OkHttpClient client = new OkHttpClient.Builder()
 | 
	
		
			
				|  |  | +                    .build();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            retrofit = new Retrofit.Builder()
 | 
	
		
			
				|  |  | +                    .baseUrl(BASE_URL)
 | 
	
		
			
				|  |  | +                    .client(client)
 | 
	
		
			
				|  |  | +                    .addConverterFactory(JacksonConverterFactory.create()) // 使用 Jackson 处理 JSON 转换
 | 
	
		
			
				|  |  | +                    .build();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return retrofit;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /**
 | 
	
		
			
				|  |  | +     * 向 Flask 服务发送文本向量化请求
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * @param queries 查询文本列表
 | 
	
		
			
				|  |  | +     * @param modelName 模型名称
 | 
	
		
			
				|  |  | +     * @param delimiter 文本分隔符
 | 
	
		
			
				|  |  | +     * @param topK 返回的结果数
 | 
	
		
			
				|  |  | +     * @param blockSize 文本块大小
 | 
	
		
			
				|  |  | +     * @param overlapChars 重叠字符数
 | 
	
		
			
				|  |  | +     * @return 返回计算得到的 Top K 嵌入向量列表
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    public static List<List<Double>> getTopKEmbeddings(
 | 
	
		
			
				|  |  | +            List<String> queries,
 | 
	
		
			
				|  |  | +            String modelName,
 | 
	
		
			
				|  |  | +            String delimiter,
 | 
	
		
			
				|  |  | +            int topK,
 | 
	
		
			
				|  |  | +            int blockSize,
 | 
	
		
			
				|  |  | +            int overlapChars) {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        modelName = (!StringUtils.isEmpty(modelName)) ? modelName : "msmarco-distilbert-base-tas-b"; // 默认模型名称
 | 
	
		
			
				|  |  | +        delimiter = (!StringUtils.isEmpty(delimiter) ) ? delimiter : ".";                             // 默认分隔符
 | 
	
		
			
				|  |  | +        topK = (topK > 0) ? topK : 3;                                                  // 默认返回 3 个结果
 | 
	
		
			
				|  |  | +        blockSize = (blockSize > 0) ? blockSize : 500;                                 // 默认文本块大小为 500
 | 
	
		
			
				|  |  | +        overlapChars = (overlapChars > 0) ? overlapChars : 50;                         // 默认重叠字符数为 50
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 创建 Retrofit 实例
 | 
	
		
			
				|  |  | +        Retrofit retrofit = getRetrofitInstance();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 创建 SearchService 接口
 | 
	
		
			
				|  |  | +        SearchService service = retrofit.create(SearchService.class);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 创建请求对象 LocalModelsSearchRequest
 | 
	
		
			
				|  |  | +        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
 | 
	
		
			
				|  |  | +                queries,            // 查询文本列表
 | 
	
		
			
				|  |  | +                modelName,          // 模型名称
 | 
	
		
			
				|  |  | +                delimiter,          // 文本分隔符
 | 
	
		
			
				|  |  | +                topK,               // 返回的结果数
 | 
	
		
			
				|  |  | +                blockSize,          // 文本块大小
 | 
	
		
			
				|  |  | +                overlapChars        // 重叠字符数
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        final CountDownLatch latch = new CountDownLatch(1);  // 创建一个 CountDownLatch
 | 
	
		
			
				|  |  | +        final List<List<Double>>[] topKEmbeddings = new List[]{null}; // 使用数组来存储结果(因为 Java 不支持直接修改 List)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 发起异步请求
 | 
	
		
			
				|  |  | +        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
 | 
	
		
			
				|  |  | +            @Override
 | 
	
		
			
				|  |  | +            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
 | 
	
		
			
				|  |  | +                if (response.isSuccessful()) {
 | 
	
		
			
				|  |  | +                    LocalModelsSearchResponse searchResponse = response.body();
 | 
	
		
			
				|  |  | +                    if (searchResponse != null) {
 | 
	
		
			
				|  |  | +                        topKEmbeddings[0] = searchResponse.getTopKEmbeddings().get(0);  // 获取结果
 | 
	
		
			
				|  |  | +                        log.info("Successfully retrieved embeddings");
 | 
	
		
			
				|  |  | +                    } else {
 | 
	
		
			
				|  |  | +                        log.error("Response body is null");
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                } else {
 | 
	
		
			
				|  |  | +                    log.error("Request failed. HTTP error code: " + response.code());
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +                latch.countDown();  // 请求完成,减少计数
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            @Override
 | 
	
		
			
				|  |  | +            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
 | 
	
		
			
				|  |  | +                t.printStackTrace();
 | 
	
		
			
				|  |  | +                log.error("Request failed: ", t);
 | 
	
		
			
				|  |  | +                latch.countDown();  // 请求失败,减少计数
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        try {
 | 
	
		
			
				|  |  | +            latch.await();  // 等待请求完成
 | 
	
		
			
				|  |  | +        } catch (InterruptedException e) {
 | 
	
		
			
				|  |  | +            e.printStackTrace();
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return topKEmbeddings[0];  // 返回结果
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +//    public static void main(String[] args) {
 | 
	
		
			
				|  |  | +//        // 示例调用
 | 
	
		
			
				|  |  | +//        List<String> queries = Arrays.asList("What is artificial intelligence?", "AI is transforming industries.");
 | 
	
		
			
				|  |  | +//        String modelName = "msmarco-distilbert-base-tas-b";
 | 
	
		
			
				|  |  | +//        String delimiter = ".";
 | 
	
		
			
				|  |  | +//        int topK = 3;
 | 
	
		
			
				|  |  | +//        int blockSize = 500;
 | 
	
		
			
				|  |  | +//        int overlapChars = 50;
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//        List<List<Double>> topKEmbeddings = getTopKEmbeddings(queries, modelName, delimiter, topK, blockSize, overlapChars);
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//        // 打印结果
 | 
	
		
			
				|  |  | +//        if (topKEmbeddings != null) {
 | 
	
		
			
				|  |  | +//            System.out.println("Top K embeddings: ");
 | 
	
		
			
				|  |  | +//            for (List<Double> embedding : topKEmbeddings) {
 | 
	
		
			
				|  |  | +//                System.out.println(embedding);
 | 
	
		
			
				|  |  | +//            }
 | 
	
		
			
				|  |  | +//        } else {
 | 
	
		
			
				|  |  | +//            System.out.println("No embeddings returned.");
 | 
	
		
			
				|  |  | +//        }
 | 
	
		
			
				|  |  | +//    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +//    public static void main(String[] args) {
 | 
	
		
			
				|  |  | +//        // 创建 Retrofit 实例
 | 
	
		
			
				|  |  | +//        Retrofit retrofit = LocalModelsofitClient.getRetrofitInstance();
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//        // 创建 SearchService 接口
 | 
	
		
			
				|  |  | +//        SearchService service = retrofit.create(SearchService.class);
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//        // 创建请求对象 LocalModelsSearchRequest
 | 
	
		
			
				|  |  | +//        LocalModelsSearchRequest request = new LocalModelsSearchRequest(
 | 
	
		
			
				|  |  | +//                Arrays.asList("What is artificial intelligence?", "AI is transforming industries."), // 查询文本列表
 | 
	
		
			
				|  |  | +//                "msmarco-distilbert-base-tas-b",  // 模型名称
 | 
	
		
			
				|  |  | +//                ".",  // 分隔符
 | 
	
		
			
				|  |  | +//                3,  // 返回的结果数
 | 
	
		
			
				|  |  | +//                500,  // 文本块大小
 | 
	
		
			
				|  |  | +//                50  // 重叠字符数
 | 
	
		
			
				|  |  | +//        );
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//        // 发起请求
 | 
	
		
			
				|  |  | +//        service.vectorize(request).enqueue(new Callback<LocalModelsSearchResponse>() {
 | 
	
		
			
				|  |  | +//            @Override
 | 
	
		
			
				|  |  | +//            public void onResponse(Call<LocalModelsSearchResponse> call, Response<LocalModelsSearchResponse> response) {
 | 
	
		
			
				|  |  | +//                if (response.isSuccessful()) {
 | 
	
		
			
				|  |  | +//                    LocalModelsSearchResponse searchResponse = response.body();
 | 
	
		
			
				|  |  | +//                    System.out.println("Response Body: " + response.body());  // Print the whole response body for debugging
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//                    if (searchResponse != null) {
 | 
	
		
			
				|  |  | +//                        // If the response is not null, process it.
 | 
	
		
			
				|  |  | +//                        // Example: Extract the embeddings and print them
 | 
	
		
			
				|  |  | +//                        List<List<List<Double>>> topKEmbeddings = searchResponse.getTopKEmbeddings();
 | 
	
		
			
				|  |  | +//                        if (topKEmbeddings != null) {
 | 
	
		
			
				|  |  | +//                            // Print the Top K embeddings
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//                        } else {
 | 
	
		
			
				|  |  | +//                            System.err.println("Top K embeddings are null");
 | 
	
		
			
				|  |  | +//                        }
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//                        // If there is more information you want to process, handle it here
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//                    } else {
 | 
	
		
			
				|  |  | +//                        System.err.println("Response body is null");
 | 
	
		
			
				|  |  | +//                    }
 | 
	
		
			
				|  |  | +//                } else {
 | 
	
		
			
				|  |  | +//                    System.err.println("Request failed. HTTP error code: " + response.code());
 | 
	
		
			
				|  |  | +//                    log.error("Failed to retrieve data. HTTP error code: " + response.code());
 | 
	
		
			
				|  |  | +//                }
 | 
	
		
			
				|  |  | +//            }
 | 
	
		
			
				|  |  | +//
 | 
	
		
			
				|  |  | +//            @Override
 | 
	
		
			
				|  |  | +//            public void onFailure(Call<LocalModelsSearchResponse> call, Throwable t) {
 | 
	
		
			
				|  |  | +//                // 请求失败,打印错误
 | 
	
		
			
				|  |  | +//                t.printStackTrace();
 | 
	
		
			
				|  |  | +//                log.error("Request failed: ", t);
 | 
	
		
			
				|  |  | +//            }
 | 
	
		
			
				|  |  | +//        });
 | 
	
		
			
				|  |  | +//    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +}
 |