Bläddra i källkod

feat: 目标毁伤评估

28968 8 månader sedan
förälder
incheckning
e3ec9eab67

+ 62 - 55
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/DataAugmentationController.java

@@ -82,77 +82,84 @@ public class DataAugmentationController extends BaseController {
         try {
             DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
 
-            Path inputPath = null;
-            Path outputPath = null;
-            String osName = System.getProperty("os.name");
-            // 判断是否是Windows环境
-//            if (osName.toLowerCase().contains("windows")) {
-//                inputPath = Paths.get("D:", dataAugmentation.getInputPath());
-//                outputPath = Paths.get("D:", dataAugmentation.getOutputPath());
-//            } else {
-//                inputPath = Paths.get(dataAugmentation.getInputPath());
-//                outputPath = Paths.get(dataAugmentation.getOutputPath());
-//            }
-            inputPath = Paths.get(dataAugmentation.getInputPath());
-            outputPath = Paths.get(dataAugmentation.getOutputPath());
-            System.out.println("inputPath: " + inputPath.toString());
-            System.out.println("outputPath: " + outputPath.toString());
-            Path imagePath = getImageAtPathIdx(inputPath, idx);
+            Path inputPath = Paths.get(dataAugmentation.getInputPath());
+            Path outputPath = Paths.get(dataAugmentation.getOutputPath());
+//            System.out.println("inputPath: " + inputPath.toString());
+//            System.out.println("outputPath: " + outputPath.toString());
+            if (!Files.exists(inputPath) || !Files.isDirectory(inputPath)) {
+                System.out.println("输入路径不存在或不是目录:" + inputPath.toString());
+                return ResponseEntity.status(500).build();
+            }
+            if (!Files.exists(outputPath) || !Files.isDirectory(outputPath)) {
+                System.out.println("输出路径不存在或不是目录:" + outputPath.toString());
+                return ResponseEntity.status(500).build();
+            }
+            Path imagePath = getImageAtPathIdx(inputPath, idx);  //按照自然排序获取索引为idx的imagePath
+            List<String> outputFileList = new ArrayList<>(); //初始化结果文件list
+            List<String> inputFileList = new ArrayList<>(); //初始化输入文件list
 
-            String fileName = imagePath.getFileName().toString();
-            Map<String, List<String>> images = new HashMap<>();
-            //图像拼接算法有多个输入图片
+
+            //图像拼接算法、目标毁伤评估有多张输入图片,则imagePath是个目录
             if ("侦察图像拼接算法_sift".equals(dataAugmentation.getTaskType()) || "侦察图像拼接算法_coordinate".equals(dataAugmentation.getTaskType())) {
                 String lastDirectoryName = imagePath.getFileName().toString();
-                outputPath = outputPath.resolve(lastDirectoryName);  //得到推理结果目录
-                Stream<Path> stream = Files.list(outputPath);
+                inputPath = imagePath;
+                outputPath = outputPath.resolve(lastDirectoryName);  //得到算法输出的结果目录
+
                 if (Files.exists(outputPath) && Files.isDirectory(outputPath)) {
+                    Stream<Path> outFilePathStream = Files.list(outputPath);
                     // 检查流中是否有下一个元素
-                    Path firstRegularFile = stream.filter(Files::isRegularFile).findFirst().orElse(null);
-                    if (firstRegularFile != null) {
-                        // 找到第一个常规文件
-                            fileName = firstRegularFile.getFileName().toString();
-                            System.out.println("找到的第一个常规文件: " + fileName);
-
-
-                    } else {
-                        System.out.println("图像拼接算法未在结果目录下保存结果文件!");
-                        return ResponseEntity.status(500).build();
-                    }
+                    outputFileList = outFilePathStream
+                        .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录、排除图像拼接算法_sift输入中的txt文件)
+                        .filter(path -> !path.toString().toLowerCase().endsWith(".txt")).map(path -> {
+                            return path.getFileName().toString();
+                        }).collect(Collectors.toList());
                 } else {
                     System.out.println("图像拼接算法未创建结果目录!");
                     return ResponseEntity.status(500).build();
                 }
                 // 收集所有文件的路径,并编码为Base64字符串
-                Stream<Path> newPathStream = Files.list(imagePath);
-                List<String> origin = newPathStream
-                    .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录)
+                Stream<Path> inputFilePathStream = Files.list(imagePath);
+                inputFileList = inputFilePathStream
+                    .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录、排除图像拼接算法_sift输入中的txt文件
                     .filter(path -> !path.toString().toLowerCase().endsWith(".txt")).map(path -> {
-                        try {
-                            return Base64.getEncoder().encodeToString(Files.readAllBytes(path));
-                        } catch (IOException e) {
-                            throw new RuntimeException(e);
-                        }
-                    })
-                    .collect(Collectors.toList());
-                images.put("origin", origin);
+                            return path.getFileName().toString();
+                    }).collect(Collectors.toList());
             } else {
-                byte[] image1 = Files.readAllBytes(imagePath);
-                String base64Image1 = Base64.getEncoder().encodeToString(image1);
-                ArrayList<String> origin = new ArrayList<>();
-                origin.add(base64Image1);
-                images.put("origin", origin);
+                //算法输入是一对一时,输出文件名称和输入文件名称相同
+                String inputFileName = imagePath.getFileName().toString();
+                inputFileList.add(inputFileName);
+                outputFileList.add(inputFileName);
             }
-            Path resolve = outputPath.resolve(fileName);
-            if (!Files.exists(resolve)) {
-                System.out.println("结果文件不存在:" + resolve.toString());
+            Map<String, List<String>> images = new HashMap<>();
+            ArrayList<String> stable = new ArrayList<>();
+            ArrayList<String> origin = new ArrayList<>();
+            if (inputFileList.isEmpty()) {
+                System.out.println(inputPath.toString() + ":输入文件为空" );
                 return ResponseEntity.status(500).build();
             }
-            byte[] image2 = Files.readAllBytes(resolve);
-            String base64Image2 = Base64.getEncoder().encodeToString(image2);   // 将图片编码成Base64字符串
-            ArrayList<String> stable = new ArrayList<>();
-            stable.add(base64Image2);
+            if (outputFileList.isEmpty()) {
+                System.out.println(outputFileList.toString() + ":输入文件为空" );
+                return ResponseEntity.status(500).build();
+            }
+            for (String inputFileName: inputFileList) {
+                Path inputFilePath = inputPath.resolve(inputFileName);
+                if (!Files.exists(inputFilePath)) {
+                    System.out.println("输入文件不存在:" + inputFilePath.toString());
+                    return ResponseEntity.status(500).build();
+                }
+                origin.add(Base64.getEncoder().encodeToString(Files.readAllBytes(inputFilePath)));
+            }
+            for (String outputFileName: outputFileList) {
+                Path outputFilePath = outputPath.resolve(outputFileName);
+                if (!Files.exists(outputFilePath)) {
+                    System.out.println("结果文件不存在:" + outputFilePath.toString());
+                    return ResponseEntity.status(500).build();
+                }
+                stable.add(Base64.getEncoder().encodeToString(Files.readAllBytes(outputFilePath)));
+            }
+            images.put("origin", origin);
             images.put("stable", stable);
+
             // 返回Map
             return ResponseEntity.ok(images);
         } catch (Exception e) {

+ 1 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAugmentationServiceImpl.java

@@ -279,7 +279,7 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
             else {
 //                dataAugmentation.setStatus(BizConstant.ModelStatus.FAILED);
 //                updateById(dataAugmentation);
-                return CommonResult.fail("任务开始失败,请检查算法服务!");
+                return CommonResult.fail("任务开始失败,请检查算法服务是否启动!");
             }
         } catch (IOException e) {
             e.printStackTrace();