BizTrainingServiceImpl.java 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package com.ips.system.service.impl;
  2. import java.util.Date;
  3. import java.util.HashMap;
  4. import java.util.List;
  5. import java.util.Map;
  6. import java.util.concurrent.CompletableFuture;
  7. import com.fasterxml.jackson.core.JsonProcessingException;
  8. import com.fasterxml.jackson.core.type.TypeReference;
  9. import com.fasterxml.jackson.databind.ObjectMapper;
  10. import com.ips.common.utils.DateUtils;
  11. import com.ips.common.utils.StringUtils;
  12. import com.ips.system.domain.AlgorithmConfig;
  13. import com.ips.system.dto.AlgorithmParamsDto;
  14. import com.ips.system.service.IAlgorithmConfigService;
  15. import com.ips.system.utils.AlgorithmCaller;
  16. import org.slf4j.Logger;
  17. import org.slf4j.LoggerFactory;
  18. import org.springframework.beans.factory.annotation.Autowired;
  19. import org.springframework.stereotype.Service;
  20. import com.ips.system.mapper.BizTrainingMapper;
  21. import com.ips.system.domain.BizTraining;
  22. import com.ips.system.service.IBizTrainingService;
  23. /**
  24. * 模型训练Service业务层处理
  25. *
  26. * @author Allen
  27. * @date 2025-05-21
  28. */
  29. @Service
  30. public class BizTrainingServiceImpl implements IBizTrainingService {
  31. protected final Logger logger = LoggerFactory.getLogger(this.getClass());
  32. @Autowired
  33. private BizTrainingMapper bizTrainingMapper;
  34. @Autowired
  35. private IAlgorithmConfigService algorithmConfigService;
  36. /**
  37. * 查询模型训练
  38. *
  39. * @param id 模型训练主键
  40. * @return 模型训练
  41. */
  42. @Override
  43. public BizTraining selectBizTrainingById(Long id) {
  44. return bizTrainingMapper.selectBizTrainingById(id);
  45. }
  46. /**
  47. * 查询模型训练列表
  48. *
  49. * @param bizTraining 模型训练
  50. * @return 模型训练
  51. */
  52. @Override
  53. public List<BizTraining> selectBizTrainingList(BizTraining bizTraining) {
  54. return bizTrainingMapper.selectBizTrainingList(bizTraining);
  55. }
  56. /**
  57. * 新增模型训练
  58. *
  59. * @param bizTraining 模型训练
  60. * @return 结果
  61. */
  62. @Override
  63. public int insertBizTraining(BizTraining bizTraining) {
  64. bizTraining.setCreateTime(DateUtils.getNowDate());
  65. return bizTrainingMapper.insertBizTraining(bizTraining);
  66. }
  67. /**
  68. * 修改模型训练
  69. *
  70. * @param bizTraining 模型训练
  71. * @return 结果
  72. */
  73. @Override
  74. public int updateBizTraining(BizTraining bizTraining) {
  75. bizTraining.setUpdateTime(DateUtils.getNowDate());
  76. return bizTrainingMapper.updateBizTraining(bizTraining);
  77. }
  78. /**
  79. * 批量删除模型训练
  80. *
  81. * @param ids 需要删除的模型训练主键
  82. * @return 结果
  83. */
  84. @Override
  85. public int deleteBizTrainingByIds(Long[] ids) {
  86. return bizTrainingMapper.deleteBizTrainingByIds(ids);
  87. }
  88. /**
  89. * 删除模型训练信息
  90. *
  91. * @param id 模型训练主键
  92. * @return 结果
  93. */
  94. @Override
  95. public int deleteBizTrainingById(Long id) {
  96. return bizTrainingMapper.deleteBizTrainingById(id);
  97. }
  98. @Override
  99. public String run(Long id) throws JsonProcessingException {
  100. BizTraining training = bizTrainingMapper.selectBizTrainingById(id);
  101. if (training == null || training.getAlgorithmId() == null) {
  102. return "无法找到该任务,id:" + id;
  103. }
  104. Long algorithmId = training.getAlgorithmId();
  105. AlgorithmConfig algorithmConfig = algorithmConfigService.selectAlgorithmConfigById(algorithmId);
  106. training.setStartTime(new Date());
  107. training.setStatus("1");
  108. this.updateBizTraining(training);
  109. CompletableFuture.runAsync(() -> {
  110. // 异步逻辑
  111. doRun(algorithmConfig, training);
  112. });
  113. return null;
  114. }
  115. private void doRun(AlgorithmConfig algorithmConfig, BizTraining training) {
  116. String algorithmPath = algorithmConfig.getAlgorithmPath();
  117. // 组装json
  118. String inputPath = training.getInputPath();
  119. String outputPath = training.getOutputPath();
  120. ObjectMapper objectMapper = new ObjectMapper();
  121. Map<String, Object> params = new HashMap<>(0);
  122. String errorMsg = "";
  123. try {
  124. if (StringUtils.isNotEmpty(training.getAlgorithmParams())) {
  125. objectMapper.readValue(
  126. training.getAlgorithmParams(),
  127. new TypeReference<Map<String, Object>>() {
  128. }
  129. );
  130. }
  131. AlgorithmParamsDto algorithmParamsDto = new AlgorithmParamsDto(inputPath, outputPath, params);
  132. // 对象 → JSON 字符串
  133. String json = objectMapper.writeValueAsString(algorithmParamsDto);
  134. // 处理算法
  135. errorMsg = AlgorithmCaller.executeAlgorithm(algorithmConfig.getAlgorithmName(),algorithmPath, json);
  136. } catch (JsonProcessingException e) {
  137. logger.error("格式化失败", e);
  138. errorMsg = "格式化失败";
  139. }
  140. //处理结果
  141. if (StringUtils.isEmpty(errorMsg)) {
  142. training.setStatus("2");
  143. } else {
  144. training.setStatus("3");
  145. }
  146. training.setEndTime(new Date());
  147. this.updateBizTraining(training);
  148. }
  149. @Override
  150. public BizTraining getResultDetails(Long id) {
  151. BizTraining bizTraining = this.selectBizTrainingById(id);
  152. String inputPath = bizTraining.getInputPath();
  153. // try {
  154. // List<FileInfoDTO> fileInfoDTOList = CommonUtils.getSortedFiles(inputPath, "dat");
  155. // extractedFeatures.setFolderInfoDTO(fileInfoDTOList);
  156. // } catch (IOException e) {
  157. // logger.error("读取文件夹错误", e);
  158. // return null;
  159. // }
  160. return bizTraining;
  161. }
  162. }