Selaa lähdekoodia

Merge branch 'develop' of http://47.108.150.237:10000/www/taais into develop

WANGKANG 8 kuukautta sitten
vanhempi
sitoutus
313e82562e

+ 64 - 23
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/DataAugmentationController.java

@@ -9,6 +9,8 @@ import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.*;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -20,7 +22,9 @@ import com.taais.biz.domain.bo.VideoStableStartResultBo;
 import com.taais.biz.domain.vo.DataAugmentationVo;
 import com.taais.biz.service.IVideoStableService;
 import com.taais.common.core.service.OssService;
+import com.taais.system.domain.SysDictData;
 import com.taais.system.domain.SysOss;
+import com.taais.system.service.ISysDictDataService;
 import com.taais.system.service.ISysOssService;
 import jakarta.validation.Valid;
 import lombok.RequiredArgsConstructor;
@@ -60,6 +64,8 @@ public class DataAugmentationController extends BaseController {
     @Autowired
     private ISysOssService ossService;
     @Autowired
+    private ISysDictDataService sysDictDataService;
+    @Autowired
     private IVideoStableService videoStableService;
     @GetMapping("/compare/num/{task_id}")
 
@@ -76,31 +82,46 @@ public class DataAugmentationController extends BaseController {
                 .orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
         }
     }
-
+    //提取文件名称中数字进行排序
+    private static int extractNumber(String fileName) {
+        Pattern pattern = Pattern.compile("^\\d+");
+        Matcher matcher = pattern.matcher(fileName);
+        if (matcher.find()) {
+            return Integer.parseInt(matcher.group());
+        }
+        // 如果文件名不是以数字开头,则返回一个很大的数,使其排在最后
+        return -1;
+    }
     @GetMapping("/compare/{task_id}/{idx}")
     public ResponseEntity<Map<String,List<String>>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
         try {
+            Map<String, List<String>> images = new HashMap<>();
+            ArrayList<String> stable = new ArrayList<>();
+            ArrayList<String> origin = new ArrayList<>();
             DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
-
             Path inputPath = Paths.get(dataAugmentation.getInputPath());
             Path outputPath = Paths.get(dataAugmentation.getOutputPath());
 //            System.out.println("inputPath: " + inputPath.toString());
 //            System.out.println("outputPath: " + outputPath.toString());
             if (!Files.exists(inputPath) || !Files.isDirectory(inputPath)) {
                 System.out.println("输入路径不存在或不是目录:" + inputPath.toString());
-                return ResponseEntity.status(500).build();
+                origin.add("输入路径不存在或不是目录:" + inputPath.toString());
+                images.put("error", origin);
+                return ResponseEntity.status(500).body(images);
             }
             if (!Files.exists(outputPath) || !Files.isDirectory(outputPath)) {
                 System.out.println("输出路径不存在或不是目录:" + outputPath.toString());
-                return ResponseEntity.status(500).build();
+                origin.add("输出路径不存在或不是目录:" + outputPath.toString());
+                images.put("error", origin);
+                return ResponseEntity.status(500).body(images);
             }
             Path imagePath = getImageAtPathIdx(inputPath, idx);  //按照自然排序获取索引为idx的imagePath
             List<String> outputFileList = new ArrayList<>(); //初始化结果文件list
             List<String> inputFileList = new ArrayList<>(); //初始化输入文件list
 
-
+            String taskType = dataAugmentation.getTaskType();
             //图像拼接算法、目标毁伤评估有多张输入图片,则imagePath是个目录
-            if ("侦察图像拼接算法_sift".equals(dataAugmentation.getTaskType()) || "侦察图像拼接算法_coordinate".equals(dataAugmentation.getTaskType())) {
+            if ("侦察图像拼接算法_sift".equals(taskType) || "侦察图像拼接算法_coordinate".equals(taskType) || "目标毁伤评估".equals(taskType)) {
                 String lastDirectoryName = imagePath.getFileName().toString();
                 inputPath = imagePath;
                 outputPath = outputPath.resolve(lastDirectoryName);  //得到算法输出的结果目录
@@ -113,11 +134,33 @@ public class DataAugmentationController extends BaseController {
                         .filter(path -> !path.toString().toLowerCase().endsWith(".txt")).map(path -> {
                             return path.getFileName().toString();
                         }).collect(Collectors.toList());
+                    if ("目标毁伤评估".equals(taskType)) {
+                        QueryWrapper queryWrapper = new QueryWrapper();
+                        queryWrapper.eq("dict_type", "biz_data_augmentation");
+                        queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
+                        List<SysDictData> list = sysDictDataService.list(queryWrapper);
+                        if (list.isEmpty()) {
+                            origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
+                            images.put("error", origin);
+                            return ResponseEntity.status(500).body(images);
+                        } else {
+                            String[] split = list.get(0).getDictValue().split(",");
+                            for (String fileName: split) {
+                                if (!outputFileList.contains(fileName)) {
+                                    origin.add("输出目录不存在结果文件:" + fileName);
+                                    images.put("error", origin);
+                                    return ResponseEntity.status(500).body(images);
+                                }
+                            }
+                        }
+                        outputFileList.sort(Comparator.comparingInt(DataAugmentationController::extractNumber));
+                    }
                 } else {
-                    System.out.println("图像拼接算法未创建结果目录!");
-                    return ResponseEntity.status(500).build();
+                    System.out.println("图像拼接算法任务:" + dataAugmentation.getId() + "未创建结果目录!");
+                    origin.add("图像拼接算法任务:" + dataAugmentation.getId() + "未创建结果目录!");
+                    images.put("error", origin);
+                    return ResponseEntity.status(500).body(images);
                 }
-                // 收集所有文件的路径,并编码为Base64字符串
                 Stream<Path> inputFilePathStream = Files.list(imagePath);
                 inputFileList = inputFilePathStream
                     .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录、排除图像拼接算法_sift输入中的txt文件)
@@ -130,31 +173,29 @@ public class DataAugmentationController extends BaseController {
                 inputFileList.add(inputFileName);
                 outputFileList.add(inputFileName);
             }
-            Map<String, List<String>> images = new HashMap<>();
-            ArrayList<String> stable = new ArrayList<>();
-            ArrayList<String> origin = new ArrayList<>();
+
             if (inputFileList.isEmpty()) {
                 System.out.println(inputPath.toString() + ":输入文件为空" );
-                return ResponseEntity.status(500).build();
+                origin.add(inputPath.toString() + ":输入文件为空");
+                images.put("error", origin);
+                return ResponseEntity.status(500).body(images);
             }
             if (outputFileList.isEmpty()) {
-                System.out.println(outputFileList.toString() + ":输入文件为空" );
-                return ResponseEntity.status(500).build();
+                System.out.println(outputFileList.toString() + ":输出文件为空" );
+                origin.add(outputFileList.toString() + ":输入文件为空");
+                images.put("error", origin);
+                return ResponseEntity.status(500).body(images);
             }
             for (String inputFileName: inputFileList) {
                 Path inputFilePath = inputPath.resolve(inputFileName);
-                if (!Files.exists(inputFilePath)) {
-                    System.out.println("输入文件不存在:" + inputFilePath.toString());
-                    return ResponseEntity.status(500).build();
-                }
                 origin.add(Base64.getEncoder().encodeToString(Files.readAllBytes(inputFilePath)));
             }
             for (String outputFileName: outputFileList) {
                 Path outputFilePath = outputPath.resolve(outputFileName);
-                if (!Files.exists(outputFilePath)) {
-                    System.out.println("结果文件不存在:" + outputFilePath.toString());
-                    return ResponseEntity.status(500).build();
-                }
+//                if (!Files.exists(outputFilePath)) {
+//                    System.out.println("结果文件不存在:" + outputFilePath.toString());
+//                    return ResponseEntity.status(500).build();
+//                }
                 stable.add(Base64.getEncoder().encodeToString(Files.readAllBytes(outputFilePath)));
             }
             images.put("origin", origin);

+ 11 - 11
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAugmentationServiceImpl.java

@@ -438,19 +438,19 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
         File[] files = directory.listFiles();
 
         // 初始化文件计数器
-        int count = 0;
+//        int count = 0;
 
         // 遍历文件和子文件夹
-        for (File file : files) {
-            if (file.isFile()) {
-                // 如果是文件,计数器加1
-                count++;
-            } else if (file.isDirectory()) {
-                // 如果是子文件夹,递归调用countFiles方法
-                count += countFiles(file);
-            }
-        }
+//        for (File file : files) {
+//            if (file.isFile()) {
+//                // 如果是文件,计数器加1
+//                count++;
+//            } else if (file.isDirectory()) {
+//                // 如果是子文件夹,递归调用countFiles方法
+//                count += countFiles(file);
+//            }
+//        }
 
-        return count;
+        return files.length;
     }
 }