|
@@ -16,15 +16,15 @@ import java.util.stream.Stream;
|
|
import com.alibaba.fastjson2.JSONArray;
|
|
import com.alibaba.fastjson2.JSONArray;
|
|
import com.alibaba.fastjson2.JSONObject;
|
|
import com.alibaba.fastjson2.JSONObject;
|
|
import com.mybatisflex.core.query.QueryWrapper;
|
|
import com.mybatisflex.core.query.QueryWrapper;
|
|
|
|
+import com.taais.biz.domain.CommonAlgorithmConfig;
|
|
import com.taais.biz.domain.DataAugmentation;
|
|
import com.taais.biz.domain.DataAugmentation;
|
|
|
|
+import com.taais.biz.domain.bo.CommonAlgorithmConfigBo;
|
|
import com.taais.biz.domain.bo.DataAugmentationBo;
|
|
import com.taais.biz.domain.bo.DataAugmentationBo;
|
|
import com.taais.biz.domain.bo.DataAugmentationResultBo;
|
|
import com.taais.biz.domain.bo.DataAugmentationResultBo;
|
|
import com.taais.biz.domain.bo.VideoStableStartResultBo;
|
|
import com.taais.biz.domain.bo.VideoStableStartResultBo;
|
|
import com.taais.biz.domain.dto.Metric;
|
|
import com.taais.biz.domain.dto.Metric;
|
|
-import com.taais.biz.domain.vo.DataAugmentationVo;
|
|
|
|
-import com.taais.biz.domain.vo.ImageUrlPair;
|
|
|
|
-import com.taais.biz.domain.vo.TaskDictDataVo;
|
|
|
|
-import com.taais.biz.domain.vo.VideoUrl;
|
|
|
|
|
|
+import com.taais.biz.domain.vo.*;
|
|
|
|
+import com.taais.biz.service.ICommonAlgorithmConfigService;
|
|
import com.taais.biz.service.IVideoStableService;
|
|
import com.taais.biz.service.IVideoStableService;
|
|
import com.taais.common.core.config.TaaisConfig;
|
|
import com.taais.common.core.config.TaaisConfig;
|
|
import com.taais.common.core.service.OssService;
|
|
import com.taais.common.core.service.OssService;
|
|
@@ -77,6 +77,9 @@ public class DataAugmentationController extends BaseController {
|
|
private ISysDictDataService sysDictDataService;
|
|
private ISysDictDataService sysDictDataService;
|
|
@Autowired
|
|
@Autowired
|
|
private IVideoStableService videoStableService;
|
|
private IVideoStableService videoStableService;
|
|
|
|
+
|
|
|
|
+ @Autowired
|
|
|
|
+ private ICommonAlgorithmConfigService commonAlgorithmConfigService;
|
|
@GetMapping("/compare/num/{task_id}")
|
|
@GetMapping("/compare/num/{task_id}")
|
|
|
|
|
|
public CommonResult getCompareNum(@PathVariable("task_id") Long taskId) {
|
|
public CommonResult getCompareNum(@PathVariable("task_id") Long taskId) {
|
|
@@ -207,27 +210,27 @@ public class DataAugmentationController extends BaseController {
|
|
return path.getFileName().toString();
|
|
return path.getFileName().toString();
|
|
}).sorted().collect(Collectors.toList());
|
|
}).sorted().collect(Collectors.toList());
|
|
if ("目标毁伤评估".equals(taskType)) {
|
|
if ("目标毁伤评估".equals(taskType)) {
|
|
- QueryWrapper queryWrapper = new QueryWrapper();
|
|
|
|
- queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
- queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
|
|
|
|
- List<SysDictData> list = sysDictDataService.list(queryWrapper);
|
|
|
|
- String[] split = list.get(0).getDictValue().split(",");
|
|
|
|
- if (list.isEmpty()) {
|
|
|
|
- origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
|
|
|
|
- origin_list.add(origin);
|
|
|
|
- images.put("error", origin_list);
|
|
|
|
- return ResponseEntity.status(500).body(images);
|
|
|
|
- } else {
|
|
|
|
-
|
|
|
|
- for (String fileName: split) {
|
|
|
|
- if (!outputFileList.contains(fileName)) {
|
|
|
|
- stable.add("输出目录不存在结果文件:" + fileName);
|
|
|
|
- stable_list.add(stable);
|
|
|
|
- images.put("error", stable_list);
|
|
|
|
- return ResponseEntity.status(500).body(images);
|
|
|
|
- }
|
|
|
|
|
|
+// QueryWrapper queryWrapper = new QueryWrapper();
|
|
|
|
+// queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
+// queryWrapper.eq("dict_label", "目标毁伤评估必须输出文件名称");
|
|
|
|
+// List<SysDictData> list = sysDictDataService.list(queryWrapper);
|
|
|
|
+ String[] split = new String[]{"diffIm.jpg","largest_salient_region.jpg","final_region.jpg"};
|
|
|
|
+// if (list.isEmpty()) {
|
|
|
|
+// origin.add("未在数据字典中定义目标毁伤评估必须输出文件名称");
|
|
|
|
+// origin_list.add(origin);
|
|
|
|
+// images.put("error", origin_list);
|
|
|
|
+// return ResponseEntity.status(500).body(images);
|
|
|
|
+// } else {
|
|
|
|
+
|
|
|
|
+ for (String fileName: split) {
|
|
|
|
+ if (!outputFileList.contains(fileName)) {
|
|
|
|
+ stable.add("输出目录不存在结果文件:" + fileName);
|
|
|
|
+ stable_list.add(stable);
|
|
|
|
+ images.put("error", stable_list);
|
|
|
|
+ return ResponseEntity.status(500).body(images);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
outputFileList.sort(Comparator.comparingInt(DataAugmentationController::extractNumber));
|
|
outputFileList.sort(Comparator.comparingInt(DataAugmentationController::extractNumber));
|
|
Map<String, Integer> indexMap = new HashMap<>();
|
|
Map<String, Integer> indexMap = new HashMap<>();
|
|
for (int i = 0; i < split.length; i++) {
|
|
for (int i = 0; i < split.length; i++) {
|
|
@@ -523,28 +526,49 @@ public class DataAugmentationController extends BaseController {
|
|
}
|
|
}
|
|
|
|
|
|
@SaCheckPermission("demo:dataAugmentation:query")
|
|
@SaCheckPermission("demo:dataAugmentation:query")
|
|
- @GetMapping(value = "/getTaskDictData")
|
|
|
|
- public CommonResult<List<TaskDictDataVo>> getTaskDictData() {
|
|
|
|
|
|
+ @PostMapping(value = "/getTaskDictData")
|
|
|
|
+ public CommonResult<List<TaskDictDataVo>> getTaskDictData(@RequestBody String module) {
|
|
|
|
+ module = module.substring(1, module.length() - 1); //去除字符串两把的引号
|
|
ArrayList<TaskDictDataVo> taskDictDataVos = new ArrayList<>();
|
|
ArrayList<TaskDictDataVo> taskDictDataVos = new ArrayList<>();
|
|
|
|
+ //之前从数据字典中获取算法任务及其对应超参配置
|
|
|
|
+// QueryWrapper queryWrapper = new QueryWrapper();
|
|
|
|
+// queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
+// queryWrapper.eq("dict_label", "任务类型");
|
|
|
|
+// List<SysDictData> taskTypeList = sysDictDataService.list(queryWrapper);
|
|
|
|
+// for (SysDictData sysDictData: taskTypeList) {
|
|
|
|
+// String dictValue = sysDictData.getDictValue();
|
|
|
|
+// queryWrapper.clear();
|
|
|
|
+// queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
+// queryWrapper.eq("dict_label", dictValue + "超参配置");
|
|
|
|
+// List<SysDictData> hyparamList = sysDictDataService.list(queryWrapper);
|
|
|
|
+// if (hyparamList.isEmpty()) {
|
|
|
|
+// return CommonResult.fail("未在type为biz_data_augmentation的字典数据中设置" + dictValue + "的超参配置");
|
|
|
|
+// }
|
|
|
|
+// TaskDictDataVo taskDictDataVo = new TaskDictDataVo();
|
|
|
|
+// taskDictDataVo.setTaskType(dictValue);
|
|
|
|
+// taskDictDataVo.setHyperparameterConfiguration(hyparamList.get(0).getDictValue());
|
|
|
|
+// taskDictDataVos.add(taskDictDataVo);
|
|
|
|
+// }
|
|
|
|
+ //现在从新建的通用算法配置表中获取
|
|
|
|
+ //List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.selectAll();
|
|
|
|
+// System.out.println("model" + module);
|
|
QueryWrapper queryWrapper = new QueryWrapper();
|
|
QueryWrapper queryWrapper = new QueryWrapper();
|
|
- queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
- queryWrapper.eq("dict_label", "任务类型");
|
|
|
|
- List<SysDictData> taskTypeList = sysDictDataService.list(queryWrapper);
|
|
|
|
- for (SysDictData sysDictData: taskTypeList) {
|
|
|
|
- String dictValue = sysDictData.getDictValue();
|
|
|
|
- queryWrapper.clear();
|
|
|
|
- queryWrapper.eq("dict_type", "biz_data_augmentation");
|
|
|
|
- queryWrapper.eq("dict_label", dictValue + "超参配置");
|
|
|
|
- List<SysDictData> hyparamList = sysDictDataService.list(queryWrapper);
|
|
|
|
- if (hyparamList.isEmpty()) {
|
|
|
|
- return CommonResult.fail("未在type为biz_data_augmentation的字典数据中设置" + dictValue + "的超参配置");
|
|
|
|
- }
|
|
|
|
|
|
+ queryWrapper.eq("module", module);
|
|
|
|
+ List<CommonAlgorithmConfig> list = commonAlgorithmConfigService.list(queryWrapper);
|
|
|
|
+ if (list.isEmpty()) {
|
|
|
|
+ return CommonResult.fail("通用算法配置表为空!");
|
|
|
|
+ }
|
|
|
|
+ for (CommonAlgorithmConfig commonAlgorithmConfig: list) {
|
|
|
|
+
|
|
TaskDictDataVo taskDictDataVo = new TaskDictDataVo();
|
|
TaskDictDataVo taskDictDataVo = new TaskDictDataVo();
|
|
- taskDictDataVo.setTaskType(dictValue);
|
|
|
|
- taskDictDataVo.setHyperparameterConfiguration(hyparamList.get(0).getDictValue());
|
|
|
|
|
|
+ taskDictDataVo.setTaskType(commonAlgorithmConfig.getAlgorithmName());
|
|
|
|
+ taskDictDataVo.setHyperparameterConfiguration(commonAlgorithmConfig.getParameters());
|
|
taskDictDataVos.add(taskDictDataVo);
|
|
taskDictDataVos.add(taskDictDataVo);
|
|
|
|
+
|
|
|
|
+
|
|
}
|
|
}
|
|
return CommonResult.success(taskDictDataVos);
|
|
return CommonResult.success(taskDictDataVos);
|
|
|
|
+
|
|
}
|
|
}
|
|
private static boolean hasSupportedExtension(String filePath, List<String> extensions) {
|
|
private static boolean hasSupportedExtension(String filePath, List<String> extensions) {
|
|
for (String ext : extensions) {
|
|
for (String ext : extensions) {
|