|
@@ -5,10 +5,8 @@ import java.io.IOException;
|
|
import java.nio.file.Files;
|
|
import java.nio.file.Files;
|
|
import java.nio.file.Path;
|
|
import java.nio.file.Path;
|
|
import java.nio.file.Paths;
|
|
import java.nio.file.Paths;
|
|
-import java.util.Base64;
|
|
|
|
-import java.util.HashMap;
|
|
|
|
-import java.util.List;
|
|
|
|
-import java.util.Map;
|
|
|
|
|
|
+import java.util.*;
|
|
|
|
+import java.util.stream.Collectors;
|
|
import java.util.stream.Stream;
|
|
import java.util.stream.Stream;
|
|
|
|
|
|
import com.mybatisflex.core.query.QueryWrapper;
|
|
import com.mybatisflex.core.query.QueryWrapper;
|
|
@@ -65,8 +63,7 @@ public class DataAugmentationController extends BaseController {
|
|
|
|
|
|
private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
|
|
private static Path getImageAtPathIdx(Path inputPath, int idx) throws IOException {
|
|
try (Stream<Path> paths = Files.list(inputPath)) {
|
|
try (Stream<Path> paths = Files.list(inputPath)) {
|
|
- return paths.filter(Files::isRegularFile)
|
|
|
|
- .sorted()
|
|
|
|
|
|
+ return paths.sorted()
|
|
.skip(idx)
|
|
.skip(idx)
|
|
.findFirst()
|
|
.findFirst()
|
|
.orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
|
|
.orElseThrow(() -> new IllegalArgumentException("Index " + idx + " is out of bounds for the list of files in " + inputPath));
|
|
@@ -74,7 +71,7 @@ public class DataAugmentationController extends BaseController {
|
|
}
|
|
}
|
|
|
|
|
|
@GetMapping("/compare/{task_id}/{idx}")
|
|
@GetMapping("/compare/{task_id}/{idx}")
|
|
- public ResponseEntity<Map<String, String>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
|
|
|
|
|
|
+ public ResponseEntity<Map<String,List<String>>> getCompareImages(@PathVariable("task_id") Long taskId, @PathVariable("idx") int idx) {
|
|
try {
|
|
try {
|
|
DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
|
|
DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
|
|
|
|
|
|
@@ -95,22 +92,43 @@ public class DataAugmentationController extends BaseController {
|
|
System.out.println("outputPath: " + outputPath.toString());
|
|
System.out.println("outputPath: " + outputPath.toString());
|
|
Path imagePath = getImageAtPathIdx(inputPath, idx);
|
|
Path imagePath = getImageAtPathIdx(inputPath, idx);
|
|
String fileName = imagePath.getFileName().toString();
|
|
String fileName = imagePath.getFileName().toString();
|
|
|
|
+ Map<String, List<String>> images = new HashMap<>();
|
|
|
|
+
|
|
|
|
+ //图像拼接算法有多个输入图片
|
|
|
|
+ if ("图像拼接".equals(dataAugmentation.getTaskType())) {
|
|
|
|
+
|
|
|
|
+ Stream<Path> pathStream = Files.list(imagePath);
|
|
|
|
+ Optional<Path> firstPath = pathStream.filter(Files::isRegularFile).findFirst();
|
|
|
|
+ String firstFileName = firstPath.get().getFileName().toString();
|
|
|
|
+ int lastDotIndex = firstFileName.lastIndexOf('.');
|
|
|
|
+ fileName = fileName + firstFileName.substring(lastDotIndex);
|
|
|
|
+ // 收集所有文件的路径,并编码为Base64字符串
|
|
|
|
+ Stream<Path> newPathStream = Files.list(imagePath);
|
|
|
|
+ List<String> origin = newPathStream
|
|
|
|
+ .filter(Files::isRegularFile) // 只选择常规文件(排除子目录)
|
|
|
|
+ .map(path -> {
|
|
|
|
+ try {
|
|
|
|
+ return Base64.getEncoder().encodeToString(Files.readAllBytes(path));
|
|
|
|
+ } catch (IOException e) {
|
|
|
|
+ throw new RuntimeException(e);
|
|
|
|
+ }
|
|
|
|
+ })
|
|
|
|
+ .collect(Collectors.toList());
|
|
|
|
+ images.put("origin", origin);
|
|
|
|
+ } else {
|
|
|
|
+ byte[] image1 = Files.readAllBytes(imagePath);
|
|
|
|
+ String base64Image1 = Base64.getEncoder().encodeToString(image1);
|
|
|
|
+ ArrayList<String> origin = new ArrayList<>();
|
|
|
|
+ origin.add(base64Image1);
|
|
|
|
+ images.put("origin", origin);
|
|
|
|
+ }
|
|
Path resolve = outputPath.resolve(fileName);
|
|
Path resolve = outputPath.resolve(fileName);
|
|
-
|
|
|
|
-
|
|
|
|
- byte[] image1 = Files.readAllBytes(imagePath);
|
|
|
|
byte[] image2 = Files.readAllBytes(resolve);
|
|
byte[] image2 = Files.readAllBytes(resolve);
|
|
|
|
|
|
-
|
|
|
|
- // 将图片编码成Base64字符串
|
|
|
|
- String base64Image1 = Base64.getEncoder().encodeToString(image1);
|
|
|
|
- String base64Image2 = Base64.getEncoder().encodeToString(image2);
|
|
|
|
-
|
|
|
|
- // 创建一个Map来存储Base64编码的图片
|
|
|
|
- Map<String, String> images = new HashMap<>();
|
|
|
|
- images.put("origin", base64Image1);
|
|
|
|
- images.put("stable", base64Image2);
|
|
|
|
-
|
|
|
|
|
|
+ String base64Image2 = Base64.getEncoder().encodeToString(image2); // 将图片编码成Base64字符串
|
|
|
|
+ ArrayList<String> stable = new ArrayList<>();
|
|
|
|
+ stable.add(base64Image2);
|
|
|
|
+ images.put("stable", stable);
|
|
// 返回Map
|
|
// 返回Map
|
|
return ResponseEntity.ok(images);
|
|
return ResponseEntity.ok(images);
|
|
} catch (Exception e) {
|
|
} catch (Exception e) {
|