Browse Source

feat: MASC和CAT评估结果接口

WANGKANG 7 months ago
parent
commit
72ba53aa79

+ 5 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/TrackSequenceController.java

@@ -152,4 +152,9 @@ public class TrackSequenceController extends BaseController {
     public CommonResult previewPredictResult(@PathVariable("id") Long id) {
         return trackSequenceService.previewPredictResult(id);
     }
+
+    @GetMapping("/previewEvaluateResult/{id}")
+    public CommonResult previewEvaluateResult(@PathVariable("id") Long id) {
+        return trackSequenceService.previewEvaluateResult(id);
+    }
 }

+ 2 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/service/ITrackSequenceService.java

@@ -89,4 +89,6 @@ public interface ITrackSequenceService extends IBaseService<TrackSequence> {
     CommonResult getModelList(Long id);
 
     CommonResult previewPredictResult(Long id);
+
+    CommonResult previewEvaluateResult(Long id);
 }

+ 36 - 4
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TrackSequenceServiceImpl.java

@@ -18,6 +18,7 @@ import com.mybatisflex.core.query.QueryWrapper;
 import com.taais.biz.constant.BizConstant;
 import com.taais.biz.domain.*;
 import com.taais.biz.domain.vo.*;
+import com.taais.biz.utils.CsvReadUtils;
 import com.taais.biz.utils.ZipUtils;
 import com.taais.common.core.config.TaaisConfig;
 import com.taais.common.core.constant.Constants;
@@ -388,7 +389,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
                         }
                     }
                 }
-                if (!(flag && !false && (1!=2) && (1 == 1))) {
+                if (!(flag && !false && (1 != 2) && (1 == 1))) {
                     return CommonResult.fail("数据集错误!!!!!!");
                 }
             }
@@ -501,7 +502,7 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         if ("200".equals(status) && ObjectUtil.isNull(algorithmModelTrack.getModelAddress())) {
             try {
                 algorithmModelTrack.setModelAddress(entity.getOutputPath() + File.separator + ((HashMap<String, String>) parse.get("dataset")).get("classes"));
-            }catch (Exception e) {
+            } catch (Exception e) {
                 System.out.println("未知错误,我也不知道啥原因。。");
             }
 
@@ -680,8 +681,8 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         Path inputPath = path.resolveSibling(fileName_without_suffix + BizConstant.UNZIP_SUFFIX + File.separator + "images");
         Path outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "gaze" + File.separator + "images");
 
-        File outputPathDir = new  File(outputPath.toString());
-        if(!outputPathDir.exists()) {
+        File outputPathDir = new File(outputPath.toString());
+        if (!outputPathDir.exists()) {
             outputPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "images");
         }
 
@@ -690,4 +691,35 @@ public class TrackSequenceServiceImpl extends BaseServiceImpl<TrackSequenceMappe
         return getCompareImage(urlPrefix, inputPath.toString(), outputPath.toString());
     }
 
+    @Override
+    public CommonResult previewEvaluateResult(Long id) {
+        TrackSequence entity = getById(id);
+
+        SysOssVo inputOssEntity = ossService.getById(entity.getInputOssId());
+
+        String filePath = inputOssEntity.getFileName();
+        String localPath = TaaisConfig.getProfile();
+        String resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
+
+        String fileName = StringUtils.substringAfterLast(filePath, "/");
+        String fileName_without_suffix = removeFileExtension(fileName);
+
+        Path path = Paths.get(resourcePath);
+        Path resultPath = path.resolveSibling(entity.getId().toString() + BizConstant.TRACK_SEQUENCE_SUFFIX + File.separator + "evaluate_result" + File.separator + "test.csv");
+
+        if (!new File(resultPath.toString()).exists()) {
+            return CommonResult.fail("评估结果文件不存在!");
+        }
+
+        Map<String, String> map = null;
+        try {
+            map = CsvReadUtils.readCSVForLastRowAsMap(resultPath.toString(), ",");
+        } catch (Exception e) {
+            e.printStackTrace();
+            return CommonResult.fail("读取结果集失败!");
+        }
+
+        return CommonResult.success(map);
+    }
+
 }

+ 55 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/utils/CsvReadUtils.java

@@ -0,0 +1,55 @@
+package com.taais.biz.utils;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class CsvReadUtils {
+    public static List<String[]> readCSV(String fileName, String splitChar) throws IOException {
+        List<String[]> data = new ArrayList<>();
+        BufferedReader br = new BufferedReader(new FileReader(fileName));
+        String line = "";
+        while ((line = br.readLine()) != null) {
+            String[] row = line.split(splitChar);
+            data.add(row);
+        }
+        br.close();
+        return data;
+    }
+
+    public static String[] readCSVLastRow(String fileName, String splitChar) throws IOException {
+        List<String[]> data = readCSV(fileName, splitChar);
+        return data.get(data.size() - 1);
+    }
+
+    public static String[] readCSVHeader(String fileName, String splitChar) throws IOException {
+        List<String[]> data = readCSV(fileName, splitChar);
+        return data.get(0);
+    }
+
+    public static Map<String, String> readCSVForLastRowAsMap(String fileName, String splitChar) throws IOException {
+        List<String[]> data = readCSV(fileName, splitChar);
+        if (data.size() == 0) {
+            return new HashMap<>();
+        } else if (data.size() == 1) {
+            String[] header = data.getFirst();
+            Map<String, String> map = new HashMap<>();
+            for (int i = 0; i < header.length; i++) {
+                map.put(header[i], "");
+            }
+            return map;
+        } else {
+            String[] lastRow = data.getLast();
+            String[] header = data.getFirst();
+            Map<String, String> map = new HashMap<>();
+            for (int i = 0; i < header.length; i++) {
+                map.put(header[i], lastRow[i]);
+            }
+            return map;
+        }
+    }
+}