BizTrainingServiceImpl.java 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package com.ips.system.service.impl;
  2. import java.util.Date;
  3. import java.util.HashMap;
  4. import java.util.List;
  5. import java.util.Map;
  6. import java.util.concurrent.CompletableFuture;
  7. import com.fasterxml.jackson.core.JsonProcessingException;
  8. import com.fasterxml.jackson.core.type.TypeReference;
  9. import com.fasterxml.jackson.databind.ObjectMapper;
  10. import com.ips.common.utils.DateUtils;
  11. import com.ips.common.utils.StringUtils;
  12. import com.ips.system.domain.AlgorithmConfig;
  13. import com.ips.system.domain.Distillation;
  14. import com.ips.system.dto.AlgorithmParamsDto;
  15. import com.ips.system.service.IAlgorithmConfigService;
  16. import com.ips.system.utils.AlgorithmCaller;
  17. import org.slf4j.Logger;
  18. import org.slf4j.LoggerFactory;
  19. import org.springframework.beans.factory.annotation.Autowired;
  20. import org.springframework.stereotype.Service;
  21. import com.ips.system.mapper.BizTrainingMapper;
  22. import com.ips.system.domain.BizTraining;
  23. import com.ips.system.service.IBizTrainingService;
  24. /**
  25. * 模型训练Service业务层处理
  26. *
  27. * @author Allen
  28. * @date 2025-05-21
  29. */
  30. @Service
  31. public class BizTrainingServiceImpl implements IBizTrainingService {
  32. protected final Logger logger = LoggerFactory.getLogger(this.getClass());
  33. @Autowired
  34. private BizTrainingMapper bizTrainingMapper;
  35. @Autowired
  36. private IAlgorithmConfigService algorithmConfigService;
  37. /**
  38. * 查询模型训练
  39. *
  40. * @param id 模型训练主键
  41. * @return 模型训练
  42. */
  43. @Override
  44. public BizTraining selectBizTrainingById(Long id) {
  45. return bizTrainingMapper.selectBizTrainingById(id);
  46. }
  47. /**
  48. * 查询模型训练列表
  49. *
  50. * @param bizTraining 模型训练
  51. * @return 模型训练
  52. */
  53. @Override
  54. public List<BizTraining> selectBizTrainingList(BizTraining bizTraining) {
  55. return bizTrainingMapper.selectBizTrainingList(bizTraining);
  56. }
  57. /**
  58. * 新增模型训练
  59. *
  60. * @param bizTraining 模型训练
  61. * @return 结果
  62. */
  63. @Override
  64. public int insertBizTraining(BizTraining bizTraining) {
  65. bizTraining.setCreateTime(DateUtils.getNowDate());
  66. return bizTrainingMapper.insertBizTraining(bizTraining);
  67. }
  68. /**
  69. * 修改模型训练
  70. *
  71. * @param bizTraining 模型训练
  72. * @return 结果
  73. */
  74. @Override
  75. public int updateBizTraining(BizTraining bizTraining) {
  76. bizTraining.setUpdateTime(DateUtils.getNowDate());
  77. return bizTrainingMapper.updateBizTraining(bizTraining);
  78. }
  79. /**
  80. * 批量删除模型训练
  81. *
  82. * @param ids 需要删除的模型训练主键
  83. * @return 结果
  84. */
  85. @Override
  86. public int deleteBizTrainingByIds(Long[] ids) {
  87. return bizTrainingMapper.deleteBizTrainingByIds(ids);
  88. }
  89. /**
  90. * 删除模型训练信息
  91. *
  92. * @param id 模型训练主键
  93. * @return 结果
  94. */
  95. @Override
  96. public int deleteBizTrainingById(Long id) {
  97. return bizTrainingMapper.deleteBizTrainingById(id);
  98. }
  99. @Override
  100. public String run(Long id) throws JsonProcessingException {
  101. BizTraining training = bizTrainingMapper.selectBizTrainingById(id);
  102. if (training == null || training.getAlgorithmId() == null) {
  103. return "无法找到该任务,id:" + id;
  104. }
  105. Long algorithmId = training.getAlgorithmId();
  106. AlgorithmConfig algorithmConfig = algorithmConfigService.selectAlgorithmConfigById(algorithmId);
  107. training.setStartTime(new Date());
  108. training.setStatus("1");
  109. this.updateBizTraining(training);
  110. CompletableFuture.runAsync(() -> {
  111. // 异步逻辑
  112. doRun(algorithmConfig, training);
  113. });
  114. return null;
  115. }
  116. private void doRun(AlgorithmConfig algorithmConfig, BizTraining training) {
  117. String algorithmPath = algorithmConfig.getAlgorithmPath();
  118. // 组装json
  119. String inputPath = training.getInputPath();
  120. String outputPath = training.getOutputPath();
  121. ObjectMapper objectMapper = new ObjectMapper();
  122. Map<String, Object> params = new HashMap<>(0);
  123. String errorMsg = "";
  124. try {
  125. if (StringUtils.isNotEmpty(training.getAlgorithmParams())) {
  126. objectMapper.readValue(
  127. training.getAlgorithmParams(),
  128. new TypeReference<Map<String, Object>>() {
  129. }
  130. );
  131. }
  132. AlgorithmParamsDto algorithmParamsDto = new AlgorithmParamsDto(inputPath, outputPath, params);
  133. // 对象 → JSON 字符串
  134. String json = objectMapper.writeValueAsString(algorithmParamsDto);
  135. // 处理算法
  136. errorMsg = AlgorithmCaller.executeAlgorithm(algorithmPath, json);
  137. } catch (JsonProcessingException e) {
  138. logger.error("格式化失败", e);
  139. errorMsg = "格式化失败";
  140. }
  141. //处理结果
  142. if (StringUtils.isEmpty(errorMsg)) {
  143. training.setStatus("2");
  144. } else {
  145. training.setStatus("3");
  146. }
  147. training.setEndTime(new Date());
  148. this.updateBizTraining(training);
  149. }
  150. @Override
  151. public BizTraining getResultDetails(Long id) {
  152. BizTraining bizTraining = this.selectBizTrainingById(id);
  153. String inputPath = bizTraining.getInputPath();
  154. // try {
  155. // List<FileInfoDTO> fileInfoDTOList = CommonUtils.getSortedFiles(inputPath, "dat");
  156. // extractedFeatures.setFolderInfoDTO(fileInfoDTOList);
  157. // } catch (IOException e) {
  158. // logger.error("读取文件夹错误", e);
  159. // return null;
  160. // }
  161. return bizTraining;
  162. }
  163. }