SwaggerExtentionSupport.java 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. package com.primeton.dgs.kernel.core.configure;
  2. import com.google.common.collect.ImmutableMap;
  3. import com.google.common.collect.Lists;
  4. import com.google.common.collect.Multimap;
  5. import com.google.common.collect.Sets;
  6. import com.google.common.net.MediaType;
  7. import com.primeton.dgs.kernel.core.web.AppBaseDispatchCommand;
  8. import io.swagger.models.Path;
  9. import io.swagger.models.Response;
  10. import io.swagger.models.Swagger;
  11. import org.apache.commons.lang.StringUtils;
  12. import org.springframework.context.ApplicationContext;
  13. import org.springframework.http.HttpMethod;
  14. import springfox.documentation.builders.ResponseMessageBuilder;
  15. import springfox.documentation.schema.ModelRef;
  16. import springfox.documentation.service.Operation;
  17. import org.slf4j.Logger;
  18. import org.slf4j.LoggerFactory;
  19. import org.springframework.beans.factory.InitializingBean;
  20. import org.springframework.beans.factory.annotation.Autowired;
  21. import org.springframework.beans.factory.annotation.Value;
  22. import org.springframework.context.ApplicationListener;
  23. import org.springframework.context.event.ContextRefreshedEvent;
  24. import org.springframework.stereotype.Repository;
  25. import springfox.documentation.service.ApiDescription;
  26. import springfox.documentation.service.ApiListing;
  27. import springfox.documentation.service.Documentation;
  28. import springfox.documentation.service.ResponseMessage;
  29. import springfox.documentation.service.Tag;
  30. import springfox.documentation.spring.web.DocumentationCache;
  31. import springfox.documentation.spring.web.json.JsonSerializer;
  32. import springfox.documentation.spring.web.plugins.Docket;
  33. import springfox.documentation.swagger2.mappers.ServiceModelToSwagger2Mapper;
  34. import java.lang.reflect.Method;
  35. import java.lang.reflect.Modifier;
  36. import java.util.ArrayList;
  37. import java.util.Collection;
  38. import java.util.HashMap;
  39. import java.util.HashSet;
  40. import java.util.Map;
  41. import java.util.Set;
  42. import java.util.TreeSet;
  43. /**
  44. *
  45. *
  46. * 元数据老结构,支持 Swagger 适配器
  47. *
  48. * <pre>
  49. *
  50. * Created by zhaopx.
  51. * User: zhaopx
  52. * Date: 2020/8/28
  53. * Time: 14:34
  54. *
  55. * </pre>
  56. *
  57. * @author zhaopx
  58. */
  59. @Repository
  60. public class SwaggerExtentionSupport implements ApplicationListener<ContextRefreshedEvent> {
  61. /**
  62. * 版本
  63. */
  64. public final static String API_VER = "7.1.0";
  65. @Autowired
  66. DocumentationCache documentationCache;
  67. @Autowired
  68. private ServiceModelToSwagger2Mapper mapper;
  69. @Autowired
  70. private JsonSerializer jsonSerializer;
  71. @Value("${swagger.group}")
  72. private String groupName;
  73. private static Logger log = LoggerFactory.getLogger(SwaggerExtentionSupport.class);
  74. /**
  75. * 扫描 cotext 中 .do 的 bean,扫描 方法
  76. * @param context
  77. */
  78. private void initSwagger(ApplicationContext context) {
  79. Documentation documentation = documentationCache.documentationByGroup(groupName);
  80. if(documentation == null) {
  81. // 如果 groupName 指定的下没有,则挂载 default 上
  82. documentation = documentationCache.documentationByGroup(Docket.DEFAULT_GROUP_NAME);
  83. }
  84. if (documentation != null) {
  85. // 取得所有的 API 合集
  86. Multimap<String, ApiListing> apiListings = documentation.getApiListings();
  87. //Swagger swagger = mapper.mapDocumentation(documentation);
  88. String[] beanDefinitionNames = context.getBeanDefinitionNames();
  89. for (String name : beanDefinitionNames) {
  90. if(name.endsWith(".do")) {
  91. Class<?> aClass = context.getBean(name).getClass();
  92. if(!AppBaseDispatchCommand.class.isAssignableFrom(aClass)) {
  93. // 必须是 AppBaseDispatchCommand 的子类,才继续
  94. continue;
  95. }
  96. log.info("add swagger bean {}", name);
  97. Method[] servletMethods = aClass.getDeclaredMethods();
  98. for (Method servletMethod : servletMethods) {
  99. String methodName = servletMethod.getName();
  100. if(!"init".equals(methodName) && Modifier.isPublic(servletMethod.getModifiers())) {
  101. // 返回 tags
  102. Set<Tag> tags = addApi(apiListings, documentation.getBasePath(), name, methodName);
  103. if(!tags.isEmpty()) {
  104. documentation.getTags().addAll(tags);
  105. }
  106. }
  107. }
  108. }
  109. }
  110. log.info("swagger apis size: {}", apiListings.size());
  111. }
  112. }
  113. private Set<Tag> addApi(Multimap<String, ApiListing> apiListings,
  114. String basePath,
  115. String beanName,
  116. String methodName) {
  117. // 获取去除了 .do 的 名称
  118. String optGroup = getName(beanName);
  119. String optId = optGroup + "_" + methodName;
  120. // 生成唯一ID
  121. Collection<ApiListing> apis = apiListings.get(optId);
  122. if(apis == null) {
  123. // 后面只是用 apis 的 size 获取长度
  124. apis = new HashSet<>();
  125. }
  126. ArrayList<ApiDescription> apis1 = new ArrayList<>();
  127. ArrayList<Operation> operations = new ArrayList<>();
  128. ResponseMessageBuilder v1 = new ResponseMessageBuilder();
  129. v1.code(200).message("OK");
  130. // tag
  131. HashSet<Tag> tags = new HashSet<>();
  132. tags.add(new Tag(optGroup, beanName+"." + methodName));
  133. // 注意 position,必须是不重复的值
  134. Operation operaGet = new Operation(
  135. HttpMethod.GET,
  136. "do exec " + optGroup + "." + methodName,
  137. "",
  138. new ModelRef("ResponseCode"),
  139. optId+"UsingGET",
  140. 0,
  141. Sets.newHashSet(optGroup),
  142. Sets.newHashSet(MediaType.ANY_TYPE.toString()),
  143. Sets.newHashSet(MediaType.create("application", "json").toString()),
  144. new HashSet<>(),
  145. new ArrayList<>(),
  146. new ArrayList<>(),
  147. Sets.newHashSet(v1.build()),
  148. "",
  149. false,
  150. new ArrayList<>()
  151. );
  152. Operation operaPost = new Operation(
  153. HttpMethod.POST,
  154. "do exec " + optGroup + "." + methodName,
  155. "",
  156. new ModelRef("ResponseCode"),
  157. optId+"UsingPOST",
  158. 1,
  159. Sets.newHashSet(optGroup),
  160. Sets.newHashSet(MediaType.ANY_TYPE.toString()),
  161. Sets.newHashSet(MediaType.create("application", "json").toString()),
  162. new HashSet<>(),
  163. new ArrayList<>(),
  164. new ArrayList<>(),
  165. Sets.newHashSet(v1.build()),
  166. "",
  167. false,
  168. new ArrayList<>()
  169. );
  170. operations.add(operaGet);
  171. operations.add(operaPost);
  172. String url = "/" + beanName + "?invoke=" + methodName;
  173. apis1.add(new ApiDescription(groupName,
  174. url,
  175. beanName+"." + methodName,
  176. operations, false));
  177. // 注意 position,必须是不重复的值
  178. ApiListing apiListing = new ApiListing(
  179. API_VER,
  180. basePath,
  181. "/" + beanName,
  182. new HashSet<>(),new HashSet<>(),"", new HashSet<>(), new ArrayList<>(),
  183. apis1,
  184. new HashMap<>(), beanName+"." + methodName, apis.size(), tags);
  185. // 放到api列表中
  186. apiListings.put(optId, apiListing);
  187. return tags;
  188. }
  189. private String getName(String beanName) {
  190. if(StringUtils.isBlank(beanName)) {
  191. return beanName;
  192. }
  193. int i = beanName.indexOf(".");
  194. if(i > 0) {
  195. return beanName.substring(0, i);
  196. }
  197. return beanName;
  198. }
  199. @Override
  200. public void onApplicationEvent(ContextRefreshedEvent event) {
  201. // Spring 框架加载完全后,扫描 bean,获取 servlet
  202. initSwagger(event.getApplicationContext());
  203. }
  204. }