package com.primeton.dsp.dataservice.utils; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbutils.DbUtils; import org.apache.commons.dbutils.QueryRunner; import org.apache.commons.dbutils.handlers.MapListHandler; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import javax.sql.DataSource; import java.io.Closeable; import java.io.IOException; import java.sql.*; import java.sql.Date; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; /** * * 数据库执行类。主要封装了查询,更新,批量更新等一些执行SQL的方法。 * *
 *
 * Created by zhaopx.
 * User: zhaopx
 * Date: 2020-03-26
 * Time: 10:20
 *
 * 
* * @author zhaopx */ @Slf4j public class SQLRunner implements Closeable { protected final DataSource ds; public SQLRunner(Properties config) { this(config.getProperty("jdbc.driverClassName"), config.getProperty("jdbc.url"), config.getProperty("jdbc.user"), config.getProperty("jdbc.password")); } public SQLRunner(String driverClass, String url, String user, String password) { BasicDataSource dataSource = new BasicDataSource(); dataSource.setMaxTotal(2); dataSource.setMinIdle(1); dataSource.setInitialSize(1); dataSource.setDriverClassName(driverClass); dataSource.setUrl(url); dataSource.setUsername(user); dataSource.setPassword(password); this.ds = dataSource; } public SQLRunner(@NonNull DataSource ds) { this.ds = ds; } public int update(String sql) throws SQLException { if(sql == null) { throw new SQLException("Null SQL statement"); } Connection conn = ds.getConnection(); if(conn == null) { throw new SQLException("Null connection"); } Statement statement = null; int r = 0; try { statement = conn.createStatement(); r = statement.executeUpdate(sql); } finally { close(conn, statement, null); } return r; } public int update(String sql, Object[] params) throws SQLException { QueryRunner runner = new QueryRunner(this.ds); return runner.update(sql, params); } /** * 批量更新 * @param sql SQL 值 ? 代替 * @param params 参数数组 * @return * @throws SQLException */ public int[] updateBatch(String sql, List params) throws SQLException { QueryRunner runner = new QueryRunner(this.ds); return runner.batch(sql, params.toArray(new Object[params.size()][])); } /** * 查询 sql,带参数 * @param sql * @param params * @return * @throws SQLException */ public QueryResult query(String sql, Object[] params) throws SQLException { if(params.length == 0) { // 不带参数的查询 return query(sql); } // 没有参数,QueryRunner 会报错 QueryRunner runner = new QueryRunner(this.ds); RSTypeHandler typeHandler = new RSTypeHandler(); List> list = runner.query(sql, typeHandler, params); QueryResult queryResult = new QueryResult(typeHandler.getTypes(), list); queryResult.setSql(sql); return queryResult; } /** * 无参数的查询 * @param sql * @return * @throws SQLException */ public QueryResult query(String sql) throws SQLException { Connection conn = ds.getConnection(); if(conn == null) { throw new SQLException("Null connection"); } else if(sql == null) { DbUtils.close(conn); throw new SQLException("Null SQL statement"); } else { Statement statement = null; ResultSet rs = null; try { statement = conn.createStatement(); rs = statement.executeQuery(sql); RSTypeHandler typeHandler = new RSTypeHandler(); List> list = typeHandler.handle(rs); QueryResult queryResult = new QueryResult(typeHandler.getTypes(), list); queryResult.setSql(sql); return queryResult; } catch (SQLException e) { // 没有返回值 if(e.getMessage().contains("not generate a result set")) { log.warn("execute sql: '"+sql+"' success, no result set.", e); QueryResult queryResult = new QueryResult(Collections.>emptyList(), Collections.>emptyList()); queryResult.setSql(sql); return queryResult; } throw e; } finally { close(conn, statement, rs); } } } /** * 查询 sql,带参数 * @param sql * @param params * @return * @throws SQLException */ public QueryResult queryOne(String sql, Object[] params) throws SQLException { if(params.length == 0) { // 不带参数的查询 return query(sql); } // 没有参数,QueryRunner 会报错 QueryRunner runner = new QueryRunner(this.ds); RSTypeHandler typeHandler = new RSTypeHandler(); List> list = runner.query(sql, typeHandler, params); QueryResult queryResult = new QueryResult(typeHandler.getTypes(), list); queryResult.setSql(sql); return queryResult; } @Override public void close() throws IOException { try { ((BasicDataSource)ds).close(); } catch (SQLException e) { } } /** * 安静的关闭数据库链接。 * * @param conn * @param ps * @param rs */ public static void close(Connection conn, Statement ps, ResultSet rs) { try { if (rs != null) { rs.close(); } } catch (Exception e) { } try { if (ps != null) { ps.close(); } } catch (Exception e) {} try { if (conn != null) { conn.close(); } } catch (Exception e) {} } /** * 根据 SQL 的 ${field} 按照 map 中的参数替换,组成一个新的 SQL 返回。 * @param sql 支持 ${field} 的变量替换 * @param params 参数,如果没有参数则会原路返回 * @return 返回新的变量替换的 SQL */ public static String getMatchSQL(String sql, Map params){ Matcher matcher = Pattern.compile("(\\$\\{\\w+\\})").matcher(sql); StringBuffer sb = new StringBuffer(); while (matcher.find()) { String group = matcher.group(); String field = StringUtils.trim(group.substring(2, group.length() - 1)); String val = (String)params.get(field); if(val == null) { continue; } matcher.appendReplacement(sb, val); } matcher.appendTail(sb); return sb.toString(); } /** * 获取一个 sql 内 from table 的table 名称,如果是join sql 则获取多个from的表名 * @param sql * @return */ public static String[] getSQLTable(String sql) { //select sql 里 from 后的字符和join后的字符,都是表名 //正则表达式,不区分大小写的模式 Pattern compile = Pattern.compile("(TABLE|JOIN|FROM){1}\\s+(\\w+)", Pattern.CASE_INSENSITIVE); Matcher matcher = compile.matcher(sql); List names = new ArrayList<>(3); while (matcher.find()) { String name = matcher.group(2); names.add(name); } return names.toArray(new String[names.size()]); } /** * * 把值 val 转为 type 的类型, type 为数据库的类型 * * @param val 字段值 * @param type 字段类型 * @return */ public static Object convertVal(Object val, String type) { if(val == null) { return null; } type = type.toUpperCase(); switch (type){ case "VARCHAR": case "VARCHAR2": case "TEXT": case "BLOB": return val instanceof String ? (String)val : String.valueOf(val); case "TINYINT": case "SMALLINT": case "INT": return val instanceof Number ? ((Number)val).intValue() : Integer.parseInt(String.valueOf(val)); case "DECIMAL": case "FLOAT": return val instanceof Number ? ((Number)val).floatValue() : Float.parseFloat(String.valueOf(val)); case "DOUBLE": return val instanceof Number ? ((Number)val).doubleValue() : Float.parseFloat(String.valueOf(val)); case "BIGINT": return val instanceof Number ? ((Number)val).longValue() : Long.parseLong(String.valueOf(val)); case "DATETIME": if(val instanceof java.util.Date){ return new Timestamp(((java.util.Date)val).getTime()); } else if(val instanceof Number) { return new Timestamp(((Number) val).longValue()); } try { return new Timestamp(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").parse(val.toString()).getTime()); } catch (ParseException e) { throw new IllegalArgumentException(e); } case "DATE": if(val instanceof java.util.Date){ return new Date(((java.util.Date)val).getTime()); } else if(val instanceof Number) { return new Date(((Number) val).longValue()); } try { return new Date(new SimpleDateFormat("yyyy-MM-dd").parse(val.toString()).getTime()); } catch (ParseException e) { throw new IllegalArgumentException(e); } case "TIME": if(val instanceof java.util.Date){ return new Time(((java.util.Date)val).getTime()); } else if(val instanceof Number) { return new Time(((Number) val).longValue()); } try { return new Time(new SimpleDateFormat("HH:mm:ss").parse(val.toString()).getTime()); } catch (ParseException e) { throw new IllegalArgumentException(e); } default: return val instanceof String ? (String)val : String.valueOf(val); } } }