package com.primeton.dsp.datarelease.api.sql; import com.alibaba.druid.sql.SQLUtils; import com.google.common.base.Preconditions; import com.primeton.dsp.datarelease.api.model.Field; import com.primeton.dsp.datarelease.api.model.SelectField; import com.primeton.dsp.datarelease.api.model.Table; import com.primeton.dsp.datarelease.api.model.WhereCause; import lombok.NonNull; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Alias; import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.DateValue; import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.TimestampValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.*; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.*; import net.sf.jsqlparser.util.SelectUtils; import org.apache.commons.lang.StringUtils; import java.util.*; import java.util.stream.Collectors; /** *
 * 
 * Created by zhaopx.
 * User: zhaopx
 * Date: 2019/2/21
 * Time: 10:29
 * Vendor: primeton.com
 *
 * 
* * @author zhaopx */ public abstract class SQLJoin { /** * 输出 SQL * * @return 返回 sql str */ public abstract String show(boolean format); public static class Builder { /** * 核心 Select */ Select select; /** * 合并后的 All Fields */ List fields; /** * 内部对表的缓存 */ final Map CACHED_TABLE = new HashMap<>(); /** * 别名和表的缓存 */ final Map ALIAS_NAME_CACHED_TABLE = new HashMap<>(); /** * 两个表关联 * * @param left * 左表 * @param right * 右表 * @param type * 关联方式 */ public Builder(JoinTable left, String leftField, JoinTable right, String rightField, JoinType type) { // 合并需要查询的列,所有的列,不能重复 List fields = new ArrayList<>(); // 去重复后的字段 ID Set distictFields = new HashSet<>(); // 检查左表是否有重复的字段 for(SelectField col : left.getTableFields()) { String tmpField = StringUtils.upperCase(col.getExpression().getColumnName()); if(!distictFields.contains(tmpField)) { distictFields.add(tmpField); fields.add(col); } } // 检查右表是否有重复的字段 for(SelectField col : right.getTableFields()) { String tmpField = StringUtils.upperCase(col.getExpression().getColumnName()); if(!distictFields.contains(tmpField)) { distictFields.add(tmpField); fields.add(col); } } ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(left.getAlias()), left); if(ALIAS_NAME_CACHED_TABLE.get(StringUtils.upperCase(right.getAlias())) != null) { // 表的别名重名了,已经存在 right.setAlias(right.getAlias()+"1"); // 第一次重名,可以确定的 } // 如果别名重复,可能修改了别名的 ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(right.getAlias()), right); // 生成字段别名 Expression[] columns = (Expression[]) fields.toArray(new SelectField[fields.size()]); Select select = SelectUtils.buildSelectFromTableAndExpressions( left.getTable(), columns); Join join = SelectUtils.addJoin(select, right.getTable(), null); EqualsTo on = new EqualsTo(); on.setLeftExpression(left.getJoinColumn(leftField)); on.setRightExpression(right.getJoinColumn(rightField)); join.setOnExpression(on); switch (type) { case INNER_JOIN: join.setInner(true); break; case LEFT_JOIN: join.setLeft(true); break; case RIGHT_JOIN: join.setRight(true); break; default: join.setFull(true); } CACHED_TABLE.put(StringUtils.upperCase(left.getTableName()), left); CACHED_TABLE.put(StringUtils.upperCase(right.getTableName()), right); this.select = select; this.fields = fields; } /** * 以当前关联结果再次关联 * @param leftJoinTable 左边的关联表 * @param leftField 左边的关联字段 * @param rightJoinTable 右边关联表 * @param rightField 右边的关联字段 * @param type 关联方式 */ public Builder join(JoinTable leftJoinTable, String leftField, JoinTable rightJoinTable, String rightField, JoinType type) { JoinTable table1 = CACHED_TABLE.get(StringUtils.upperCase(leftJoinTable.getTableName())); if (table1 == null) { // 关联表没有加入到 cache table1 = leftJoinTable; CACHED_TABLE.put(StringUtils.upperCase(leftJoinTable.getTableName()), leftJoinTable); } // 检查别名是否重复 if(ALIAS_NAME_CACHED_TABLE.get(StringUtils.upperCase(rightJoinTable.getAlias())) != null) { // 表的别名重名了,已经存在 rightJoinTable.setAlias(rightJoinTable.getAlias()+ALIAS_NAME_CACHED_TABLE.size()); } // 如果别名重复,可能修改了别名的 ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(rightJoinTable.getAlias()), rightJoinTable); try { // 去重复后的字段 ID Set distictFields = new HashSet<>(); for (SelectField col : this.fields) { String tmpField = StringUtils.upperCase(col.getExpression().getColumnName()); distictFields.add(tmpField); } // 拿到所有需要查询的字段,可能存在重复的 List addTableFields = rightJoinTable.getTableFields(); List tableFields = new ArrayList<>(addTableFields.size()); // 检查左表是否有重复的字段 for(SelectField col : addTableFields) { String tmpField = StringUtils.upperCase(col.getExpression().getColumnName()); // 不存在才放进来 if(!distictFields.contains(tmpField)) { distictFields.add(tmpField); fields.add(col); tableFields.add(col.getExpression()); } } SelectItem[] addField = new SelectItem[tableFields.size()]; for (int i = 0; i < tableFields.size(); i++) { addField[i] = new SelectExpressionItem( CCJSqlParserUtil.parseExpression(tableFields.get(i) .getName(true))); } SelectBody selectBody = select.getSelectBody(); ((PlainSelect) selectBody).addSelectItems(addField); } catch (JSQLParserException e) { throw new IllegalStateException(e); } Join join = SelectUtils.addJoin(select, rightJoinTable.getTable(), null); EqualsTo on2 = new EqualsTo(); on2.setLeftExpression(table1.getJoinColumn(leftField)); on2.setRightExpression(rightJoinTable.getJoinColumn(rightField)); join.setOnExpression(on2); switch (type) { case INNER_JOIN: join.setInner(true); break; case LEFT_JOIN: join.setLeft(true); break; case RIGHT_JOIN: join.setRight(true); break; default: join.setFull(true); } CACHED_TABLE.put(StringUtils.upperCase(rightJoinTable.getTableName()), rightJoinTable); return this; } /** * 是否是已经关联缓存的表 * @param tableName * @return */ public boolean isCachedTable(@NonNull String tableName) { return CACHED_TABLE.get(StringUtils.upperCase(tableName)) != null; } /** * 设置 where 调用条件。调用 where 后就不应该再调用 join 了 * @return */ public Builder where(WhereCause... wheres) { if(wheres == null || wheres.length == 0) { // 没有可加的条件 return this; } PlainSelect ps = (PlainSelect)select.getSelectBody(); Expression where = ps.getWhere(); if(where == null && wheres.length == 1) { // 一个条件,就这样了。 JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName())); Expression expr = buildExpression(table, wheres[0]); ps.setWhere(expr); } else if(where == null){ // where is null,wheres 第一个不加 and,后续都加 and。 JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName())); Expression firstExpr = buildExpression(table, wheres[0]); WhereCause[] whereCauses1toEnd = new WhereCause[wheres.length - 1]; System.arraycopy(wheres, 1, whereCauses1toEnd, 0, whereCauses1toEnd.length); ps.setWhere(buildWhereCause(firstExpr, whereCauses1toEnd)); } else { // where is not null,第一个条件就需要加 and ps.setWhere(buildWhereCause(where, wheres)); } return this; } /** * 创建循环的 where 条件 * @param wheres 一个或者多个 where * @return */ private Expression buildWhereCause(Expression last, WhereCause... wheres) { if(wheres.length == 1) { JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName())); Expression expression = buildExpression(table, wheres[0]); BinaryExpression expr = null; if("or".equalsIgnoreCase(wheres[0].getCond())){ expr = new OrExpression(last, expression); } else { expr = new AndExpression(last, expression); } return expr; } JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName())); Expression addExpr = buildExpression(table, wheres[0]); BinaryExpression expr = null; if("or".equalsIgnoreCase(wheres[0].getCond())){ expr = new OrExpression(last, addExpr); } else { expr = new AndExpression(last, addExpr); } WhereCause[] whereCauses1toEnd = new WhereCause[wheres.length - 1]; System.arraycopy(wheres, 1, whereCauses1toEnd, 0, whereCauses1toEnd.length); // 递归处理每一个表达式 return buildWhereCause(expr, whereCauses1toEnd); } /** * 根据 where 条件,把前台选择的条件转为 sql 支持的结构。 * @param table 表名称 * @param cause 条件 * @return */ private Expression buildExpression(JoinTable table, WhereCause cause) { String[] mutilValue = cause.getValue() != null ? cause.getValue().split(",") : new String[]{}; if (mutilValue.length > 1) { // 多值的情况 select * from table where id in('a', 'b') // 多值的必须是 in 或者 notin 的情况,如果不是,强制改变语句为 in 的条件 String opera = cause.getOpera(); if(!"in".equalsIgnoreCase(opera) && !"notin".equalsIgnoreCase(opera)) { cause.setOpera("in"); } // 多个值的情况 return buildSingleValue(table, cause); } // 单值的条件 return buildSingleValue(table, cause); } /** * 编译单值的条件 * @param table * @param cause * @return */ private Expression buildSingleValue(JoinTable table, WhereCause cause) { Expression valueExpr = null; if(StringUtils.isNotBlank(cause.getToTableName()) && StringUtils.isNotBlank(cause.getToFieldName())) { // 第二个表名称和表字段名都不为 null,则表达式的值为第二个表中的字段 // a.AGE > b.AGE JoinTable tmpTable = this.CACHED_TABLE.get(StringUtils.upperCase(cause.getToTableName())); valueExpr = new Column(tmpTable.getTable(), cause.getToFieldName()); } else if("INT".equalsIgnoreCase(cause.getType())) { // 表达式为常量,但是值为数值类型,SQL 中数值不加引号 valueExpr = new LongValue(cause.getValue()); } else if("DOUBLE".equalsIgnoreCase(cause.getType())) { // 表达式为常量,但是值为浮点类型,SQL 中数值不加引号 valueExpr = new DoubleValue(cause.getValue()); } else if("DATE".equalsIgnoreCase(cause.getType())) { // 表达式为常量,但是值为日期类型,SQL 中数值不加引号 // 日期类型为:yyyy-[M]M-[d]d valueExpr = new DateValue(cause.getValue()); } else if("DATETIME".equalsIgnoreCase(cause.getType())) { // 表达式为常量,但是值为日期类型,SQL 中数值不加引号 // 日期类型为:yyyy-[M]M-[d]d HH:mm:ss valueExpr = new TimestampValue(cause.getValue()); } else { // 表达式值为常量,字符串, NAME = 'X' valueExpr = new StringValue(cause.getValue()); } if("=".equals(cause.getOpera())) { EqualsTo equals = new EqualsTo(); equals.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); equals.setRightExpression(valueExpr); return equals; } else if(">".equals(cause.getOpera())) { GreaterThan greaterThan = new GreaterThan(); greaterThan.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); greaterThan.setRightExpression(valueExpr); return greaterThan; } else if(">=".equals(cause.getOpera())) { GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); greaterThanEquals.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); greaterThanEquals.setRightExpression(valueExpr); return greaterThanEquals; } else if("<".equals(cause.getOpera())) { MinorThan minorThan = new MinorThan(); minorThan.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); minorThan.setRightExpression(valueExpr); return minorThan; } else if("<=".equals(cause.getOpera())) { MinorThanEquals minorThanEquals = new MinorThanEquals(); minorThanEquals.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); minorThanEquals.setRightExpression(valueExpr); return minorThanEquals; } else if("<>".equals(cause.getOpera()) || "!=".equals(cause.getOpera())) { NotEqualsTo notEqualsTo = new NotEqualsTo(); notEqualsTo.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); notEqualsTo.setRightExpression(valueExpr); return notEqualsTo; } else if("in".equalsIgnoreCase(cause.getOpera())) { String[] mutilValue = cause.getValue().split(","); InExpression inExpression = new InExpression(); inExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); MultiExpressionList list = new MultiExpressionList(); List inVals = new ArrayList<>(); for (String val : mutilValue) { inVals.add(new StringValue(val)); } list.addExpressionList(new ExpressionList(inVals)); inExpression.setRightItemsList(list); return inExpression; } else if("notin".equalsIgnoreCase(cause.getOpera())) { String[] mutilValue = cause.getValue().split(","); InExpression inExpression = new InExpression(); inExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); // not 否定条件 inExpression.setNot(true); MultiExpressionList list = new MultiExpressionList(); List inVals = new ArrayList<>(); for (String val : mutilValue) { inVals.add(new StringValue(val)); } list.addExpressionList(new ExpressionList(inVals)); inExpression.setRightItemsList(list); return inExpression; } else if("l".equalsIgnoreCase(cause.getOpera())) { Preconditions.checkNotNull(cause.getValue(), " like value must not be blank"); // 如果自带 %,则说明需要匹配值的 %,用转义 String likeValue = cause.getValue().replaceAll("%", "\\%"); valueExpr = new StringValue("%"+likeValue+"%"); LikeExpression likeExpression = new LikeExpression(); likeExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); likeExpression.setRightExpression(valueExpr); return likeExpression; } EqualsTo equals = new EqualsTo(); equals.setLeftExpression(new Column(table.getTable(), cause.getFieldName())); equals.setRightExpression(valueExpr); return equals; } /** * 设置 where 调用条件。调用 where 后就不应该再调用 join 了 * @return */ public Builder where(Collection wheres) { where(wheres.toArray(new WhereCause[wheres.size()])); return this; } /** * 创建 SQL * * @return */ public SQLJoin build() { return new SQLJoin() { @Override public String show(boolean format) { if(!format) { return select.toString(); } String sql = select.toString(); sql = SQLUtils.formatMySql(sql); return sql; } }; } } /** * 数据表 * @author zhaopx * */ public static class JoinTable extends Table { /** * 内部 sqlparser 的表结构 */ final net.sf.jsqlparser.schema.Table table; public JoinTable(Table table){ this(table.getTableName(), table.getAlias(), table.getFields()); } /** * 构造器 * * @param tableName * 表名称 * @param alias * 关联时的别名 * @param fields * 关联后 select 的表字段,当前表字段 */ public JoinTable(String tableName, String alias, List fields) { super(tableName, alias, fields); table = new net.sf.jsqlparser.schema.Table(tableName); table.setAlias(new Alias(alias, false)); } public net.sf.jsqlparser.schema.Table getTable() { return table; } public List getTableFields() { if (getFields() == null) { return Collections.emptyList(); } List cols = new ArrayList<>(getFields().size()); for (Field col : getFields()) { Column column = new Column(table, col.getFieldName()); SelectField exprs = new SelectField(column); exprs.setAlias(new Alias(col.getAlias(), true)); cols.add(exprs); } return cols; } public Column getJoinColumn(String joinField) { return new Column(table, joinField); } @Override public void setAlias(String alias) { super.setAlias(alias); table.setAlias(new Alias(alias, false)); } } /** * 关联类型 * @author zhaopx * */ public static enum JoinType { /** * 内连接 */ INNER_JOIN, /** * 左外连接 */ LEFT_JOIN, /** * 右外连接 */ RIGHT_JOIN, /** * 全连接 */ FULL_OUTER_JOIN; public static JoinType joinType(String joinType) { if("INNER_JOIN".equalsIgnoreCase(joinType)) { return INNER_JOIN; } else if("LEFT_JOIN".equalsIgnoreCase(joinType)) { return LEFT_JOIN; } else if("RIGHT_JOIN".equalsIgnoreCase(joinType)) { return RIGHT_JOIN; } else if("FULL_OUTER_JOIN".equalsIgnoreCase(joinType)) { return FULL_OUTER_JOIN; } return INNER_JOIN; } } }