package com.yiidata.amc.jdbc import java.util import java.util.Properties import com.google.common.collect.ImmutableList.Builder import com.google.common.collect.{UnmodifiableIterator, ImmutableList} import org.apache.commons.dbcp.BasicDataSource import org.apache.spark.api.java.function.VoidFunction import org.apache.spark.rdd.RDD /** * *
  * new JdbcDFWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  *    Array("DATETIME", "HOLIDAY"), select).save(pro)
  * 
* * *

* JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠列名取值,而JdbcRDDWriter 靠列下标取值。 * * PstateSetter 数量应该和SQL中的?数量相同 * *

* * Created by ZhenQin on 2018/2/6 0006-18:01 * Vendor: yiidata.com * */ class JdbcDFWriter(sql:String, columnsNames:Array[String], df:org.apache.spark.sql.DataFrame){ { val arrays: util.List[String] = new util.ArrayList[String](columnsNames.length) columnsNames.foreach(c=>{arrays.add(c)}) val split: Array[String] = (sql + " ").split("\\?") if(columnsNames.length != (split.length -1)){ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。") } // 如果传入的列有DataFrame 中不包含的列,则抛出异常 df.schema.foreach(s=>{arrays.remove(s.name)}) if(!arrays.isEmpty){ throw new IllegalArgumentException("unknown field: " + arrays.toString) } } /** * 保存该 DataFrame 到数据库 * * @param props 数据库连接信息 */ def save(props:Properties):Unit = { val builder: Builder[PstateSetter] = ImmutableList.builder() // 要注意顺序 val fieldMap:util.Map[String, PstateSetter] = new util.HashMap[String, PstateSetter](df.schema.length) df.schema.foreach(f=>{ f.dataType match { case org.apache.spark.sql.types.IntegerType => fieldMap.put(f.name, new IntPstateSetter(f.name)) case org.apache.spark.sql.types.LongType => fieldMap.put(f.name, new LongPstateSetter(f.name)) case org.apache.spark.sql.types.DoubleType => fieldMap.put(f.name, new DoublePstateSetter(f.name)) case org.apache.spark.sql.types.FloatType => fieldMap.put(f.name, new FloatPstateSetter(f.name)) case org.apache.spark.sql.types.ShortType => fieldMap.put(f.name, new ShortPstateSetter(f.name)) case org.apache.spark.sql.types.ByteType => fieldMap.put(f.name, new BytePstateSetter(f.name)) case org.apache.spark.sql.types.BooleanType => fieldMap.put(f.name, new BoolPstateSetter(f.name)) case org.apache.spark.sql.types.StringType => fieldMap.put(f.name, new StringPstateSetter(f.name)) case org.apache.spark.sql.types.BinaryType => fieldMap.put(f.name, new StringPstateSetter(f.name)) case org.apache.spark.sql.types.TimestampType => fieldMap.put(f.name, new TimestampPstateSetter(f.name)) case org.apache.spark.sql.types.DateType => fieldMap.put(f.name, new DatePstateSetter(f.name)) case t: org.apache.spark.sql.types.DecimalType => fieldMap.put(f.name, new DecimalPstateSetter(f.name)) case _ => None } }) // 顺序的加入每一个参数列 for(col <- columnsNames){ builder.add(fieldMap.get(col)) } df.javaRDD.foreachPartition(new JdbcPartitionFunction(sql, builder.build(), props)) } } /** * Jdbc RDD, 存储到数据库中。 * *
  * new JdbcRDDWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  *    Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), rdd)
  * 
* * $1 则取RDD Line 第二列数据,从 $0 开始 * $2 时取RDD Line 第三列数据,从 $0 开始 * * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠 列名取值,而JdbcRDDWriter 靠列下标取值。 * * 一般的 RDD 中应该为一个 Array(),这样的结果可以使用 array(0) 来取得对应的值。 * * PstateSetter 数量应该和SQL中的?数量相同 * * @param sql 执行更新的SQL * @param columnTypes 字段 * @param rdd RDD */ class JdbcRDDWriter(sql:String, columnTypes:Array[PstateSetter], @transient rdd:RDD[Array[Any]]) extends Serializable { { val split: Array[String] = (sql + " ").split("\\?") if(columnTypes.length != (split.length -1)){ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。") } } /** * 保存该 RDD 到数据库 * * @param props 数据库连接信息 */ def save(props:Properties):Unit = { val builder: Builder[PstateSetter] = ImmutableList.builder() rdd.foreachPartition(iter =>{ val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000")) val dataSource = new BasicDataSource() dataSource.setMaxActive(1) dataSource.setMinIdle(1) dataSource.setInitialSize(1) dataSource.setDriverClassName(props.getProperty("driver")) dataSource.setUrl(props.getProperty("url")) dataSource.setUsername(props.getProperty("user")) dataSource.setPassword(props.getProperty("password")) val conn = dataSource.getConnection conn.setAutoCommit(false) val st = conn.prepareStatement(sql) try { var counter = 0 while (iter.hasNext) { val line = iter.next() var i = 1 for (pstateSetter <- columnTypes) { val v = line(pstateSetter.index) pstateSetter.setValue(st, i, v) i += 1 } counter += 1 st.addBatch() // 一个批次,执行一次 if (counter >= batchSize) { st.executeBatch() st.clearBatch() counter = 0 } } // 最后不够一个批次的数据,一次性提交 if(counter % batchSize > 0) { st.executeBatch() } conn.commit(); } finally { // 关闭连接 try { st.close() conn.close() dataSource.close() } catch { case e:Exception => e.printStackTrace() } } }) } } /** * Jdbc Array(数组), 存储到数据库中。 * *
  * new JdbcArrayWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  *    Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), array)
  * 
* * $1 则取RDD Line 第二列数据,从 $0 开始 * $2 时取RDD Line 第三列数据,从 $0 开始 * * * 一般的 Array() 中应该为一个 Array[Any],这样的结果可以使用 array(0) 来取得对应的值。 * * PstateSetter 数量应该和SQL中的?数量相同 * * @param sql 执行更新的SQL * @param columnTypes 字段 * @param list Array */ class JdbcArrayWriter(sql:String, columnTypes:Array[PstateSetter], @transient list:Array[Array[Any]]) extends Serializable { { val split: Array[String] = (sql + " ").split("\\?") if(columnTypes.length != (split.length -1)){ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。") } } /** * 保存该 RDD 到数据库 * * @param props 数据库连接信息 */ def save(props:Properties):Unit = { val builder: Builder[PstateSetter] = ImmutableList.builder() val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000")) val dataSource = new BasicDataSource() dataSource.setMaxActive(1) dataSource.setMinIdle(1) dataSource.setInitialSize(1) dataSource.setDriverClassName(props.getProperty("driver")) dataSource.setUrl(props.getProperty("url")) dataSource.setUsername(props.getProperty("user")) dataSource.setPassword(props.getProperty("password")) val conn = dataSource.getConnection conn.setAutoCommit(false) val st = conn.prepareStatement(sql) try { var counter = 0 val iter = list.iterator while (iter.hasNext) { val line = iter.next() var i = 1 for (pstateSetter <- columnTypes) { val v = line(pstateSetter.index) pstateSetter.setValue(st, i, v) i += 1 } counter += 1 st.addBatch() // 一个批次,执行一次 if (counter >= batchSize) { st.executeBatch() st.clearBatch() counter = 0 } } // 最后不够一个批次的数据,一次性提交 if(counter % batchSize > 0) { st.executeBatch() } conn.commit(); } finally { // 关闭连接 try { st.close() conn.close() dataSource.close() } catch { case e:Exception => e.printStackTrace() } } } } class JdbcPartitionFunction(sql:String, ps:ImmutableList[PstateSetter], pw:Properties) extends VoidFunction[util.Iterator[org.apache.spark.sql.Row]] { /** * 每 batchSize 个提交一次 */ val batchSize = Integer.parseInt(pw.getProperty("batchsize", "2000")) override def call(iter: util.Iterator[org.apache.spark.sql.Row]): Unit = { val dataSource = new BasicDataSource() dataSource.setMaxActive(1) dataSource.setMinIdle(1) dataSource.setInitialSize(1) dataSource.setDriverClassName(pw.getProperty("driver")) dataSource.setUrl(pw.getProperty("url")) dataSource.setUsername(pw.getProperty("user")) dataSource.setPassword(pw.getProperty("password")) val conn = dataSource.getConnection conn.setAutoCommit(false) val st = conn.prepareStatement(sql) var counter = 0 try { while(iter.hasNext()) { val row = iter.next() var i = 1 val iterator: UnmodifiableIterator[PstateSetter] = ps.iterator() while(iterator.hasNext) { val pstateSetter: PstateSetter = iterator.next() pstateSetter.setValue(st, i, row) i += 1 } counter += 1 st.addBatch() // 一个批次,执行一次 if(counter >= batchSize){ st.executeBatch() st.clearBatch() counter = 0 } } // 最后不够一个批次的数据,一次性提交 if(counter % batchSize > 0) { st.executeBatch() } conn.commit(); } finally { // 关闭连接 try { st.close() conn.close() dataSource.close() } catch { case e:Exception => e.printStackTrace() } } } }