package com.primeton.dsp.datarelease.api.sql;
import com.alibaba.druid.sql.SQLUtils;
import com.google.common.base.Preconditions;
import com.primeton.dsp.datarelease.api.model.Field;
import com.primeton.dsp.datarelease.api.model.SelectField;
import com.primeton.dsp.datarelease.api.model.Table;
import com.primeton.dsp.datarelease.api.model.WhereCause;
import lombok.NonNull;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.DateValue;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.TimestampValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.util.SelectUtils;
import org.apache.commons.lang.StringUtils;
import java.util.*;
import java.util.stream.Collectors;
/**
*
*
* Created by zhaopx.
* User: zhaopx
* Date: 2019/2/21
* Time: 10:29
* Vendor: primeton.com
*
*
*
* @author zhaopx
*/
public abstract class SQLJoin {
/**
* 输出 SQL
*
* @return 返回 sql str
*/
public abstract String show(boolean format);
public static class Builder {
/**
* 核心 Select
*/
Select select;
/**
* 合并后的 All Fields
*/
List fields;
/**
* 内部对表的缓存
*/
final Map CACHED_TABLE = new HashMap<>();
/**
* 别名和表的缓存
*/
final Map ALIAS_NAME_CACHED_TABLE = new HashMap<>();
/**
* 两个表关联
*
* @param left
* 左表
* @param right
* 右表
* @param type
* 关联方式
*/
public Builder(JoinTable left, String leftField, JoinTable right, String rightField, JoinType type) {
// 合并需要查询的列,所有的列,不能重复
List fields = new ArrayList<>();
// 去重复后的字段 ID
Set distictFields = new HashSet<>();
// 检查左表是否有重复的字段
for(SelectField col : left.getTableFields()) {
String tmpField = StringUtils.upperCase(col.getExpression().getColumnName());
if(!distictFields.contains(tmpField)) {
distictFields.add(tmpField);
fields.add(col);
}
}
// 检查右表是否有重复的字段
for(SelectField col : right.getTableFields()) {
String tmpField = StringUtils.upperCase(col.getExpression().getColumnName());
if(!distictFields.contains(tmpField)) {
distictFields.add(tmpField);
fields.add(col);
}
}
ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(left.getAlias()), left);
if(ALIAS_NAME_CACHED_TABLE.get(StringUtils.upperCase(right.getAlias())) != null) {
// 表的别名重名了,已经存在
right.setAlias(right.getAlias()+"1"); // 第一次重名,可以确定的
}
// 如果别名重复,可能修改了别名的
ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(right.getAlias()), right);
// 生成字段别名
Expression[] columns = (Expression[]) fields.toArray(new SelectField[fields.size()]);
Select select = SelectUtils.buildSelectFromTableAndExpressions(
left.getTable(), columns);
Join join = SelectUtils.addJoin(select, right.getTable(), null);
EqualsTo on = new EqualsTo();
on.setLeftExpression(left.getJoinColumn(leftField));
on.setRightExpression(right.getJoinColumn(rightField));
join.setOnExpression(on);
switch (type) {
case INNER_JOIN:
join.setInner(true);
break;
case LEFT_JOIN:
join.setLeft(true);
break;
case RIGHT_JOIN:
join.setRight(true);
break;
default:
join.setFull(true);
}
CACHED_TABLE.put(StringUtils.upperCase(left.getTableName()), left);
CACHED_TABLE.put(StringUtils.upperCase(right.getTableName()), right);
this.select = select;
this.fields = fields;
}
/**
* 以当前关联结果再次关联
* @param leftJoinTable 左边的关联表
* @param leftField 左边的关联字段
* @param rightJoinTable 右边关联表
* @param rightField 右边的关联字段
* @param type 关联方式
*/
public Builder join(JoinTable leftJoinTable, String leftField, JoinTable rightJoinTable, String rightField, JoinType type) {
JoinTable table1 = CACHED_TABLE.get(StringUtils.upperCase(leftJoinTable.getTableName()));
if (table1 == null) {
// 关联表没有加入到 cache
table1 = leftJoinTable;
CACHED_TABLE.put(StringUtils.upperCase(leftJoinTable.getTableName()), leftJoinTable);
}
// 检查别名是否重复
if(ALIAS_NAME_CACHED_TABLE.get(StringUtils.upperCase(rightJoinTable.getAlias())) != null) {
// 表的别名重名了,已经存在
rightJoinTable.setAlias(rightJoinTable.getAlias()+ALIAS_NAME_CACHED_TABLE.size());
}
// 如果别名重复,可能修改了别名的
ALIAS_NAME_CACHED_TABLE.put(StringUtils.upperCase(rightJoinTable.getAlias()), rightJoinTable);
try {
// 去重复后的字段 ID
Set distictFields = new HashSet<>();
for (SelectField col : this.fields) {
String tmpField = StringUtils.upperCase(col.getExpression().getColumnName());
distictFields.add(tmpField);
}
// 拿到所有需要查询的字段,可能存在重复的
List addTableFields = rightJoinTable.getTableFields();
List tableFields = new ArrayList<>(addTableFields.size());
// 检查左表是否有重复的字段
for(SelectField col : addTableFields) {
String tmpField = StringUtils.upperCase(col.getExpression().getColumnName());
// 不存在才放进来
if(!distictFields.contains(tmpField)) {
distictFields.add(tmpField);
fields.add(col);
tableFields.add(col.getExpression());
}
}
SelectItem[] addField = new SelectItem[tableFields.size()];
for (int i = 0; i < tableFields.size(); i++) {
addField[i] = new SelectExpressionItem(
CCJSqlParserUtil.parseExpression(tableFields.get(i)
.getName(true)));
}
SelectBody selectBody = select.getSelectBody();
((PlainSelect) selectBody).addSelectItems(addField);
} catch (JSQLParserException e) {
throw new IllegalStateException(e);
}
Join join = SelectUtils.addJoin(select, rightJoinTable.getTable(), null);
EqualsTo on2 = new EqualsTo();
on2.setLeftExpression(table1.getJoinColumn(leftField));
on2.setRightExpression(rightJoinTable.getJoinColumn(rightField));
join.setOnExpression(on2);
switch (type) {
case INNER_JOIN:
join.setInner(true);
break;
case LEFT_JOIN:
join.setLeft(true);
break;
case RIGHT_JOIN:
join.setRight(true);
break;
default:
join.setFull(true);
}
CACHED_TABLE.put(StringUtils.upperCase(rightJoinTable.getTableName()), rightJoinTable);
return this;
}
/**
* 是否是已经关联缓存的表
* @param tableName
* @return
*/
public boolean isCachedTable(@NonNull String tableName) {
return CACHED_TABLE.get(StringUtils.upperCase(tableName)) != null;
}
/**
* 设置 where 调用条件。调用 where 后就不应该再调用 join 了
* @return
*/
public Builder where(WhereCause... wheres) {
if(wheres == null || wheres.length == 0) {
// 没有可加的条件
return this;
}
PlainSelect ps = (PlainSelect)select.getSelectBody();
Expression where = ps.getWhere();
if(where == null && wheres.length == 1) {
// 一个条件,就这样了。
JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName()));
Expression expr = buildExpression(table, wheres[0]);
ps.setWhere(expr);
} else if(where == null){
// where is null,wheres 第一个不加 and,后续都加 and。
JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName()));
Expression firstExpr = buildExpression(table, wheres[0]);
WhereCause[] whereCauses1toEnd = new WhereCause[wheres.length - 1];
System.arraycopy(wheres, 1, whereCauses1toEnd, 0, whereCauses1toEnd.length);
ps.setWhere(buildWhereCause(firstExpr, whereCauses1toEnd));
} else {
// where is not null,第一个条件就需要加 and
ps.setWhere(buildWhereCause(where, wheres));
}
return this;
}
/**
* 创建循环的 where 条件
* @param wheres 一个或者多个 where
* @return
*/
private Expression buildWhereCause(Expression last, WhereCause... wheres) {
if(wheres.length == 1) {
JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName()));
Expression expression = buildExpression(table, wheres[0]);
BinaryExpression expr = null;
if("or".equalsIgnoreCase(wheres[0].getCond())){
expr = new OrExpression(last, expression);
} else {
expr = new AndExpression(last, expression);
}
return expr;
}
JoinTable table = CACHED_TABLE.get(StringUtils.upperCase(wheres[0].getTableName()));
Expression addExpr = buildExpression(table, wheres[0]);
BinaryExpression expr = null;
if("or".equalsIgnoreCase(wheres[0].getCond())){
expr = new OrExpression(last, addExpr);
} else {
expr = new AndExpression(last, addExpr);
}
WhereCause[] whereCauses1toEnd = new WhereCause[wheres.length - 1];
System.arraycopy(wheres, 1, whereCauses1toEnd, 0, whereCauses1toEnd.length);
// 递归处理每一个表达式
return buildWhereCause(expr, whereCauses1toEnd);
}
/**
* 根据 where 条件,把前台选择的条件转为 sql 支持的结构。
* @param table 表名称
* @param cause 条件
* @return
*/
private Expression buildExpression(JoinTable table, WhereCause cause) {
String[] mutilValue = cause.getValue() != null ? cause.getValue().split(",") : new String[]{};
if (mutilValue.length > 1) {
// 多值的情况 select * from table where id in('a', 'b')
// 多值的必须是 in 或者 notin 的情况,如果不是,强制改变语句为 in 的条件
String opera = cause.getOpera();
if(!"in".equalsIgnoreCase(opera) && !"notin".equalsIgnoreCase(opera)) {
cause.setOpera("in");
}
// 多个值的情况
return buildSingleValue(table, cause);
}
// 单值的条件
return buildSingleValue(table, cause);
}
/**
* 编译单值的条件
* @param table
* @param cause
* @return
*/
private Expression buildSingleValue(JoinTable table, WhereCause cause) {
Expression valueExpr = null;
if(StringUtils.isNotBlank(cause.getToTableName()) && StringUtils.isNotBlank(cause.getToFieldName())) {
// 第二个表名称和表字段名都不为 null,则表达式的值为第二个表中的字段
// a.AGE > b.AGE
JoinTable tmpTable = this.CACHED_TABLE.get(StringUtils.upperCase(cause.getToTableName()));
valueExpr = new Column(tmpTable.getTable(), cause.getToFieldName());
} else if("INT".equalsIgnoreCase(cause.getType())) {
// 表达式为常量,但是值为数值类型,SQL 中数值不加引号
valueExpr = new LongValue(cause.getValue());
} else if("DOUBLE".equalsIgnoreCase(cause.getType())) {
// 表达式为常量,但是值为浮点类型,SQL 中数值不加引号
valueExpr = new DoubleValue(cause.getValue());
} else if("DATE".equalsIgnoreCase(cause.getType())) {
// 表达式为常量,但是值为日期类型,SQL 中数值不加引号
// 日期类型为:yyyy-[M]M-[d]d
valueExpr = new DateValue(cause.getValue());
} else if("DATETIME".equalsIgnoreCase(cause.getType())) {
// 表达式为常量,但是值为日期类型,SQL 中数值不加引号
// 日期类型为:yyyy-[M]M-[d]d HH:mm:ss
valueExpr = new TimestampValue(cause.getValue());
} else {
// 表达式值为常量,字符串, NAME = 'X'
valueExpr = new StringValue(cause.getValue());
}
if("=".equals(cause.getOpera())) {
EqualsTo equals = new EqualsTo();
equals.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
equals.setRightExpression(valueExpr);
return equals;
} else if(">".equals(cause.getOpera())) {
GreaterThan greaterThan = new GreaterThan();
greaterThan.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
greaterThan.setRightExpression(valueExpr);
return greaterThan;
} else if(">=".equals(cause.getOpera())) {
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
greaterThanEquals.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
greaterThanEquals.setRightExpression(valueExpr);
return greaterThanEquals;
} else if("<".equals(cause.getOpera())) {
MinorThan minorThan = new MinorThan();
minorThan.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
minorThan.setRightExpression(valueExpr);
return minorThan;
} else if("<=".equals(cause.getOpera())) {
MinorThanEquals minorThanEquals = new MinorThanEquals();
minorThanEquals.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
minorThanEquals.setRightExpression(valueExpr);
return minorThanEquals;
} else if("<>".equals(cause.getOpera()) || "!=".equals(cause.getOpera())) {
NotEqualsTo notEqualsTo = new NotEqualsTo();
notEqualsTo.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
notEqualsTo.setRightExpression(valueExpr);
return notEqualsTo;
} else if("in".equalsIgnoreCase(cause.getOpera())) {
String[] mutilValue = cause.getValue().split(",");
InExpression inExpression = new InExpression();
inExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
MultiExpressionList list = new MultiExpressionList();
List inVals = new ArrayList<>();
for (String val : mutilValue) {
inVals.add(new StringValue(val));
}
list.addExpressionList(new ExpressionList(inVals));
inExpression.setRightItemsList(list);
return inExpression;
} else if("notin".equalsIgnoreCase(cause.getOpera())) {
String[] mutilValue = cause.getValue().split(",");
InExpression inExpression = new InExpression();
inExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
// not 否定条件
inExpression.setNot(true);
MultiExpressionList list = new MultiExpressionList();
List inVals = new ArrayList<>();
for (String val : mutilValue) {
inVals.add(new StringValue(val));
}
list.addExpressionList(new ExpressionList(inVals));
inExpression.setRightItemsList(list);
return inExpression;
} else if("l".equalsIgnoreCase(cause.getOpera())) {
Preconditions.checkNotNull(cause.getValue(), " like value must not be blank");
// 如果自带 %,则说明需要匹配值的 %,用转义
String likeValue = cause.getValue().replaceAll("%", "\\%");
valueExpr = new StringValue("%"+likeValue+"%");
LikeExpression likeExpression = new LikeExpression();
likeExpression.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
likeExpression.setRightExpression(valueExpr);
return likeExpression;
}
EqualsTo equals = new EqualsTo();
equals.setLeftExpression(new Column(table.getTable(), cause.getFieldName()));
equals.setRightExpression(valueExpr);
return equals;
}
/**
* 设置 where 调用条件。调用 where 后就不应该再调用 join 了
* @return
*/
public Builder where(Collection wheres) {
where(wheres.toArray(new WhereCause[wheres.size()]));
return this;
}
/**
* 创建 SQL
*
* @return
*/
public SQLJoin build() {
return new SQLJoin() {
@Override
public String show(boolean format) {
if(!format) {
return select.toString();
}
String sql = select.toString();
sql = SQLUtils.formatMySql(sql);
return sql;
}
};
}
}
/**
* 数据表
* @author zhaopx
*
*/
public static class JoinTable extends Table {
/**
* 内部 sqlparser 的表结构
*/
final net.sf.jsqlparser.schema.Table table;
public JoinTable(Table table){
this(table.getTableName(), table.getAlias(), table.getFields());
}
/**
* 构造器
*
* @param tableName
* 表名称
* @param alias
* 关联时的别名
* @param fields
* 关联后 select 的表字段,当前表字段
*/
public JoinTable(String tableName, String alias, List fields) {
super(tableName, alias, fields);
table = new net.sf.jsqlparser.schema.Table(tableName);
table.setAlias(new Alias(alias, false));
}
public net.sf.jsqlparser.schema.Table getTable() {
return table;
}
public List getTableFields() {
if (getFields() == null) {
return Collections.emptyList();
}
List cols = new ArrayList<>(getFields().size());
for (Field col : getFields()) {
Column column = new Column(table, col.getFieldName());
SelectField exprs = new SelectField(column);
exprs.setAlias(new Alias(col.getAlias(), true));
cols.add(exprs);
}
return cols;
}
public Column getJoinColumn(String joinField) {
return new Column(table, joinField);
}
@Override
public void setAlias(String alias) {
super.setAlias(alias);
table.setAlias(new Alias(alias, false));
}
}
/**
* 关联类型
* @author zhaopx
*
*/
public static enum JoinType {
/**
* 内连接
*/
INNER_JOIN,
/**
* 左外连接
*/
LEFT_JOIN,
/**
* 右外连接
*/
RIGHT_JOIN,
/**
* 全连接
*/
FULL_OUTER_JOIN;
public static JoinType joinType(String joinType) {
if("INNER_JOIN".equalsIgnoreCase(joinType)) {
return INNER_JOIN;
} else if("LEFT_JOIN".equalsIgnoreCase(joinType)) {
return LEFT_JOIN;
} else if("RIGHT_JOIN".equalsIgnoreCase(joinType)) {
return RIGHT_JOIN;
} else if("FULL_OUTER_JOIN".equalsIgnoreCase(joinType)) {
return FULL_OUTER_JOIN;
}
return INNER_JOIN;
}
}
}