unknown 8 сар өмнө
parent
commit
9d065ffe2a

+ 38 - 20
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/DataAugmentationController.java

@@ -5,10 +5,8 @@ import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.nio.file.Paths;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.Stream;
 
 
 import com.mybatisflex.core.query.QueryWrapper;
 import com.mybatisflex.core.query.QueryWrapper;
@@ -65,8 +63,7 @@ public class DataAugmentationController extends BaseController {
 
 
     private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
     private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
         try (Stream<Path> paths = Files.list(inputPath)) {
         try (Stream<Path> paths = Files.list(inputPath)) {
-            return paths.filter(Files::isRegularFile)
-                .sorted()
+            return paths.sorted()
                 .skip(idx)
                 .skip(idx)
                 .findFirst()
                 .findFirst()
                 .orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
                 .orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
@@ -74,7 +71,7 @@ public class DataAugmentationController extends BaseController {
     }
     }
 
 
     @GetMapping("/compare/{task_id}/{idx}")
     @GetMapping("/compare/{task_id}/{idx}")
-    public ResponseEntity<Map<String, String>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
+    public ResponseEntity<Map<String,List<String>>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
         try {
         try {
             DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
             DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
 
 
@@ -95,22 +92,43 @@ public class DataAugmentationController extends BaseController {
             System.out.println("outputPath: " + outputPath.toString());
             System.out.println("outputPath: " + outputPath.toString());
             Path imagePath = getImageAtPathIdx(inputPath, idx);
             Path imagePath = getImageAtPathIdx(inputPath, idx);
             String fileName = imagePath.getFileName().toString();
             String fileName = imagePath.getFileName().toString();
+            Map<String, List<String>> images = new HashMap<>();
+
+            //图像拼接算法有多个输入图片
+            if ("图像拼接".equals(dataAugmentation.getTaskType())) {
+
+                Stream<Path> pathStream = Files.list(imagePath);
+                Optional<Path> firstPath = pathStream.filter(Files::isRegularFile).findFirst();
+                String firstFileName = firstPath.get().getFileName().toString();
+                int lastDotIndex = firstFileName.lastIndexOf('.');
+                fileName = fileName + firstFileName.substring(lastDotIndex);
+                // 收集所有文件的路径,并编码为Base64字符串
+                Stream<Path> newPathStream = Files.list(imagePath);
+                List<String> origin = newPathStream
+                    .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录)
+                    .map(path -> {
+                        try {
+                            return Base64.getEncoder().encodeToString(Files.readAllBytes(path));
+                        } catch (IOException e) {
+                            throw new RuntimeException(e);
+                        }
+                    })
+                    .collect(Collectors.toList());
+                images.put("origin", origin);
+            } else {
+                byte[] image1 = Files.readAllBytes(imagePath);
+                String base64Image1 = Base64.getEncoder().encodeToString(image1);
+                ArrayList<String> origin = new ArrayList<>();
+                origin.add(base64Image1);
+                images.put("origin", origin);
+            }
             Path resolve = outputPath.resolve(fileName);
             Path resolve = outputPath.resolve(fileName);
-
-
-            byte[] image1 = Files.readAllBytes(imagePath);
             byte[] image2 = Files.readAllBytes(resolve);
             byte[] image2 = Files.readAllBytes(resolve);
 
 
-
-            // 将图片编码成Base64字符串
-            String base64Image1 = Base64.getEncoder().encodeToString(image1);
-            String base64Image2 = Base64.getEncoder().encodeToString(image2);
-
-            // 创建一个Map来存储Base64编码的图片
-            Map<String, String> images = new HashMap<>();
-            images.put("origin", base64Image1);
-            images.put("stable", base64Image2);
-
+            String base64Image2 = Base64.getEncoder().encodeToString(image2);   // 将图片编码成Base64字符串
+            ArrayList<String> stable = new ArrayList<>();
+            stable.add(base64Image2);
+            images.put("stable", stable);
             // 返回Map
             // 返回Map
             return ResponseEntity.ok(images);
             return ResponseEntity.ok(images);
         } catch (Exception e) {
         } catch (Exception e) {

+ 22 - 7
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAugmentationServiceImpl.java

@@ -46,12 +46,15 @@ import org.springframework.transaction.annotation.Transactional;
 
 
 import java.io.File;
 import java.io.File;
 import java.io.IOException;
 import java.io.IOException;
+import java.nio.file.DirectoryStream;
+import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.nio.file.Paths;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Date;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.List;
+import java.util.stream.Stream;
 
 
 import static com.taais.biz.domain.table.DataAugmentationTableDef.DATA_AUGMENTATION;
 import static com.taais.biz.domain.table.DataAugmentationTableDef.DATA_AUGMENTATION;
 /**
 /**
@@ -245,16 +248,28 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
 //        makeDir(inputPath.toString());
 //        makeDir(inputPath.toString());
 //        makeDir(outputPath.toString());
 //        makeDir(outputPath.toString());
 //        makeDir(logPath.toString());
 //        makeDir(logPath.toString());
-        int lastUnderscoreIndex = fileName_without_suffix.lastIndexOf('_');
+//        Path path = Paths.get(inputPath);
         ZipUtils.unzip(resourcePath, inputPath.toString());
         ZipUtils.unzip(resourcePath, inputPath.toString());
-//        String inputPathString = "D:\\program\\taais\\duijie\\code-niguang\\" + fileName_without_suffix + "_images\\" + fileName_without_suffix.substring(0, lastUnderscoreIndex);
-        String inputPathString = inputPath.toString() + fileName_without_suffix.substring(0, lastUnderscoreIndex);
-        dataAugmentation.setLog(logPath.toString());
-        dataAugmentation.setInputPath(inputPathString);
+        //获取inputPath的下一级目录
+        if (Files.exists(inputPath) && Files.isDirectory(inputPath)) {
+            try (DirectoryStream<Path> stream = Files.newDirectoryStream(path)) {
+                // 遍历目录流以找到第一个子目录
+                for (Path entry : stream) {
+                    if (Files.isDirectory(entry)) {
+                        // 打印第一个子目录的名称,并跳出循环
+                        dataAugmentation.setInputPath(entry.toString());
+                        break;  // 只处理第一个子目录后退出
+                    }
+                }
+            } catch (IOException e) {
+                return CommonResult.fail(e.toString());
+            }
+        } else {
+            return CommonResult.fail("The provided path is not a valid directory.");
+        }
+        dataAugmentation.setAlgorithmPath(logPath.toString());
         dataAugmentation.setOutputPath(outputPath.toString());
         dataAugmentation.setOutputPath(outputPath.toString());
-
         dataAugmentation.setStartTime(new Date());
         dataAugmentation.setStartTime(new Date());
-
         dataAugmentation.setStatus(BizConstant.VideoStatus.RUNNING);
         dataAugmentation.setStatus(BizConstant.VideoStatus.RUNNING);
         updateById(dataAugmentation);
         updateById(dataAugmentation);
         SysDictDataBo sysDictDataBo = new SysDictDataBo();
         SysDictDataBo sysDictDataBo = new SysDictDataBo();