|
@@ -22,6 +22,7 @@ import com.taais.biz.domain.bo.DataAugmentationResultBo;
|
|
|
import com.taais.biz.domain.bo.VideoStableStartResultBo;
|
|
|
import com.taais.biz.domain.dto.Metric;
|
|
|
import com.taais.biz.domain.vo.DataAugmentationVo;
|
|
|
+import com.taais.biz.domain.vo.ImageUrlPair;
|
|
|
import com.taais.biz.domain.vo.TaskDictDataVo;
|
|
|
import com.taais.biz.domain.vo.VideoUrl;
|
|
|
import com.taais.biz.service.IVideoStableService;
|
|
@@ -102,6 +103,48 @@ public class DataAugmentationController extends BaseController {
|
|
|
// 如果文件名不是以数字开头,则返回一个很大的数,使其排在最后
|
|
|
return -1;
|
|
|
}
|
|
|
+
|
|
|
+ public static List<ImageUrlPair> generateFilePairs(Path inputPath, Path outputPath) throws IOException {
|
|
|
+ String localPath = TaaisConfig.getProfile();
|
|
|
+ Path localPathPath = Paths.get(localPath);
|
|
|
+
|
|
|
+
|
|
|
+ // 替换路径
|
|
|
+ return Files.list(inputPath)
|
|
|
+ .filter(Files::isRegularFile) // 只处理文件
|
|
|
+ .map(path -> {
|
|
|
+ String fileName = path.getFileName().toString();
|
|
|
+ ImageUrlPair imageUrlPair = new ImageUrlPair();
|
|
|
+ imageUrlPair.setInputUrl(inputPath.resolve(fileName).toString().replace(localPathPath.toString(), "/profile").replace("\\", "/"));
|
|
|
+ imageUrlPair.setOutputUrl(outputPath.resolve(fileName).toString().replace(localPathPath.toString(), "/profile").replace("\\", "/"));
|
|
|
+ return imageUrlPair;
|
|
|
+ })
|
|
|
+ .collect(Collectors.toList());
|
|
|
+ }
|
|
|
+ //新接口获取图片对序列,返回支持前端PreviewCompareImages组件的数据格式
|
|
|
+ @GetMapping("/imageCompare/{task_id}")
|
|
|
+ public CommonResult<List<ImageUrlPair>> getCompareImageSeq(@PathVariable("task_id") Long taskId) throws IOException {
|
|
|
+ DataAugmentation dataAugmentation = dataAugmentationService.getById(taskId);
|
|
|
+ Path inputPath = Paths.get(dataAugmentation.getInputPath());
|
|
|
+ Path outputPath = Paths.get(dataAugmentation.getOutputPath());
|
|
|
+
|
|
|
+ if ("多目标跟踪".equals(dataAugmentation.getTaskType())) {
|
|
|
+ inputPath = inputPath.resolve("input");
|
|
|
+ outputPath = outputPath.resolve("output");
|
|
|
+ }
|
|
|
+ List<ImageUrlPair> list;
|
|
|
+ if (!Files.exists(inputPath) || !Files.isDirectory(inputPath)) {
|
|
|
+ System.out.println("输入路径不存在或不是目录:" + inputPath.toString());
|
|
|
+ return CommonResult.fail("输入路径不存在或不是目录:" + inputPath.toString());
|
|
|
+ }
|
|
|
+ if (!Files.exists(outputPath) || !Files.isDirectory(outputPath)) {
|
|
|
+ System.out.println("输出路径不存在或不是目录:" + outputPath.toString());
|
|
|
+ return CommonResult.fail("输出路径不存在或不是目录:" + outputPath.toString());
|
|
|
+ }
|
|
|
+ list = generateFilePairs(inputPath, outputPath);
|
|
|
+ return CommonResult.success(list);
|
|
|
+
|
|
|
+ }
|
|
|
@GetMapping("/compare/{task_id}")
|
|
|
public ResponseEntity<Map<String,List<List<String>>>> getCompareImages(@PathVariable("task_id") Long taskId) {
|
|
|
try {
|
|
@@ -162,19 +205,20 @@ public class DataAugmentationController extends BaseController {
|
|
|
.filter(Files::isRegularFile) // 只选择常规文件(排除子目录、排除图像拼接算法_sift输入中的txt文件)
|
|
|
.filter(path -> !path.toString().toLowerCase().endsWith(".txt")).map(path -> {
|
|
|
return path.getFileName().toString();
|
|
|
- }).collect(Collectors.toList());
|
|
|
+ }).sorted().collect(Collectors.toList());
|
|
|
if ("目标毁伤评估".equals(taskType)) {
|
|
|
QueryWrapper queryWrapper = new QueryWrapper();
|
|
|
queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
|
|
|
List<SysDictData> list = sysDictDataService.list(queryWrapper);
|
|
|
+ String[] split = list.get(0).getDictValue().split(",");
|
|
|
if (list.isEmpty()) {
|
|
|
origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
|
|
|
origin_list.add(origin);
|
|
|
images.put("error", origin_list);
|
|
|
return ResponseEntity.status(500).body(images);
|
|
|
} else {
|
|
|
- String[] split = list.get(0).getDictValue().split(",");
|
|
|
+
|
|
|
for (String fileName: split) {
|
|
|
if (!outputFileList.contains(fileName)) {
|
|
|
stable.add("输出目录不存在结果文件:" + fileName);
|
|
@@ -185,6 +229,25 @@ public class DataAugmentationController extends BaseController {
|
|
|
}
|
|
|
}
|
|
|
outputFileList.sort(Comparator.comparingInt(DataAugmentationController::extractNumber));
|
|
|
+ Map<String, Integer> indexMap = new HashMap<>();
|
|
|
+ for (int i = 0; i < split.length; i++) {
|
|
|
+ indexMap.put(split[i], i);
|
|
|
+ }
|
|
|
+ // 自定义比较器,基于indexMap中存储的位置信息
|
|
|
+ Comparator<String> customComparator = (o1, o2) -> {
|
|
|
+ Integer index1 = indexMap.get(o1);
|
|
|
+ Integer index2 = indexMap.get(o2);
|
|
|
+ // 如果o1或o2不在s数组中,则保持它们在原始列表中的相对顺序
|
|
|
+ if (index1 == null && index2 == null) {
|
|
|
+ return 0;
|
|
|
+ } else if (index1 == null) {
|
|
|
+ return 1;
|
|
|
+ } else if (index2 == null) {
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ return index1.compareTo(index2);
|
|
|
+ };
|
|
|
+ Collections.sort(outputFileList, customComparator);
|
|
|
}
|
|
|
} else {
|
|
|
System.out.println(taskType + "任务:" + dataAugmentation.getId() + "未创建结果目录!");
|
|
@@ -200,7 +263,7 @@ public class DataAugmentationController extends BaseController {
|
|
|
.filter(path -> !path.toString().toLowerCase().endsWith(".json"))
|
|
|
.map(path -> {
|
|
|
return path.getFileName().toString();
|
|
|
- }).collect(Collectors.toList());
|
|
|
+ }).sorted().collect(Collectors.toList());
|
|
|
} else {
|
|
|
//算法输入是一对一时,输出文件名称和输入文件名称相同
|
|
|
inputPath1 = inputPath;
|
|
@@ -312,9 +375,6 @@ public class DataAugmentationController extends BaseController {
|
|
|
public void export(HttpServletResponse response, @RequestBody String zipDirPath) {
|
|
|
|
|
|
try {
|
|
|
- // 假设你已经知道ZIP文件的路径
|
|
|
-// String zipFilePath = dataAugmentationVo.getOutputPath() + ".zip";
|
|
|
-// System.out.println(dataAugmentationVo.toString());
|
|
|
|
|
|
String zipFilePath = zipDirPath.substring(1, zipDirPath.length() - 1) + ".zip";
|
|
|
System.out.println( zipFilePath);
|