|
@@ -5,10 +5,8 @@ import java.io.IOException;
|
|
|
import java.nio.file.Files;
|
|
|
import java.nio.file.Path;
|
|
|
import java.nio.file.Paths;
|
|
|
-import java.util.Base64;
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
+import java.util.*;
|
|
|
+import java.util.stream.Collectors;
|
|
|
import java.util.stream.Stream;
|
|
|
|
|
|
import com.mybatisflex.core.query.QueryWrapper;
|
|
@@ -65,8 +63,7 @@ public class DataAugmentationController extends BaseController {
|
|
|
|
|
|
private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
|
|
|
try (Stream<Path> paths = Files.list(inputPath)) {
|
|
|
- return paths.filter(Files::isRegularFile)
|
|
|
- .sorted()
|
|
|
+ return paths.sorted()
|
|
|
.skip(idx)
|
|
|
.findFirst()
|
|
|
.orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
|
|
@@ -74,7 +71,7 @@ public class DataAugmentationController extends BaseController {
|
|
|
}
|
|
|
|
|
|
@GetMapping("/compare/{task_id}/{idx}")
|
|
|
- public ResponseEntity<Map<String, String>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
|
|
|
+ public ResponseEntity<Map<String,List<String>>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
|
|
|
try {
|
|
|
DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
|
|
|
|
|
@@ -82,35 +79,56 @@ public class DataAugmentationController extends BaseController {
|
|
|
Path outputPath = null;
|
|
|
String osName = System.getProperty("os.name");
|
|
|
// 判断是否是Windows环境
|
|
|
- if (osName.toLowerCase().contains("windows")) {
|
|
|
- inputPath = Paths.get("C:", dataAugmentation.getInputPath());
|
|
|
- outputPath = Paths.get("C:", dataAugmentation.getOutputPath());
|
|
|
- } else {
|
|
|
- inputPath = Paths.get(dataAugmentation.getInputPath());
|
|
|
- outputPath = Paths.get(dataAugmentation.getOutputPath());
|
|
|
- }
|
|
|
-// inputPath = Paths.get(dataAugmentation.getInputPath());
|
|
|
-// outputPath = Paths.get(dataAugmentation.getOutputPath());
|
|
|
+// if (osName.toLowerCase().contains("windows")) {
|
|
|
+// inputPath = Paths.get("D:", dataAugmentation.getInputPath());
|
|
|
+// outputPath = Paths.get("D:", dataAugmentation.getOutputPath());
|
|
|
+// } else {
|
|
|
+// inputPath = Paths.get(dataAugmentation.getInputPath());
|
|
|
+// outputPath = Paths.get(dataAugmentation.getOutputPath());
|
|
|
+// }
|
|
|
+ inputPath = Paths.get(dataAugmentation.getInputPath());
|
|
|
+ outputPath = Paths.get(dataAugmentation.getOutputPath());
|
|
|
System.out.println("inputPath: " + inputPath.toString());
|
|
|
System.out.println("outputPath: " + outputPath.toString());
|
|
|
Path imagePath = getImageAtPathIdx(inputPath, idx);
|
|
|
String fileName = imagePath.getFileName().toString();
|
|
|
+ Map<String, List<String>> images = new HashMap<>();
|
|
|
+
|
|
|
+ //图像拼接算法有多个输入图片
|
|
|
+ if ("图像拼接".equals(dataAugmentation.getTaskType())) {
|
|
|
+
|
|
|
+ Stream<Path> pathStream = Files.list(imagePath);
|
|
|
+ Optional<Path> firstPath = pathStream.filter(Files::isRegularFile).findFirst();
|
|
|
+ String firstFileName = firstPath.get().getFileName().toString();
|
|
|
+ int lastDotIndex = firstFileName.lastIndexOf('.');
|
|
|
+ fileName = fileName + firstFileName.substring(lastDotIndex);
|
|
|
+ // 收集所有文件的路径,并编码为Base64字符串
|
|
|
+ Stream<Path> newPathStream = Files.list(imagePath);
|
|
|
+ List<String> origin = newPathStream
|
|
|
+ .filter(Files::isRegularFile) // 只选择常规文件(排除子目录)
|
|
|
+ .map(path -> {
|
|
|
+ try {
|
|
|
+ return Base64.getEncoder().encodeToString(Files.readAllBytes(path));
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ })
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ images.put("origin", origin);
|
|
|
+ } else {
|
|
|
+ byte[] image1 = Files.readAllBytes(imagePath);
|
|
|
+ String base64Image1 = Base64.getEncoder().encodeToString(image1);
|
|
|
+ ArrayList<String> origin = new ArrayList<>();
|
|
|
+ origin.add(base64Image1);
|
|
|
+ images.put("origin", origin);
|
|
|
+ }
|
|
|
Path resolve = outputPath.resolve(fileName);
|
|
|
-
|
|
|
-
|
|
|
- byte[] image1 = Files.readAllBytes(imagePath);
|
|
|
byte[] image2 = Files.readAllBytes(resolve);
|
|
|
|
|
|
-
|
|
|
- // 将图片编码成Base64字符串
|
|
|
- String base64Image1 = Base64.getEncoder().encodeToString(image1);
|
|
|
- String base64Image2 = Base64.getEncoder().encodeToString(image2);
|
|
|
-
|
|
|
- // 创建一个Map来存储Base64编码的图片
|
|
|
- Map<String, String> images = new HashMap<>();
|
|
|
- images.put("origin", base64Image1);
|
|
|
- images.put("stable", base64Image2);
|
|
|
-
|
|
|
+ String base64Image2 = Base64.getEncoder().encodeToString(image2); // 将图片编码成Base64字符串
|
|
|
+ ArrayList<String> stable = new ArrayList<>();
|
|
|
+ stable.add(base64Image2);
|
|
|
+ images.put("stable", stable);
|
|
|
// 返回Map
|
|
|
return ResponseEntity.ok(images);
|
|
|
} catch (Exception e) {
|
|
@@ -119,10 +137,10 @@ public class DataAugmentationController extends BaseController {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- @PostMapping("/get_result")
|
|
|
- public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
|
|
|
- return dataAugmentationService.getResult(videoStableStartResultBo);
|
|
|
- }
|
|
|
+ @PostMapping("/get_result")
|
|
|
+ public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
|
|
|
+ return dataAugmentationService.getResult(videoStableStartResultBo);
|
|
|
+ }
|
|
|
|
|
|
@GetMapping("/start/{id}")
|
|
|
public CommonResult start(@PathVariable("id") Long id) {
|