package org.apache.kyuubi.engine.jdbc.session; import org.apache.kyuubi.config.KyuubiConf; import org.apache.kyuubi.engine.jdbc.connection.JdbcConnectionProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.PrintWriter; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.sql.Connection; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.util.Date; import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; /** * 自实现的基于 JdbcConnectionProvider 的数据源 连接池 * @author zhaopx */ public class SimpleDataSource implements PooledDataSource { final Logger log = LoggerFactory.getLogger(SimpleDataSource.class); /** * 缓存的池子 */ private final ConcurrentMap pool = new ConcurrentHashMap(); /** * 5 个连接,最大 */ private int maxSize = 10; /** * 最小连接数 */ private int minSize = 1; /** * 连接等待时间 */ private int waitTime = 30000; private Semaphore semaphore; private final Properties connProps = new Properties(); private JdbcConnectionProvider jdbcConnectionProvider; private KyuubiConf kyuubiConf; public SimpleDataSource(JdbcConnectionProvider jdbcConnectionProvider, KyuubiConf kyuubiConf) { this.jdbcConnectionProvider = jdbcConnectionProvider; this.kyuubiConf = kyuubiConf; this.maxSize = Math.max(Integer.parseInt(kyuubiConf.getOption("kyuubi.engine.jdbc.pool.maxSize").getOrElse(()-> "10")), 10); this.minSize = Math.max(Integer.parseInt(kyuubiConf.getOption("kyuubi.engine.jdbc.pool.minSize").getOrElse(()-> "1")), 1); this.waitTime = Math.max(Integer.parseInt(kyuubiConf.getOption("kyuubi.engine.jdbc.pool.maxWait").getOrElse(()-> "30000")), 1000); initConnections(); } private void initConnections() { log.info("Initializing simple data source{ pool.max = " + maxSize + ", pool.min = " + minSize + "}"); semaphore = new Semaphore(maxSize, false); if (minSize > 0 && minSize < maxSize) { try { // 尝试获得连接 Connection conn = getRealConnection(null, null); conn.close(); } catch (SQLException e) { throw new RuntimeException(e); } } } public void close() throws IOException { Exception ex = null; for (JdbcConnectionWrapper conn : pool.keySet()) { try { conn.directClose(); } catch (Exception e) { ex = e; } } pool.clear(); if(ex != null) { throw new IOException(ex); } log.info("closed data source{ pool.max = " + maxSize + ", pool.min = " + minSize + "}"); } /** * 关闭连接,这里是软关闭 * @param realConnection * @throws SQLException */ private void closeConnection(Connection realConnection, Connection proxyConnection) throws SQLException { synchronized (pool) { if (pool.size() <= maxSize && realConnection instanceof JdbcConnectionWrapper && ((JdbcConnectionWrapper)realConnection).isValidFlag()) { // 正常的连接不关闭,放到池中 pool.put((JdbcConnectionWrapper)realConnection, new Date()); return; } else if(pool.size() <= maxSize && !(realConnection instanceof JdbcConnectionWrapper)) { pool.put(new JdbcConnectionWrapper(proxyConnection, realConnection), new Date()); return; } } try { realConnection.close(); } finally { semaphore.release(); } } /** * 关闭连接,这里是软关闭 * @param realConnection * @throws SQLException */ public void closeConnectionAndRemove(Connection realConnection) throws SQLException { if(realConnection == null) { return; } synchronized (pool) { // 从缓存移除 if(realConnection instanceof JdbcConnectionWrapper) { pool.remove((JdbcConnectionWrapper) realConnection); } else { pool.remove(realConnection); } } try { if(realConnection instanceof JdbcConnectionWrapper) { ((JdbcConnectionWrapper)realConnection).directClose(); } else { realConnection.close(); } } catch (Exception ignore) { } finally { semaphore.release(); } } public Connection getConnection() throws SQLException { return getConnection(null, null); } public Connection getConnection(String username, String password) throws SQLException { synchronized (pool) { if (!pool.isEmpty()) { JdbcConnectionWrapper realConn = pool.keySet().iterator().next(); pool.remove(realConn); if(realConn.isValidFlag()) { return realConn; } // hive jdbc 不支持设置 AutoCommit //realConn.setAutoCommit(true); return getProxyConnection(realConn); } } try { if (semaphore.tryAcquire(waitTime, TimeUnit.MILLISECONDS)) { return getProxyConnection(getRealConnection(username, password)); } else { throw new RuntimeException("Connection pool is full: " + maxSize); } } catch (SQLException e) { semaphore.release(); throw e; } catch (InterruptedException e) { throw new RuntimeException(e); } } private Connection getProxyConnection(final Connection realConnection) { InvocationHandler handler = new InvocationHandler() { public Object invoke(Object proxy, Method method, Object[] params) throws Exception { Object ret = null; if ("close".equals(method.getName())) { closeConnection(realConnection, (Connection)proxy); } else if ("directClose".equals(method.getName())) { // 实际的关闭 try { realConnection.close(); } catch (Exception ignore) {} ret = Void.TYPE.newInstance(); } else if ("unwrap".equals(method.getName())) { ret = realConnection; } else { ret = method.invoke(realConnection, params); } return ret; } }; return new JdbcConnectionWrapper((JdbcConnection) Proxy.newProxyInstance(JdbcConnection.class.getClassLoader(), new Class[] { JdbcConnection.class }, handler), realConnection); } public Connection getRealConnection(String username, String password) throws SQLException { try { return jdbcConnectionProvider.getConnection(kyuubiConf); } catch (Exception e) { throw new SQLException(e); } } public void setProperties(Properties properties){ this.connProps.putAll(properties); } public PrintWriter getLogWriter() throws SQLException { return null; } public void setLogWriter(PrintWriter out) throws SQLException { } public void setLoginTimeout(int seconds) throws SQLException { } public int getLoginTimeout() throws SQLException { return 0; } public java.util.logging.Logger getParentLogger() throws SQLFeatureNotSupportedException { return null; } public T unwrap(Class iface) throws SQLException { return null; } public boolean isWrapperFor(Class iface) throws SQLException { return false; } public void setIdleValidationQuery(int idleInSeconds,String validationQuery){ //do noting } public int getMaxSize() { return maxSize; } public int getMinSize() { return minSize; } public int getWaitTime() { return waitTime; } public Properties getConnProps() { return connProps; } /** * 返回内置的绑定 * @return */ public Set getPoolSet() { return pool.keySet(); } }