SysTrainController.java 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package com.cirs.biz.controller;
  2. import com.cirs.biz.domain.VerificationData;
  3. import com.cirs.common.utils.DictUtils;
  4. import com.fasterxml.jackson.databind.ObjectMapper;
  5. import org.springframework.http.HttpMethod;
  6. import java.io.File;
  7. import java.sql.Timestamp;
  8. import java.util.HashMap;
  9. import java.util.List;
  10. import java.util.Map;
  11. import javax.servlet.http.HttpServletResponse;
  12. import com.alibaba.fastjson2.JSON;
  13. import com.alibaba.fastjson2.JSONObject;
  14. import com.cirs.biz.domain.TElectronComponent;
  15. import com.cirs.biz.domain.TrainReturn;
  16. import org.apache.commons.io.FileUtils;
  17. import org.springframework.beans.factory.annotation.Value;
  18. import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration;
  19. import org.springframework.http.MediaType;
  20. import org.springframework.security.access.prepost.PreAuthorize;
  21. import org.springframework.beans.factory.annotation.Autowired;
  22. import org.springframework.web.bind.annotation.*;
  23. import com.cirs.common.annotation.Log;
  24. import com.cirs.common.core.controller.BaseController;
  25. import com.cirs.common.core.domain.AjaxResult;
  26. import com.cirs.common.enums.BusinessType;
  27. import com.cirs.biz.domain.SysTrain;
  28. import com.cirs.biz.service.ISysTrainService;
  29. import com.cirs.common.utils.poi.ExcelUtil;
  30. import com.cirs.common.core.page.TableDataInfo;
  31. import org.springframework.web.multipart.MultipartFile;
  32. import org.springframework.web.reactive.function.BodyInserters;
  33. import org.springframework.web.reactive.function.client.WebClient;
  34. import reactor.core.publisher.Mono;
  35. /**
  36. * 训练集数据列Controller
  37. *
  38. * @author allen
  39. * @date 2023-11-28
  40. */
  41. @RestController
  42. @RequestMapping("/biz/train")
  43. public class SysTrainController extends BaseController
  44. {
  45. @Autowired
  46. private ISysTrainService sysTrainService;
  47. // 创建 WebClient 对象
  48. private WebClient webClient = WebClient.builder()
  49. // .baseUrl("http://jsonplaceholder.typicode.com")
  50. .build();
  51. /**
  52. * 查询训练集数据列列表
  53. */
  54. @PreAuthorize("@ss.hasPermi('biz:train:list')")
  55. @GetMapping("/list")
  56. public TableDataInfo list(SysTrain sysTrain)
  57. {
  58. startPage();
  59. List<SysTrain> list = sysTrainService.selectSysTrainList(sysTrain);
  60. return getDataTable(list);
  61. }
  62. /**
  63. * 导出训练集数据列列表
  64. */
  65. @PreAuthorize("@ss.hasPermi('biz:train:export')")
  66. @Log(title = "训练集数据列", businessType = BusinessType.EXPORT)
  67. @PostMapping("/export")
  68. public void export(HttpServletResponse response, SysTrain sysTrain)
  69. {
  70. List<SysTrain> list = sysTrainService.selectSysTrainList(sysTrain);
  71. ExcelUtil<SysTrain> util = new ExcelUtil<SysTrain>(SysTrain.class);
  72. util.exportExcel(response, list, "训练集数据列数据");
  73. }
  74. /**
  75. * 获取训练集数据列详细信息
  76. */
  77. @PreAuthorize("@ss.hasPermi('biz:train:query')")
  78. @GetMapping(value = "/{id}")
  79. public AjaxResult getInfo(@PathVariable("id") Long id)
  80. {
  81. return success(sysTrainService.selectSysTrainById(id));
  82. }
  83. /**
  84. * 新增训练集数据列
  85. */
  86. @PreAuthorize("@ss.hasPermi('biz:train:add')")
  87. @Log(title = "训练集数据列", businessType = BusinessType.INSERT)
  88. @PostMapping
  89. public AjaxResult add(@RequestBody SysTrain sysTrain)
  90. {
  91. return toAjax(sysTrainService.insertSysTrain(sysTrain));
  92. }
  93. /**
  94. * 修改训练集数据列
  95. */
  96. @PreAuthorize("@ss.hasPermi('biz:train:edit')")
  97. @Log(title = "训练集数据列", businessType = BusinessType.UPDATE)
  98. @PutMapping
  99. public AjaxResult edit(@RequestBody SysTrain sysTrain)
  100. {
  101. return toAjax(sysTrainService.updateSysTrain(sysTrain));
  102. }
  103. /**
  104. * 删除训练集数据列
  105. */
  106. @PreAuthorize("@ss.hasPermi('biz:train:remove')")
  107. @Log(title = "训练集数据列", businessType = BusinessType.DELETE)
  108. @DeleteMapping("/{ids}")
  109. public AjaxResult remove(@PathVariable Long[] ids)
  110. {
  111. return toAjax(sysTrainService.deleteSysTrainByIds(ids));
  112. }
  113. @PreAuthorize("@ss.hasPermi('biz:train:import')")
  114. @PostMapping("/importData")
  115. public AjaxResult importData(MultipartFile file, boolean updateSupport) throws Exception
  116. {
  117. ExcelUtil<SysTrain> util = new ExcelUtil<SysTrain>(SysTrain.class);
  118. List<SysTrain> trainList = util.importExcel(file.getInputStream());
  119. String operName = getUsername();
  120. String message = sysTrainService.importTrain(trainList, updateSupport, operName);
  121. return success(message);
  122. }
  123. @PostMapping("/importTemplate")
  124. public void importTemplate(HttpServletResponse response)
  125. {
  126. ExcelUtil<SysTrain> util = new ExcelUtil<SysTrain>(SysTrain.class);
  127. util.importTemplateExcel(response, "训练集数据");
  128. }
  129. @GetMapping("/componentIds")
  130. public AjaxResult getComponentIds()
  131. {
  132. List<SysTrain> train_dataset = sysTrainService.alldata();
  133. int idx = 0;
  134. for(;idx < train_dataset.size(); idx++){
  135. SysTrain data = train_dataset.get(idx);
  136. if(data.getResult1Id()==null){//不改动元器件id时可行
  137. Long component_id1 = sysTrainService.getComponentId(data.getResult1());
  138. Long component_id2 = sysTrainService.getComponentId(data.getResult2());
  139. Long component_id3 = sysTrainService.getComponentId(data.getResult3());
  140. Long component_id4 = sysTrainService.getComponentId(data.getResult4());
  141. Long component_id5 = sysTrainService.getComponentId(data.getResult5());
  142. data.setResult1Id(component_id1);
  143. data.setResult2Id(component_id2);
  144. data.setResult3Id(component_id3);
  145. data.setResult4Id(component_id4);
  146. data.setResult5Id(component_id5);
  147. edit(data);//更新数据库
  148. }
  149. }
  150. return success();
  151. }
  152. @PreAuthorize("@ss.hasPermi('biz:train:train')")
  153. @GetMapping("/train")
  154. public AjaxResult train() {
  155. try {
  156. String model_path = DictUtils.getDictValue("biz_algorithm_config","model_path");
  157. String train_uri = DictUtils.getDictValue("biz_algorithm_config","train_uri");
  158. Map<String, Object> objectMap=new HashMap<>();
  159. objectMap.put("dataSet", sysTrainService.getComponentids());
  160. objectMap.put("modelPath", model_path);
  161. // 创建ObjectMapper实例
  162. ObjectMapper mapper = new ObjectMapper();
  163. // 将对象转换为JSON字符串
  164. String json = mapper.writeValueAsString(objectMap);
  165. logger.info("json : {}",json);
  166. // 发送请求
  167. // todo wangruilin uri 应该是一个全地址+端口,这个地址可以通过数据字典配置和获取
  168. Mono<String> mono = webClient
  169. .post() // POST 请求
  170. .uri(train_uri) // 请求路径
  171. .contentType(MediaType.APPLICATION_JSON_UTF8)
  172. .syncBody(objectMap)
  173. .retrieve() // 获取响应体
  174. .bodyToMono(String.class); //响应数据类型转换
  175. String res = mono.block();
  176. logger.info(res);
  177. //接下来就传入算法即可
  178. // System.out.println(JSON.toJSONString(objectMap));
  179. AjaxResult result = new AjaxResult();
  180. result.put("msg","成功");
  181. result.put("data",res);
  182. result.put("code",200);
  183. return result;
  184. } catch (Exception e) {
  185. return error("训练失败");
  186. }
  187. }
  188. @PreAuthorize("@ss.hasPermi('biz:train:recommend')")
  189. @PostMapping ("/recommend")
  190. public AjaxResult recommend(@RequestBody TrainReturn recommend_args) {
  191. try {
  192. String recommend_uri = DictUtils.getDictValue("biz_algorithm_config","recommend_uri");
  193. String model_path = DictUtils.getDictValue("biz_algorithm_config","model_path");
  194. Map<String, Object> objectMap=new HashMap<>();
  195. objectMap.put("useScene", recommend_args.getUseScene());
  196. objectMap.put("SearchCondition",recommend_args.getSearchCondition());
  197. objectMap.put("modelPath", model_path);
  198. objectMap.put("result1Id",recommend_args.getResult1Id());
  199. objectMap.put("result2Id",recommend_args.getResult2Id());
  200. objectMap.put("result3Id",recommend_args.getResult3Id());
  201. objectMap.put("result4Id",recommend_args.getResult4Id());
  202. objectMap.put("result5Id",recommend_args.getResult5Id());
  203. //接下来就传入算法
  204. // System.out.println(JSON.toJSONString(objectMap));
  205. // todo wangruilin uri 应该是一个全地址+端口,这个地址可以通过数据字典配置和获取
  206. Mono<String> mono = webClient
  207. .post() // POST 请求
  208. .uri(recommend_uri) // 请求路径
  209. .contentType(MediaType.APPLICATION_JSON_UTF8)
  210. .syncBody(objectMap)
  211. .retrieve() // 获取响应体
  212. .bodyToMono(String.class); //响应数据类型转换
  213. String res = mono.block();
  214. // // todo wangruilin 改成logger
  215. logger.info(res);
  216. AjaxResult result = new AjaxResult();
  217. result.put("data",res);
  218. result.put("code",200);
  219. return result;
  220. } catch (Exception e) {
  221. return error("推荐元器件失败");
  222. }
  223. }
  224. @GetMapping("/getTraindataset")
  225. public TableDataInfo getTraindataset(){
  226. try{
  227. startPage();
  228. return getDataTable(sysTrainService.alldata());
  229. } catch (Exception e){
  230. return null;
  231. }
  232. }
  233. }