SwaggerExtentionSupport.java 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package com.primeton.dgs.kernel.core.configure;
  2. import com.google.common.collect.ImmutableMap;
  3. import com.google.common.collect.Lists;
  4. import com.google.common.collect.Multimap;
  5. import com.google.common.collect.Sets;
  6. import com.google.common.net.MediaType;
  7. import com.primeton.dgs.kernel.core.web.AppBaseDispatchCommand;
  8. import io.swagger.models.Path;
  9. import io.swagger.models.Response;
  10. import io.swagger.models.Swagger;
  11. import org.apache.commons.lang.StringUtils;
  12. import org.springframework.context.ApplicationContext;
  13. import org.springframework.http.HttpMethod;
  14. import springfox.documentation.builders.ResponseMessageBuilder;
  15. import springfox.documentation.schema.ModelRef;
  16. import springfox.documentation.service.Operation;
  17. import org.slf4j.Logger;
  18. import org.slf4j.LoggerFactory;
  19. import org.springframework.beans.factory.InitializingBean;
  20. import org.springframework.beans.factory.annotation.Autowired;
  21. import org.springframework.beans.factory.annotation.Value;
  22. import org.springframework.context.ApplicationListener;
  23. import org.springframework.context.event.ContextRefreshedEvent;
  24. import org.springframework.stereotype.Repository;
  25. import springfox.documentation.service.ApiDescription;
  26. import springfox.documentation.service.ApiListing;
  27. import springfox.documentation.service.Documentation;
  28. import springfox.documentation.service.ResponseMessage;
  29. import springfox.documentation.service.Tag;
  30. import springfox.documentation.spring.web.DocumentationCache;
  31. import springfox.documentation.spring.web.json.JsonSerializer;
  32. import springfox.documentation.spring.web.plugins.Docket;
  33. import springfox.documentation.swagger2.mappers.ServiceModelToSwagger2Mapper;
  34. import java.lang.reflect.Method;
  35. import java.lang.reflect.Modifier;
  36. import java.util.ArrayList;
  37. import java.util.Collection;
  38. import java.util.HashMap;
  39. import java.util.HashSet;
  40. import java.util.Map;
  41. import java.util.Set;
  42. import java.util.TreeSet;
  43. /**
  44. *
  45. *
  46. * 元数据老结构,支持 Swagger 适配器
  47. *
  48. * <pre>
  49. *
  50. * Created by zhaopx.
  51. * User: zhaopx
  52. * Date: 2020/8/28
  53. * Time: 14:34
  54. *
  55. * </pre>
  56. *
  57. * @author zhaopx
  58. */
  59. @Repository
  60. public class SwaggerExtentionSupport implements ApplicationListener<ContextRefreshedEvent> {
  61. /**
  62. * 版本
  63. */
  64. public final static String API_VER = "7.1.0";
  65. @Autowired
  66. DocumentationCache documentationCache;
  67. @Autowired
  68. private ServiceModelToSwagger2Mapper mapper;
  69. @Autowired
  70. private JsonSerializer jsonSerializer;
  71. @Value("${swagger.group}")
  72. private String groupName;
  73. private static Logger log = LoggerFactory.getLogger(SwaggerExtentionSupport.class);
  74. /**
  75. * 扫描 cotext 中 .do 的 bean,扫描 方法
  76. * @param context
  77. */
  78. private void initSwagger(ApplicationContext context) {
  79. Documentation documentation = documentationCache.documentationByGroup(groupName);
  80. if(documentation == null) {
  81. // 如果 groupName 指定的下没有,则挂载 default 上
  82. documentation = documentationCache.documentationByGroup(Docket.DEFAULT_GROUP_NAME);
  83. }
  84. if (documentation != null) {
  85. // 取得所有的 API 合集
  86. Multimap<String, ApiListing> apiListings = documentation.getApiListings();
  87. //Swagger swagger = mapper.mapDocumentation(documentation);
  88. String[] beanDefinitionNames = context.getBeanDefinitionNames();
  89. //String[] beanDefinitionNames = new String[]{"test.do", "param.do", "rule.do"};
  90. for (String name : beanDefinitionNames) {
  91. if(name.endsWith(".do")) {
  92. Class<?> aClass = context.getBean(name).getClass();
  93. if(!AppBaseDispatchCommand.class.isAssignableFrom(aClass)) {
  94. // 必须是 AppBaseDispatchCommand 的子类,才继续
  95. continue;
  96. }
  97. log.info("add swagger bean {}", name);
  98. Method[] servletMethods = aClass.getDeclaredMethods();
  99. for (Method servletMethod : servletMethods) {
  100. String methodName = servletMethod.getName();
  101. if(!"init".equals(methodName) && Modifier.isPublic(servletMethod.getModifiers())) {
  102. // 返回 tags
  103. Set<Tag> tags = addApi(apiListings, documentation.getBasePath(), name, methodName);
  104. if(!tags.isEmpty()) {
  105. documentation.getTags().addAll(tags);
  106. }
  107. }
  108. }
  109. }
  110. }
  111. log.info("swagger apis size: {}", apiListings.size());
  112. }
  113. }
  114. private Set<Tag> addApi(Multimap<String, ApiListing> apiListings,
  115. String basePath,
  116. String beanName,
  117. String methodName) {
  118. // 获取去除了 .do 的 名称
  119. String optGroup = getName(beanName);
  120. String apiId = optGroup + "_" + methodName;
  121. String optId = methodName;
  122. // 生成唯一ID
  123. Collection<ApiListing> apis = apiListings.get(apiId);
  124. if(apis == null) {
  125. // 后面只是用 apis 的 size 获取长度
  126. apis = new HashSet<>();
  127. }
  128. ArrayList<ApiDescription> apis1 = new ArrayList<>();
  129. ArrayList<Operation> operations = new ArrayList<>();
  130. ResponseMessage v200 = new ResponseMessageBuilder().code(200).message("OK").build();
  131. ResponseMessage v401 = new ResponseMessageBuilder().code(401).message("Unauthorized").build();
  132. ResponseMessage v403 = new ResponseMessageBuilder().code(403).message("Forbidden").build();
  133. ResponseMessage v404 = new ResponseMessageBuilder().code(404).message("Not Found").build();
  134. // tag
  135. HashSet<Tag> tags = new HashSet<>();
  136. // description 是生成 API JS 的文件名
  137. // optGroup 是生成的函数名
  138. Tag tag = new Tag(optGroup, optGroup + "Controller");
  139. tags.add(tag);
  140. // 注意 position,必须是不重复的值
  141. Operation operaGet = new Operation(
  142. HttpMethod.GET,
  143. "do exec " + optGroup + "." + methodName,
  144. "",
  145. new ModelRef("ResponseCode"),
  146. optId+"UsingGET",
  147. 0,
  148. Sets.newHashSet(tag.getName()),
  149. Sets.newHashSet(MediaType.ANY_TYPE.toString()),
  150. Sets.newHashSet(MediaType.create("application", "json").toString()),
  151. new HashSet<>(),
  152. new ArrayList<>(),
  153. new ArrayList<>(),
  154. Sets.newHashSet(v200, v401, v403, v404),
  155. "",
  156. false,
  157. new ArrayList<>()
  158. );
  159. // Operation 只需要 tag name,他决定了该 api 在 Swagger 上挂载的tag
  160. /*
  161. Operation operaPost = new Operation(
  162. HttpMethod.POST,
  163. "do exec " + optGroup + "." + methodName,
  164. "",
  165. new ModelRef("ResponseCode"),
  166. optId+"UsingPOST",
  167. 0,
  168. Sets.newHashSet(tag.getName()),
  169. Sets.newHashSet(MediaType.ANY_TYPE.toString()),
  170. Sets.newHashSet(MediaType.create("application", "json").toString()),
  171. new HashSet<>(),
  172. new ArrayList<>(),
  173. new ArrayList<>(),
  174. Sets.newHashSet(v200, v401, v403, v404),
  175. "",
  176. false,
  177. new ArrayList<>()
  178. );
  179. */
  180. operations.add(operaGet);
  181. //operations.add(operaPost);
  182. String url = "/" + beanName + "?invoke=" + methodName;
  183. apis1.add(new ApiDescription(groupName,
  184. url,
  185. beanName+"." + methodName,
  186. operations, false));
  187. // 注意 position,必须是不重复的值
  188. ApiListing apiListing = new ApiListing(
  189. API_VER,
  190. basePath,
  191. "/" + beanName,
  192. new HashSet<>(),new HashSet<>(),"", new HashSet<>(), new ArrayList<>(),
  193. apis1,
  194. new HashMap<>(), beanName+"." + methodName, apis.size(), tags);
  195. // 放到api列表中
  196. apiListings.put(apiId, apiListing);
  197. // 返回 tag,tag 会显示到 Swagger Content
  198. return tags;
  199. }
  200. private String getName(String beanName) {
  201. if(StringUtils.isBlank(beanName)) {
  202. return beanName;
  203. }
  204. int i = beanName.indexOf(".");
  205. if(i > 0) {
  206. return beanName.substring(0, i);
  207. }
  208. return beanName;
  209. }
  210. @Override
  211. public void onApplicationEvent(ContextRefreshedEvent event) {
  212. // Spring 框架加载完全后,扫描 bean,获取 servlet
  213. initSwagger(event.getApplicationContext());
  214. }
  215. }