123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- package com.taais.biz.controller;
- import com.alibaba.fastjson2.JSON;
- import com.taais.biz.constant.BizConstant;
- import com.taais.biz.domain.DataAugmentation;
- import com.taais.biz.domain.TaskTrackResultBo;
- import com.taais.biz.domain.bo.*;
- import com.taais.biz.domain.dto.TaskResultDTO;
- import com.taais.biz.service.*;
- import com.taais.biz.service.impl.*;
- import com.taais.biz.service.service.impl.ObjectMatchServiceImpl;
- import com.taais.common.core.config.TaaisConfig;
- import com.taais.common.core.core.domain.CommonResult;
- import com.taais.common.log.annotation.Log;
- import com.taais.common.log.enums.BusinessType;
- import com.taais.common.web.core.BaseController;
- import jakarta.annotation.Resource;
- import jakarta.validation.Valid;
- import lombok.RequiredArgsConstructor;
- import org.apache.commons.lang3.StringUtils;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.core.io.InputStreamResource;
- import org.springframework.http.HttpHeaders;
- import org.springframework.http.MediaType;
- import org.springframework.http.ResponseEntity;
- import org.springframework.validation.annotation.Validated;
- import org.springframework.web.bind.annotation.*;
- import java.io.File;
- import java.io.FileInputStream;
- import java.io.IOException;
- import java.nio.file.Path;
- import java.nio.file.Paths;
- import java.util.Arrays;
- import java.util.Date;
- /**
- * @author allen
- */
- @Validated
- @RequiredArgsConstructor
- @RestController
- @RequestMapping("/public")
- public class PublicController extends BaseController {
- private static final Logger log = LoggerFactory.getLogger(PublicController.class);
- @Resource
- private IVideoStableService videoStableService;
- @Resource
- private IToInfraredService toInfraredService;
- @Resource
- private ITrackSequenceService trackSequenceService;
- @Resource
- private ITargetDetectionService targetDetectionService;
- @Resource
- private IDataAugmentationService dataAugmentationService;
- @Resource
- IAlgorithmDataProcessService dataProcessService;
- @Resource
- IAlgorithmBizProcessService bizProcessService;
- @Resource
- TargetIdentificationSubtaskDetailsServiceImpl detailsService;
- @Resource
- private IAlgorithmTaskService algorithmTaskService;
- @Resource
- AlgorithmModelServiceImpl algorithmModelService;
- @Resource
- ObjectTraceMergeServiceImpl objectTraceMergeService;
- @Resource
- ObjectMatchServiceImpl objectMatchService;
- @Resource
- DataAmplificationTaskServiceImpl dataAmplificationTaskService;
- @Resource
- DataServiceImpl dataService;
- @PostMapping("/taskResult")
- public CommonResult<Void> taskResult(@RequestBody TaskResultDTO resultDTO) {
- log.info("taskResult start,params:{}", resultDTO);
- String errorMsg = checkDTO(resultDTO);
- if (StringUtils.isNotEmpty(errorMsg)) {
- log.error("taskResult error,{}", errorMsg);
- return CommonResult.fail(errorMsg);
- }
- String bizType = resultDTO.getBizType();
- if(BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType)){
- //errorMsg = bizProcessService.taskResult(resultDTO);
- TargetIdentificationSubtaskDetailsBo detailsBo = detailsService.getById(resultDTO.getBizId());
- detailsBo.setStatus(resultDTO.getStatus() != 200 ? BizConstant.TASK_STATUS_FAILED :
- resultDTO.getMsg().contains("finish") ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_PROCESSING);
- detailsBo.setEndTime(new Date());
- detailsBo.setCostSecond((detailsBo.getEndTime().getTime() - detailsBo.getStartTime().getTime()) / 1000);
- detailsService.update(detailsBo);
- // 保存模型
- if (BizConstant.TASK_STATUS_SUCCEED.equals(detailsBo.getStatus()) && detailsBo.getName().contains("训练")) {
- Long algorithmId = detailsBo.getAlgorithmId();
- AlgorithmModelBo bo = new AlgorithmModelBo();
- bo.setAlgorithmId(algorithmId);
- String _path = TaaisConfig.getProfile() + "/task" + detailsBo.getResultPath() + "/weights";
- File dir = new File(_path);
- String SUFFIX_NAME = "NONAME";
- if (dir.exists()) {
- File[] files = dir.listFiles();
- for (File file : files) {
- if (file.isFile() && file.getName().contains("best")) {
- SUFFIX_NAME = file.getName();
- break;
- }
- }
- }
- bo.setModelAddress("/profile/task" + detailsBo.getResultPath() + "/weights/" + SUFFIX_NAME);
- bo.setModelName(detailsBo.getName() + "_" + detailsBo.getCreateTime().toString());
- algorithmModelService.insert(bo);
- }
- } else if (BizConstant.TYPE_OBJ_TRACE.equals(bizType)){
- ObjectTraceMergeBo bo = objectTraceMergeService.getById(resultDTO.getBizId());
- if (bo == null) {
- return CommonResult.fail("bo为null");
- }
- bo.setEndTime(new Date());
- bo.setCostSecond((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000);
- bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
- objectTraceMergeService.update(bo);
- } else if (BizConstant.TYPE_OBJ_MATCH.equals(bizType)){
- ObjectMatchBo bo = objectMatchService.getById(resultDTO.getBizId());
- if (bo == null) {
- return CommonResult.fail("bo为null");
- }
- bo.setEndTime(new Date());
- bo.setCostSecond((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000);
- bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
- objectMatchService.update(bo);
- } else if (BizConstant.TYPE_DATA_PROCESS.equals(bizType)) {
- errorMsg = dataProcessService.taskResult(resultDTO);
- } else if (BizConstant.TYPE_DATA_EXPAND.equals(bizType)) {
- DataAmplificationTaskBo bo = dataAmplificationTaskService.getById(resultDTO.getBizId());
- log.info("data_expand callback: {}, {}", resultDTO, bo);
- if (bo != null) {
- try {
- bo.setInputImagePath(String.valueOf(Integer.parseInt(bo.getInputImagePath()) - 1));
- if (bo.getInputImagePath().equals("0")) {
- bo.setEndTime(new Date());
- bo.setCostSecond((int) ((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000));
- bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
- if ("PERSISTENCE".equals(bo.getRemarks())) {
- dataService.registerExistedData(bo.getOutputImagePath() + "/result", bo);
- }
- }
- } catch (Exception e) {
- log.error("error: {}", e.getMessage());
- } finally {
- dataAmplificationTaskService.update(bo);
- }
- }
- } else {
- log.error("这种情况是不可能发生的,参数:{}",resultDTO);
- return CommonResult.fail("这种情况是不可能发生的");
- }
- if(StringUtils.isNotEmpty(errorMsg)){
- return CommonResult.fail(errorMsg);
- }
- return CommonResult.success();
- }
- private String checkDTO(TaskResultDTO resultDTO) {
- Integer status = resultDTO.getStatus();
- if(status != 200 && status != 500){
- return "status 只能是200或500";
- }
- Long bizId = resultDTO.getBizId();
- if(bizId == null){
- return "bizId 不能为null";
- }
- String bizType = resultDTO.getBizType();
- if(!BizConstant.TYPE_OBJ_MATCH.equals(bizType) &&
- !BizConstant.TYPE_OBJ_TRACE.equals(bizType) &&
- !BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType) &&
- !BizConstant.TYPE_DATA_PROCESS.equals(bizType) &&
- !BizConstant.TYPE_DATA_EXPAND.equals(bizType)){
- return "status 只能是"+BizConstant.TYPE_DATA_BIZ_PROCESS+"或"+BizConstant.TYPE_DATA_PROCESS;
- }
- return null;
- }
- @GetMapping("/taskRun")
- public CommonResult<Void> taskRun() {
- algorithmTaskService.taskRun();
- return CommonResult.success();
- }
- @PostMapping("/videoStable/get_result")
- public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
- log.info("/videoStable/get_result ,params:{}", videoStableStartResultBo);
- return videoStableService.getResult(videoStableStartResultBo);
- }
- @PostMapping("/dataAugmentation/get_result")
- public CommonResult Result(@Valid @RequestBody DataAugmentationResultBo dataAugmentationStartResultBo) {
- log.info("/dataAugmentation/get_result ,params:{}", dataAugmentationStartResultBo);
- return dataAugmentationService.getResult(dataAugmentationStartResultBo);
- }
- @PostMapping("/task/get_result")
- public CommonResult getResult(@Valid @RequestBody TaskTrackResultBo taskTrackResultBo) {
- log.info("/task/get_result ,params:{}", taskTrackResultBo);
- if(BizConstant.BizType.TO_INFRARED.equals(taskTrackResultBo.getBizType())) {
- return toInfraredService.getResult(taskTrackResultBo);
- }
- else if(BizConstant.BizType.TRACK_SEQUENCE.equals(taskTrackResultBo.getBizType())) {
- return trackSequenceService.getResult(taskTrackResultBo);
- }
- else if(BizConstant.BizType.TARGET_DETECTION.equals(taskTrackResultBo.getBizType())) {
- return targetDetectionService.getResult(taskTrackResultBo);
- }
- else {
- return CommonResult.fail("业务类型不支持");
- }
- }
- @GetMapping("/video/input/{id}")
- public ResponseEntity<InputStreamResource> streamVideo(@PathVariable Long id) throws IOException {
- // 视频文件存储路径
- DataAugmentation byId = dataAugmentationService.getById(id);
- System.out.println("dataaa:"+byId);
- String inputPath = byId.getInputPath();
- Path path = Paths.get(inputPath, "Project_Test.avi");
- String filename = path.toString();
- File videoFile = new File(filename);
- if (!videoFile.exists()) {
- return ResponseEntity.notFound().build();
- }
- // 获取视频文件的扩展名来判断 MIME 类型
- String fileExtension = getFileExtension(filename);
- MediaType mediaType = getMediaType(fileExtension);
- // 打开视频文件流
- FileInputStream videoStream = new FileInputStream(videoFile);
- // 返回响应流
- return ResponseEntity.ok()
- .contentType(mediaType)
- .header(HttpHeaders.CONTENT_DISPOSITION, "inline; filename=\"" + filename + "\"") // 可以直接播放
- .body(new InputStreamResource(videoStream));
- }
- @GetMapping("/video/output/{id}")
- public ResponseEntity<InputStreamResource> streamOutputVideo(@PathVariable Long id) throws IOException {
- // 视频文件存储路径
- DataAugmentation byId = dataAugmentationService.getById(id);
- String inputPath = byId.getOutputPath();
- Path path = Paths.get(inputPath, "result.avi");
- String filename = path.toString();
- File videoFile = new File(filename);
- if (!videoFile.exists()) {
- return ResponseEntity.notFound().build();
- }
- // 获取视频文件的扩展名来判断 MIME 类型
- String fileExtension = getFileExtension(filename);
- MediaType mediaType = getMediaType(fileExtension);
- // 打开视频文件流
- FileInputStream videoStream = new FileInputStream(videoFile);
- // 返回响应流
- return ResponseEntity.ok()
- .contentType(mediaType)
- .header(HttpHeaders.CONTENT_DISPOSITION, "inline; filename=\"" + filename + "\"") // 可以直接播放
- .body(new InputStreamResource(videoStream));
- }
- // 获取文件扩展名
- private String getFileExtension(String filename) {
- int dotIndex = filename.lastIndexOf('.');
- return (dotIndex > 0) ? filename.substring(dotIndex + 1) : "";
- }
- // 根据文件扩展名返回正确的 MIME 类型
- private MediaType getMediaType(String extension) {
- switch (extension.toLowerCase()) {
- case "mp4":
- return MediaType.valueOf("video/mp4");
- case "avi":
- return MediaType.valueOf("video/x-msvideo");
- case "mov":
- return MediaType.valueOf("video/quicktime");
- case "webm":
- return MediaType.valueOf("video/webm");
- case "flv":
- return MediaType.valueOf("video/x-flv");
- case "ogg":
- return MediaType.valueOf("video/ogg");
- default:
- return MediaType.valueOf("application/octet-stream"); // 默认二进制流
- }
- }
- // 根据文件扩展名返回正确的 MIME 类型
- }
|