Ver Fonte

Merge branch 'develop_1016' of http://47.108.150.237:10000/www/taais into dev-lzy

# Conflicts:
#	taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataServiceImpl.java
Eureka há 8 meses atrás
pai
commit
588cfe29de

+ 1 - 1
taais-common/taais-common-core/src/main/java/com/taais/common/core/utils/file/MimeTypeUtils.java

@@ -37,7 +37,7 @@ public class MimeTypeUtils {
         // pdf
         "pdf",
         // pt
-        "pt"
+        "pt", "pth"
     };
 
     public static String getExtension(String prefix) {

+ 8 - 0
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/AlgorithmModelController.java

@@ -3,6 +3,8 @@ package com.taais.biz.controller;
 import java.util.List;
 
 import com.taais.biz.service.impl.AlgorithmModelServiceImpl;
+import com.taais.system.domain.vo.SysOssVo;
+import com.taais.system.service.ISysOssService;
 import lombok.RequiredArgsConstructor;
 import jakarta.servlet.http.HttpServletResponse;
 import cn.dev33.satoken.annotation.SaCheckPermission;
@@ -69,6 +71,9 @@ public class AlgorithmModelController extends BaseController {
         return CommonResult.success(algorithmModelService.getModelByAlgorithmId(id));
     }
 
+    @Resource
+    private ISysOssService sysOssService;
+
     /**
      * 新增算法模型配置
      */
@@ -77,6 +82,9 @@ public class AlgorithmModelController extends BaseController {
     @RepeatSubmit()
     @PostMapping
     public CommonResult<Void> add(@Validated @RequestBody AlgorithmModelBo algorithmModelBo) {
+        String modelAddress = algorithmModelBo.getModelAddress();
+        SysOssVo vo = sysOssService.getById(Long.parseLong(modelAddress));
+        algorithmModelBo.setModelAddress("/profile" + vo.getUrl().split("/profile")[1]);
         boolean inserted = algorithmModelService.insert(algorithmModelBo);
         if (!inserted) {
             return CommonResult.fail("新增算法模型配置记录失败!");

+ 51 - 33
taais-modules/taais-biz/src/main/java/com/taais/biz/controller/DataAugmentationController.java

@@ -5,10 +5,8 @@ import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
-import java.util.Base64;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
+import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import com.mybatisflex.core.query.QueryWrapper;
@@ -65,8 +63,7 @@ public class DataAugmentationController extends BaseController {
 
     private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
         try (Stream<Path> paths = Files.list(inputPath)) {
-            return paths.filter(Files::isRegularFile)
-                .sorted()
+            return paths.sorted()
                 .skip(idx)
                 .findFirst()
                 .orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
@@ -74,7 +71,7 @@ public class DataAugmentationController extends BaseController {
     }
 
     @GetMapping("/compare/{task_id}/{idx}")
-    public ResponseEntity<Map<String, String>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
+    public ResponseEntity<Map<String,List<String>>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
         try {
             DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
 
@@ -82,35 +79,56 @@ public class DataAugmentationController extends BaseController {
             Path outputPath = null;
             String osName = System.getProperty("os.name");
             // 判断是否是Windows环境
-            if (osName.toLowerCase().contains("windows")) {
-                inputPath = Paths.get("C:", dataAugmentation.getInputPath());
-                outputPath = Paths.get("C:", dataAugmentation.getOutputPath());
-            } else {
-                inputPath = Paths.get(dataAugmentation.getInputPath());
-                outputPath = Paths.get(dataAugmentation.getOutputPath());
-            }
-//            inputPath = Paths.get(dataAugmentation.getInputPath());
-//            outputPath = Paths.get(dataAugmentation.getOutputPath());
+//            if (osName.toLowerCase().contains("windows")) {
+//                inputPath = Paths.get("D:", dataAugmentation.getInputPath());
+//                outputPath = Paths.get("D:", dataAugmentation.getOutputPath());
+//            } else {
+//                inputPath = Paths.get(dataAugmentation.getInputPath());
+//                outputPath = Paths.get(dataAugmentation.getOutputPath());
+//            }
+            inputPath = Paths.get(dataAugmentation.getInputPath());
+            outputPath = Paths.get(dataAugmentation.getOutputPath());
             System.out.println("inputPath: " + inputPath.toString());
             System.out.println("outputPath: " + outputPath.toString());
             Path imagePath = getImageAtPathIdx(inputPath, idx);
             String fileName = imagePath.getFileName().toString();
+            Map<String, List<String>> images = new HashMap<>();
+
+            //图像拼接算法有多个输入图片
+            if ("图像拼接".equals(dataAugmentation.getTaskType())) {
+
+                Stream<Path> pathStream = Files.list(imagePath);
+                Optional<Path> firstPath = pathStream.filter(Files::isRegularFile).findFirst();
+                String firstFileName = firstPath.get().getFileName().toString();
+                int lastDotIndex = firstFileName.lastIndexOf('.');
+                fileName = fileName + firstFileName.substring(lastDotIndex);
+                // 收集所有文件的路径,并编码为Base64字符串
+                Stream<Path> newPathStream = Files.list(imagePath);
+                List<String> origin = newPathStream
+                    .filter(Files::isRegularFile)  // 只选择常规文件(排除子目录)
+                    .map(path -> {
+                        try {
+                            return Base64.getEncoder().encodeToString(Files.readAllBytes(path));
+                        } catch (IOException e) {
+                            throw new RuntimeException(e);
+                        }
+                    })
+                    .collect(Collectors.toList());
+                images.put("origin", origin);
+            } else {
+                byte[] image1 = Files.readAllBytes(imagePath);
+                String base64Image1 = Base64.getEncoder().encodeToString(image1);
+                ArrayList<String> origin = new ArrayList<>();
+                origin.add(base64Image1);
+                images.put("origin", origin);
+            }
             Path resolve = outputPath.resolve(fileName);
-
-
-            byte[] image1 = Files.readAllBytes(imagePath);
             byte[] image2 = Files.readAllBytes(resolve);
 
-
-            // 将图片编码成Base64字符串
-            String base64Image1 = Base64.getEncoder().encodeToString(image1);
-            String base64Image2 = Base64.getEncoder().encodeToString(image2);
-
-            // 创建一个Map来存储Base64编码的图片
-            Map<String, String> images = new HashMap<>();
-            images.put("origin", base64Image1);
-            images.put("stable", base64Image2);
-
+            String base64Image2 = Base64.getEncoder().encodeToString(image2);   // 将图片编码成Base64字符串
+            ArrayList<String> stable = new ArrayList<>();
+            stable.add(base64Image2);
+            images.put("stable", stable);
             // 返回Map
             return ResponseEntity.ok(images);
         } catch (Exception e) {
@@ -119,10 +137,10 @@ public class DataAugmentationController extends BaseController {
         }
     }
 
-     @PostMapping("/get_result")
-     public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
-         return dataAugmentationService.getResult(videoStableStartResultBo);
-     }
+    @PostMapping("/get_result")
+    public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
+        return dataAugmentationService.getResult(videoStableStartResultBo);
+    }
 
     @GetMapping("/start/{id}")
     public CommonResult start(@PathVariable("id") Long id) {

+ 53 - 31
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataAugmentationServiceImpl.java

@@ -46,12 +46,15 @@ import org.springframework.transaction.annotation.Transactional;
 
 import java.io.File;
 import java.io.IOException;
+import java.nio.file.DirectoryStream;
+import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.Arrays;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
+import java.util.stream.Stream;
 
 import static com.taais.biz.domain.table.DataAugmentationTableDef.DATA_AUGMENTATION;
 /**
@@ -230,31 +233,49 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
 
         String fileName = StringUtils.substringAfterLast(filePath, "/");
         String fileName_without_suffix = removeFileExtension(fileName);
-
-        Path path = Paths.get(resourcePath);
-        Path inputPath = path.resolveSibling(fileName_without_suffix + "_images");
-        Path outputPath = path.resolveSibling(fileName_without_suffix + "_stable");
-        Path logPath = path.resolveSibling(fileName_without_suffix + "_log");
+        Path path = null;
+        Path inputPath = null;
+        Path outputPath = null;
+        Path logPath = null;
+//        Path path = Paths.get();
+        //String osName = System.getProperty("os.name");
+        //判断是否是Windows环境
+//        if (osName.toLowerCase().contains("windows")) {
+//            path = Paths.get("D:", resourcePath);
+//        } else {
+//            path = Paths.get(resourcePath);
+//
+//        }
+        path = Paths.get(resourcePath);
+        inputPath = path.resolveSibling(fileName_without_suffix + "_input");
+        outputPath = path.resolveSibling(fileName_without_suffix + "_output");
+        logPath = path.resolveSibling(fileName_without_suffix + "_log");
         makeDir(inputPath.toString());
         makeDir(outputPath.toString());
-        makeDir(localPath.toString());
-        //本地测试代码
-//        String inputPath = "D:\\program\\taais\\duijie\\code-niguang\\" + fileName_without_suffix + "_images";
-//        String outputPath = "D:\\program\\taais\\duijie\\code-niguang\\" + fileName_without_suffix + "_output";
-//        String logPath = "D:\\program\\taais\\duijie\\code-niguang\\" + fileName_without_suffix + "_log";
-//        makeDir(inputPath.toString());
-//        makeDir(outputPath.toString());
-//        makeDir(logPath.toString());
-        int lastUnderscoreIndex = fileName_without_suffix.lastIndexOf('_');
+        makeDir(logPath.toString());
+
         ZipUtils.unzip(resourcePath, inputPath.toString());
-//        String inputPathString = "D:\\program\\taais\\duijie\\code-niguang\\" + fileName_without_suffix + "_images\\" + fileName_without_suffix.substring(0, lastUnderscoreIndex);
-        String inputPathString = inputPath.toString() + fileName_without_suffix.substring(0, lastUnderscoreIndex);
-        dataAugmentation.setLog(logPath.toString());
-        dataAugmentation.setInputPath(inputPathString);
+//        //获取inputPath的下一级目录
+//        if (Files.exists(inputPath) && Files.isDirectory(inputPath)) {
+//            try (DirectoryStream<Path> stream = Files.newDirectoryStream(inputPath)) {
+//                // 遍历目录流以找到第一个子目录
+//                for (Path entry : stream) {
+//                    if (Files.isDirectory(entry)) {
+//                        // 打印第一个子目录的名称,并跳出循环
+//                        dataAugmentation.setInputPath(entry.toString());
+//                        break;  // 只处理第一个子目录后退出
+//                    }
+//                }
+//            } catch (IOException e) {
+//                return CommonResult.fail(e.toString());
+//            }
+//        } else {
+//            return CommonResult.fail("The provided path is not a valid directory.");
+//        }
+        dataAugmentation.setInputPath(inputPath.toString());
+        dataAugmentation.setAlgorithmPath(logPath.toString());
         dataAugmentation.setOutputPath(outputPath.toString());
-
         dataAugmentation.setStartTime(new Date());
-
         dataAugmentation.setStatus(BizConstant.VideoStatus.RUNNING);
         updateById(dataAugmentation);
         SysDictDataBo sysDictDataBo = new SysDictDataBo();
@@ -269,7 +290,7 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
         TransmissionObject transmissionObject = new TransmissionObject();
         transmissionObject.setBizId(dataAugmentation.getId());
         transmissionObject.setBizType(dataAugmentation.getTaskType());
-        transmissionObject.setLogPath(dataAugmentation.getLog());
+        transmissionObject.setLogPath(dataAugmentation.getAlgorithmPath());
         transmissionObject.setSourcePath(dataAugmentation.getInputPath());
         transmissionObject.setResultPath(dataAugmentation.getOutputPath());
         transmissionObject.setOtherParams(dataAugmentation.getHyperparameterConfiguration());
@@ -335,17 +356,18 @@ public class DataAugmentationServiceImpl extends BaseServiceImpl<DataAugmentatio
 
         Path inputPath = null;
         Path outputPath = null;
-        String osName = System.getProperty("os.name");
+//        String osName = System.getProperty("os.name");
         // 判断是否是Windows环境
-        if (osName.toLowerCase().contains("windows")) {
-            inputPath = Paths.get("C:", dataAugmentation.getInputPath());
-            outputPath = Paths.get("C:", dataAugmentation.getOutputPath());
-        } else {
-            inputPath = Paths.get(dataAugmentation.getInputPath());
-            outputPath = Paths.get(dataAugmentation.getOutputPath());
-        }
-//        inputPath = Paths.get(dataAugmentation.getInputPath());
-//        outputPath = Paths.get(dataAugmentation.getOutputPath());
+//        if (osName.toLowerCase().contains("windows")) {
+//            inputPath = Paths.get("D:", dataAugmentation.getInputPath());
+//            outputPath = Paths.get("D:", dataAugmentation.getOutputPath());
+//        } else {
+//            inputPath = Paths.get(dataAugmentation.getInputPath());
+//            outputPath = Paths.get(dataAugmentation.getOutputPath());
+//        }
+
+        inputPath = Paths.get(dataAugmentation.getInputPath());
+        outputPath = Paths.get(dataAugmentation.getOutputPath());
         // 创建File对象
         File in_directory = new File(inputPath.toString());
         File out_directory = new File(outputPath.toString());

+ 1 - 1
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/DataServiceImpl.java

@@ -73,7 +73,7 @@ public class DataServiceImpl extends BaseServiceImpl<DataMapper, Data> implement
     private static final String ZIP = ".zip";
     private static final String RAR = ".rar";
     private static final String TXT = ".txt";
-    private static final String[] VALID_EXTENSIONS = {".jpg", ".jpeg", ".png"};
+    private static final String[] VALID_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp"};
     private static final String PYTHON_DATA_AMPLIFY_API = "http://127.0.0.1:11001/augment";
     private static final String RESULT_CODE = "status";
     private static final String RESULT_STATUS = "200";

+ 7 - 2
taais-modules/taais-biz/src/main/java/com/taais/biz/service/impl/TargetIdentificationSubtaskServiceImpl.java

@@ -228,7 +228,12 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
                     hasModelProperty = true;
                     _modelId = object.getString("modelId");
                 }
-                algorithmRequestDto.getOtherParams().put(object.getString("agName"), object.getString("defaultValue"));
+                String value = object.getString("defaultValue");
+                if(NumberUtils.isCreatable(value)){
+                    algorithmRequestDto.getOtherParams().put(object.getString("agName"), NumberUtils.createNumber(value));
+                } else {
+                    algorithmRequestDto.getOtherParams().put(object.getString("agName"), value);
+                }
             }
         } catch (Exception e) {
             log.error(e.getMessage());
@@ -256,7 +261,7 @@ public class TargetIdentificationSubtaskServiceImpl extends BaseServiceImpl<Targ
             return;
         }
         // send http
-        System.out.println("todo request: " + algorithmRequestDto.toString());
+        log.info("todo request: " + algorithmRequestDto);
 
         try {
             String res = HttpUtil.post(url, JSONUtil.toJsonStr(algorithmRequestDto));