PublicController.java 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package com.taais.biz.controller;
  2. import com.alibaba.fastjson2.JSON;
  3. import com.taais.biz.constant.BizConstant;
  4. import com.taais.biz.domain.DataAugmentation;
  5. import com.taais.biz.domain.TaskTrackResultBo;
  6. import com.taais.biz.domain.bo.*;
  7. import com.taais.biz.domain.dto.TaskResultDTO;
  8. import com.taais.biz.service.*;
  9. import com.taais.biz.service.impl.*;
  10. import com.taais.biz.service.service.impl.ObjectMatchServiceImpl;
  11. import com.taais.common.core.config.TaaisConfig;
  12. import com.taais.common.core.core.domain.CommonResult;
  13. import com.taais.common.log.annotation.Log;
  14. import com.taais.common.log.enums.BusinessType;
  15. import com.taais.common.web.core.BaseController;
  16. import jakarta.annotation.Resource;
  17. import jakarta.validation.Valid;
  18. import lombok.RequiredArgsConstructor;
  19. import org.apache.commons.lang3.StringUtils;
  20. import org.slf4j.Logger;
  21. import org.slf4j.LoggerFactory;
  22. import org.springframework.beans.factory.annotation.Autowired;
  23. import org.springframework.core.io.InputStreamResource;
  24. import org.springframework.http.HttpHeaders;
  25. import org.springframework.http.MediaType;
  26. import org.springframework.http.ResponseEntity;
  27. import org.springframework.validation.annotation.Validated;
  28. import org.springframework.web.bind.annotation.*;
  29. import java.io.File;
  30. import java.io.FileInputStream;
  31. import java.io.IOException;
  32. import java.nio.file.Path;
  33. import java.nio.file.Paths;
  34. import java.util.Arrays;
  35. import java.util.Date;
  36. /**
  37. * @author allen
  38. */
  39. @Validated
  40. @RequiredArgsConstructor
  41. @RestController
  42. @RequestMapping("/public")
  43. public class PublicController extends BaseController {
  44. private static final Logger log = LoggerFactory.getLogger(PublicController.class);
  45. @Resource
  46. private IVideoStableService videoStableService;
  47. @Resource
  48. private IToInfraredService toInfraredService;
  49. @Resource
  50. private ITrackSequenceService trackSequenceService;
  51. @Resource
  52. private ITargetDetectionService targetDetectionService;
  53. @Resource
  54. private IDataAugmentationService dataAugmentationService;
  55. @Resource
  56. IAlgorithmDataProcessService dataProcessService;
  57. @Resource
  58. IAlgorithmBizProcessService bizProcessService;
  59. @Resource
  60. TargetIdentificationSubtaskDetailsServiceImpl detailsService;
  61. @Resource
  62. private IAlgorithmTaskService algorithmTaskService;
  63. @Resource
  64. AlgorithmModelServiceImpl algorithmModelService;
  65. @Resource
  66. ObjectTraceMergeServiceImpl objectTraceMergeService;
  67. @Resource
  68. ObjectMatchServiceImpl objectMatchService;
  69. @Resource
  70. DataAmplificationTaskServiceImpl dataAmplificationTaskService;
  71. @Resource
  72. DataServiceImpl dataService;
  73. @PostMapping("/taskResult")
  74. public CommonResult<Void> taskResult(@RequestBody TaskResultDTO resultDTO) {
  75. log.info("taskResult start,params:{}", resultDTO);
  76. String errorMsg = checkDTO(resultDTO);
  77. if (StringUtils.isNotEmpty(errorMsg)) {
  78. log.error("taskResult error,{}", errorMsg);
  79. return CommonResult.fail(errorMsg);
  80. }
  81. String bizType = resultDTO.getBizType();
  82. if(BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType)){
  83. //errorMsg = bizProcessService.taskResult(resultDTO);
  84. TargetIdentificationSubtaskDetailsBo detailsBo = detailsService.getById(resultDTO.getBizId());
  85. detailsBo.setStatus(resultDTO.getStatus() != 200 ? BizConstant.TASK_STATUS_FAILED :
  86. resultDTO.getMsg().contains("finish") ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_PROCESSING);
  87. detailsBo.setEndTime(new Date());
  88. detailsBo.setCostSecond((detailsBo.getEndTime().getTime() - detailsBo.getStartTime().getTime()) / 1000);
  89. detailsService.update(detailsBo);
  90. // 保存模型
  91. if (BizConstant.TASK_STATUS_SUCCEED.equals(detailsBo.getStatus()) && detailsBo.getName().contains("训练")) {
  92. Long algorithmId = detailsBo.getAlgorithmId();
  93. AlgorithmModelBo bo = new AlgorithmModelBo();
  94. bo.setAlgorithmId(algorithmId);
  95. String _path = TaaisConfig.getProfile() + "/task" + detailsBo.getResultPath() + "/weights";
  96. File dir = new File(_path);
  97. String SUFFIX_NAME = "NONAME";
  98. if (dir.exists()) {
  99. File[] files = dir.listFiles();
  100. for (File file : files) {
  101. if (file.isFile() && file.getName().contains("best")) {
  102. SUFFIX_NAME = file.getName();
  103. break;
  104. }
  105. }
  106. }
  107. bo.setModelAddress("/profile/task" + detailsBo.getResultPath() + "/weights/" + SUFFIX_NAME);
  108. bo.setModelName(detailsBo.getName() + "_" + detailsBo.getCreateTime().toString());
  109. algorithmModelService.insert(bo);
  110. }
  111. } else if (BizConstant.TYPE_OBJ_TRACE.equals(bizType)){
  112. ObjectTraceMergeBo bo = objectTraceMergeService.getById(resultDTO.getBizId());
  113. if (bo == null) {
  114. return CommonResult.fail("bo为null");
  115. }
  116. bo.setEndTime(new Date());
  117. bo.setCostSecond((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000);
  118. bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
  119. objectTraceMergeService.update(bo);
  120. } else if (BizConstant.TYPE_OBJ_MATCH.equals(bizType)){
  121. ObjectMatchBo bo = objectMatchService.getById(resultDTO.getBizId());
  122. if (bo == null) {
  123. return CommonResult.fail("bo为null");
  124. }
  125. bo.setEndTime(new Date());
  126. bo.setCostSecond((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000);
  127. bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
  128. objectMatchService.update(bo);
  129. } else if (BizConstant.TYPE_DATA_PROCESS.equals(bizType)) {
  130. errorMsg = dataProcessService.taskResult(resultDTO);
  131. } else if (BizConstant.TYPE_DATA_EXPAND.equals(bizType)) {
  132. DataAmplificationTaskBo bo = dataAmplificationTaskService.getById(resultDTO.getBizId());
  133. log.info("data_expand callback: {}, {}", resultDTO, bo);
  134. if (bo != null) {
  135. try {
  136. bo.setInputImagePath(String.valueOf(Integer.parseInt(bo.getInputImagePath()) - 1));
  137. if (bo.getInputImagePath().equals("0")) {
  138. bo.setEndTime(new Date());
  139. bo.setCostSecond((int) ((bo.getEndTime().getTime() - bo.getStartTime().getTime()) / 1000));
  140. bo.setStatus(resultDTO.getStatus() == 200 ? BizConstant.TASK_STATUS_SUCCEED : BizConstant.TASK_STATUS_FAILED);
  141. if ("PERSISTENCE".equals(bo.getRemarks())) {
  142. dataService.registerExistedData(bo.getOutputImagePath() + "/result", bo);
  143. }
  144. }
  145. } catch (Exception e) {
  146. log.error("error: {}", e.getMessage());
  147. } finally {
  148. dataAmplificationTaskService.update(bo);
  149. }
  150. }
  151. } else {
  152. log.error("这种情况是不可能发生的,参数:{}",resultDTO);
  153. return CommonResult.fail("这种情况是不可能发生的");
  154. }
  155. if(StringUtils.isNotEmpty(errorMsg)){
  156. return CommonResult.fail(errorMsg);
  157. }
  158. return CommonResult.success();
  159. }
  160. private String checkDTO(TaskResultDTO resultDTO) {
  161. Integer status = resultDTO.getStatus();
  162. if(status != 200 && status != 500){
  163. return "status 只能是200或500";
  164. }
  165. Long bizId = resultDTO.getBizId();
  166. if(bizId == null){
  167. return "bizId 不能为null";
  168. }
  169. String bizType = resultDTO.getBizType();
  170. if(!BizConstant.TYPE_OBJ_MATCH.equals(bizType) &&
  171. !BizConstant.TYPE_OBJ_TRACE.equals(bizType) &&
  172. !BizConstant.TYPE_DATA_BIZ_PROCESS.equals(bizType) &&
  173. !BizConstant.TYPE_DATA_PROCESS.equals(bizType) &&
  174. !BizConstant.TYPE_DATA_EXPAND.equals(bizType)){
  175. return "status 只能是"+BizConstant.TYPE_DATA_BIZ_PROCESS+"或"+BizConstant.TYPE_DATA_PROCESS;
  176. }
  177. return null;
  178. }
  179. @GetMapping("/taskRun")
  180. public CommonResult<Void> taskRun() {
  181. algorithmTaskService.taskRun();
  182. return CommonResult.success();
  183. }
  184. @PostMapping("/videoStable/get_result")
  185. public CommonResult getResult(@Valid @RequestBody VideoStableStartResultBo videoStableStartResultBo) {
  186. log.info("/videoStable/get_result ,params:{}", videoStableStartResultBo);
  187. return videoStableService.getResult(videoStableStartResultBo);
  188. }
  189. @PostMapping("/dataAugmentation/get_result")
  190. public CommonResult Result(@Valid @RequestBody DataAugmentationResultBo dataAugmentationStartResultBo) {
  191. log.info("/dataAugmentation/get_result ,params:{}", dataAugmentationStartResultBo);
  192. return dataAugmentationService.getResult(dataAugmentationStartResultBo);
  193. }
  194. @PostMapping("/task/get_result")
  195. public CommonResult getResult(@Valid @RequestBody TaskTrackResultBo taskTrackResultBo) {
  196. log.info("/task/get_result ,params:{}", taskTrackResultBo);
  197. if(BizConstant.BizType.TO_INFRARED.equals(taskTrackResultBo.getBizType())) {
  198. return toInfraredService.getResult(taskTrackResultBo);
  199. }
  200. else if(BizConstant.BizType.TRACK_SEQUENCE.equals(taskTrackResultBo.getBizType())) {
  201. return trackSequenceService.getResult(taskTrackResultBo);
  202. }
  203. else if(BizConstant.BizType.TARGET_DETECTION.equals(taskTrackResultBo.getBizType())) {
  204. return targetDetectionService.getResult(taskTrackResultBo);
  205. }
  206. else {
  207. return CommonResult.fail("业务类型不支持");
  208. }
  209. }
  210. @GetMapping("/video/input/{id}")
  211. public ResponseEntity<InputStreamResource> streamVideo(@PathVariable Long id) throws IOException {
  212. // 视频文件存储路径
  213. DataAugmentation byId = dataAugmentationService.getById(id);
  214. System.out.println("dataaa:"+byId);
  215. String inputPath = byId.getInputPath();
  216. Path path = Paths.get(inputPath, "Project_Test.avi");
  217. String filename = path.toString();
  218. File videoFile = new File(filename);
  219. if (!videoFile.exists()) {
  220. return ResponseEntity.notFound().build();
  221. }
  222. // 获取视频文件的扩展名来判断 MIME 类型
  223. String fileExtension = getFileExtension(filename);
  224. MediaType mediaType = getMediaType(fileExtension);
  225. // 打开视频文件流
  226. FileInputStream videoStream = new FileInputStream(videoFile);
  227. // 返回响应流
  228. return ResponseEntity.ok()
  229. .contentType(mediaType)
  230. .header(HttpHeaders.CONTENT_DISPOSITION, "inline; filename=\"" + filename + "\"") // 可以直接播放
  231. .body(new InputStreamResource(videoStream));
  232. }
  233. @GetMapping("/video/output/{id}")
  234. public ResponseEntity<InputStreamResource> streamOutputVideo(@PathVariable Long id) throws IOException {
  235. // 视频文件存储路径
  236. DataAugmentation byId = dataAugmentationService.getById(id);
  237. String inputPath = byId.getOutputPath();
  238. Path path = Paths.get(inputPath, "result.avi");
  239. String filename = path.toString();
  240. File videoFile = new File(filename);
  241. if (!videoFile.exists()) {
  242. return ResponseEntity.notFound().build();
  243. }
  244. // 获取视频文件的扩展名来判断 MIME 类型
  245. String fileExtension = getFileExtension(filename);
  246. MediaType mediaType = getMediaType(fileExtension);
  247. // 打开视频文件流
  248. FileInputStream videoStream = new FileInputStream(videoFile);
  249. // 返回响应流
  250. return ResponseEntity.ok()
  251. .contentType(mediaType)
  252. .header(HttpHeaders.CONTENT_DISPOSITION, "inline; filename=\"" + filename + "\"") // 可以直接播放
  253. .body(new InputStreamResource(videoStream));
  254. }
  255. // 获取文件扩展名
  256. private String getFileExtension(String filename) {
  257. int dotIndex = filename.lastIndexOf('.');
  258. return (dotIndex > 0) ? filename.substring(dotIndex + 1) : "";
  259. }
  260. // 根据文件扩展名返回正确的 MIME 类型
  261. private MediaType getMediaType(String extension) {
  262. switch (extension.toLowerCase()) {
  263. case "mp4":
  264. return MediaType.valueOf("video/mp4");
  265. case "avi":
  266. return MediaType.valueOf("video/x-msvideo");
  267. case "mov":
  268. return MediaType.valueOf("video/quicktime");
  269. case "webm":
  270. return MediaType.valueOf("video/webm");
  271. case "flv":
  272. return MediaType.valueOf("video/x-flv");
  273. case "ogg":
  274. return MediaType.valueOf("video/ogg");
  275. default:
  276. return MediaType.valueOf("application/octet-stream"); // 默认二进制流
  277. }
  278. }
  279. // 根据文件扩展名返回正确的 MIME 类型
  280. }