|
@@ -1,5 +1,6 @@
|
|
package com.taais.biz.service.impl;
|
|
package com.taais.biz.service.impl;
|
|
|
|
|
|
|
|
+import java.io.File;
|
|
import java.util.ArrayList;
|
|
import java.util.ArrayList;
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
import java.util.HashMap;
|
|
import java.util.HashMap;
|
|
@@ -37,6 +38,7 @@ import com.taais.biz.service.IAlgorithmModelTrackService;
|
|
|
|
|
|
import static com.taais.biz.domain.table.AlgorithmModelTrackTableDef.ALGORITHM_MODEL_TRACK;
|
|
import static com.taais.biz.domain.table.AlgorithmModelTrackTableDef.ALGORITHM_MODEL_TRACK;
|
|
import static com.taais.biz.constant.BizConstant.AlgorithmType;
|
|
import static com.taais.biz.constant.BizConstant.AlgorithmType;
|
|
|
|
+import static com.taais.biz.utils.ZipUtils.unzip;
|
|
|
|
|
|
/**
|
|
/**
|
|
* 算法模型配置Service业务层处理
|
|
* 算法模型配置Service业务层处理
|
|
@@ -149,7 +151,7 @@ public class AlgorithmModelTrackServiceImpl extends BaseServiceImpl<AlgorithmMod
|
|
public CommonResult<String> insert(AlgorithmModelTrackBo algorithmModelTrackBo) {
|
|
public CommonResult<String> insert(AlgorithmModelTrackBo algorithmModelTrackBo) {
|
|
AlgorithmConfigTrackVo algorithmConfigTrackVo = algorithmConfigTrackService.selectById(algorithmModelTrackBo.getAlgorithmId());
|
|
AlgorithmConfigTrackVo algorithmConfigTrackVo = algorithmConfigTrackService.selectById(algorithmModelTrackBo.getAlgorithmId());
|
|
if (algorithmConfigTrackVo.getType().equals(AlgorithmType.REASONING) && ObjectUtil.isEmpty(algorithmModelTrackBo.getModelInputOssId())) {
|
|
if (algorithmConfigTrackVo.getType().equals(AlgorithmType.REASONING) && ObjectUtil.isEmpty(algorithmModelTrackBo.getModelInputOssId())) {
|
|
- return CommonResult.fail("预测算法必须上传模型输入文件");
|
|
|
|
|
|
+ return CommonResult.fail("预测算法必须上传模型文件");
|
|
}
|
|
}
|
|
|
|
|
|
AlgorithmModelTrack algorithmModelTrack = MapstructUtils.convert(algorithmModelTrackBo, AlgorithmModelTrack.class);
|
|
AlgorithmModelTrack algorithmModelTrack = MapstructUtils.convert(algorithmModelTrackBo, AlgorithmModelTrack.class);
|
|
@@ -161,13 +163,27 @@ public class AlgorithmModelTrackServiceImpl extends BaseServiceImpl<AlgorithmMod
|
|
String filePath = inputOssEntity.getFileName();
|
|
String filePath = inputOssEntity.getFileName();
|
|
String localPath = TaaisConfig.getProfile();
|
|
String localPath = TaaisConfig.getProfile();
|
|
resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
|
|
resourcePath = localPath + StringUtils.substringAfter(filePath, Constants.RESOURCE_PREFIX);
|
|
- algorithmModelTrack.setModelAddress(resourcePath);
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ if (resourcePath.endsWith(".zip")) {
|
|
|
|
+ String modelAddress = resourcePath.substring(0, resourcePath.lastIndexOf("."));
|
|
|
|
+ File file = new File(modelAddress);
|
|
|
|
+ if (!file.exists()) {
|
|
|
|
+ unzip(resourcePath, modelAddress);
|
|
|
|
+ }
|
|
|
|
+ algorithmModelTrack.setModelAddress(modelAddress);
|
|
|
|
+ } else if (resourcePath.endsWith(".pt")) {
|
|
|
|
+ algorithmModelTrack.setModelAddress(resourcePath);
|
|
|
|
+ } else {
|
|
|
|
+ return CommonResult.fail("模型格式不正确,请上传.zip或.pt文件");
|
|
|
|
+ }
|
|
} else {
|
|
} else {
|
|
// String localPath = TaaisConfig.getUploadPath();
|
|
// String localPath = TaaisConfig.getUploadPath();
|
|
// String path = DateUtils.datePath() + "/" + IdUtil.fastSimpleUUID();
|
|
// String path = DateUtils.datePath() + "/" + IdUtil.fastSimpleUUID();
|
|
// resourcePath = localPath + "/" + path + ".pt";
|
|
// resourcePath = localPath + "/" + path + ".pt";
|
|
// 本来这里是默认整一个虚拟模型地址,后面发现不合适便去掉了
|
|
// 本来这里是默认整一个虚拟模型地址,后面发现不合适便去掉了
|
|
}
|
|
}
|
|
|
|
+
|
|
if (AlgorithmType.REASONING.equals(algorithmConfigTrackVo.getType()) || AlgorithmType.TEST.equals(algorithmConfigTrackVo.getType())) {
|
|
if (AlgorithmType.REASONING.equals(algorithmConfigTrackVo.getType()) || AlgorithmType.TEST.equals(algorithmConfigTrackVo.getType())) {
|
|
algorithmModelTrack.setModelStatus(BizConstant.ModelStatus.END);
|
|
algorithmModelTrack.setModelStatus(BizConstant.ModelStatus.END);
|
|
} else if (AlgorithmType.TRAIN.equals(algorithmConfigTrackVo.getType())) {
|
|
} else if (AlgorithmType.TRAIN.equals(algorithmConfigTrackVo.getType())) {
|