| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340 |
- package com.yiidata.amc.jdbc
- import java.util
- import java.util.Properties
- import com.google.common.collect.ImmutableList.Builder
- import com.google.common.collect.{UnmodifiableIterator, ImmutableList}
- import org.apache.commons.dbcp.BasicDataSource
- import org.apache.spark.api.java.function.VoidFunction
- import org.apache.spark.rdd.RDD
- /**
- *
- * <pre>
- * new JdbcDFWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
- * Array("DATETIME", "HOLIDAY"), select).save(pro)
- * </pre>
- *
- *
- * <p>
- * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠列名取值,而JdbcRDDWriter 靠列下标取值。
- *
- * PstateSetter 数量应该和SQL中的?数量相同
- *
- * </p>
- *
- * Created by ZhenQin on 2018/2/6 0006-18:01
- * Vendor: yiidata.com
- *
- */
- class JdbcDFWriter(sql:String, columnsNames:Array[String], df:org.apache.spark.sql.DataFrame){
- {
- val arrays: util.List[String] = new util.ArrayList[String](columnsNames.length)
- columnsNames.foreach(c=>{arrays.add(c)})
- val split: Array[String] = (sql + " ").split("\\?")
- if(columnsNames.length != (split.length -1)){
- throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
- }
- // 如果传入的列有DataFrame 中不包含的列,则抛出异常
- df.schema.foreach(s=>{arrays.remove(s.name)})
- if(!arrays.isEmpty){
- throw new IllegalArgumentException("unknown field: " + arrays.toString)
- }
- }
- /**
- * 保存该 DataFrame 到数据库
- *
- * @param props 数据库连接信息
- */
- def save(props:Properties):Unit = {
- val builder: Builder[PstateSetter] = ImmutableList.builder()
- // 要注意顺序
- val fieldMap:util.Map[String, PstateSetter] = new util.HashMap[String, PstateSetter](df.schema.length)
- df.schema.foreach(f=>{
- f.dataType match {
- case org.apache.spark.sql.types.IntegerType => fieldMap.put(f.name, new IntPstateSetter(f.name))
- case org.apache.spark.sql.types.LongType => fieldMap.put(f.name, new LongPstateSetter(f.name))
- case org.apache.spark.sql.types.DoubleType => fieldMap.put(f.name, new DoublePstateSetter(f.name))
- case org.apache.spark.sql.types.FloatType => fieldMap.put(f.name, new FloatPstateSetter(f.name))
- case org.apache.spark.sql.types.ShortType => fieldMap.put(f.name, new ShortPstateSetter(f.name))
- case org.apache.spark.sql.types.ByteType => fieldMap.put(f.name, new BytePstateSetter(f.name))
- case org.apache.spark.sql.types.BooleanType => fieldMap.put(f.name, new BoolPstateSetter(f.name))
- case org.apache.spark.sql.types.StringType => fieldMap.put(f.name, new StringPstateSetter(f.name))
- case org.apache.spark.sql.types.BinaryType => fieldMap.put(f.name, new StringPstateSetter(f.name))
- case org.apache.spark.sql.types.TimestampType => fieldMap.put(f.name, new TimestampPstateSetter(f.name))
- case org.apache.spark.sql.types.DateType => fieldMap.put(f.name, new DatePstateSetter(f.name))
- case t: org.apache.spark.sql.types.DecimalType => fieldMap.put(f.name, new DecimalPstateSetter(f.name))
- case _ => None
- }
- })
- // 顺序的加入每一个参数列
- for(col <- columnsNames){
- builder.add(fieldMap.get(col))
- }
- df.javaRDD.foreachPartition(new JdbcPartitionFunction(sql, builder.build(), props))
- }
- }
- /**
- * Jdbc RDD, 存储到数据库中。
- *
- * <pre>
- * new JdbcRDDWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
- * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), rdd)
- * </pre>
- *
- * $1 则取RDD Line 第二列数据,从 $0 开始
- * $2 时取RDD Line 第三列数据,从 $0 开始
- *
- * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠 列名取值,而JdbcRDDWriter 靠列下标取值。
- *
- * 一般的 RDD 中应该为一个 Array(),这样的结果可以使用 array(0) 来取得对应的值。
- *
- * PstateSetter 数量应该和SQL中的?数量相同
- *
- * @param sql 执行更新的SQL
- * @param columnTypes 字段
- * @param rdd RDD
- */
- class JdbcRDDWriter(sql:String, columnTypes:Array[PstateSetter], @transient rdd:RDD[Array[Any]]) extends Serializable {
- {
- val split: Array[String] = (sql + " ").split("\\?")
- if(columnTypes.length != (split.length -1)){
- throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
- }
- }
- /**
- * 保存该 RDD 到数据库
- *
- * @param props 数据库连接信息
- */
- def save(props:Properties):Unit = {
- val builder: Builder[PstateSetter] = ImmutableList.builder()
- rdd.foreachPartition(iter =>{
- val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
- val dataSource = new BasicDataSource()
- dataSource.setMaxActive(1)
- dataSource.setMinIdle(1)
- dataSource.setInitialSize(1)
- dataSource.setDriverClassName(props.getProperty("driver"))
- dataSource.setUrl(props.getProperty("url"))
- dataSource.setUsername(props.getProperty("user"))
- dataSource.setPassword(props.getProperty("password"))
- val conn = dataSource.getConnection
- conn.setAutoCommit(false)
- val st = conn.prepareStatement(sql)
- try {
- var counter = 0
- while (iter.hasNext) {
- val line = iter.next()
- var i = 1
- for (pstateSetter <- columnTypes) {
- val v = line(pstateSetter.index)
- pstateSetter.setValue(st, i, v)
- i += 1
- }
- counter += 1
- st.addBatch()
- // 一个批次,执行一次
- if (counter >= batchSize) {
- st.executeBatch()
- st.clearBatch()
- counter = 0
- }
- }
- // 最后不够一个批次的数据,一次性提交
- if(counter % batchSize > 0) {
- st.executeBatch()
- }
- conn.commit();
- } finally {
- // 关闭连接
- try {
- st.close()
- conn.close()
- dataSource.close()
- } catch {
- case e:Exception => e.printStackTrace()
- }
- }
- })
- }
- }
- /**
- * Jdbc Array(数组), 存储到数据库中。
- *
- * <pre>
- * new JdbcArrayWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
- * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), array)
- * </pre>
- *
- * $1 则取RDD Line 第二列数据,从 $0 开始
- * $2 时取RDD Line 第三列数据,从 $0 开始
- *
- *
- * 一般的 Array() 中应该为一个 Array[Any],这样的结果可以使用 array(0) 来取得对应的值。
- *
- * PstateSetter 数量应该和SQL中的?数量相同
- *
- * @param sql 执行更新的SQL
- * @param columnTypes 字段
- * @param list Array
- */
- class JdbcArrayWriter(sql:String, columnTypes:Array[PstateSetter], @transient list:Array[Array[Any]]) extends Serializable {
- {
- val split: Array[String] = (sql + " ").split("\\?")
- if(columnTypes.length != (split.length -1)){
- throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
- }
- }
- /**
- * 保存该 RDD 到数据库
- *
- * @param props 数据库连接信息
- */
- def save(props:Properties):Unit = {
- val builder: Builder[PstateSetter] = ImmutableList.builder()
- val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
- val dataSource = new BasicDataSource()
- dataSource.setMaxActive(1)
- dataSource.setMinIdle(1)
- dataSource.setInitialSize(1)
- dataSource.setDriverClassName(props.getProperty("driver"))
- dataSource.setUrl(props.getProperty("url"))
- dataSource.setUsername(props.getProperty("user"))
- dataSource.setPassword(props.getProperty("password"))
- val conn = dataSource.getConnection
- conn.setAutoCommit(false)
- val st = conn.prepareStatement(sql)
- try {
- var counter = 0
- val iter = list.iterator
- while (iter.hasNext) {
- val line = iter.next()
- var i = 1
- for (pstateSetter <- columnTypes) {
- val v = line(pstateSetter.index)
- pstateSetter.setValue(st, i, v)
- i += 1
- }
- counter += 1
- st.addBatch()
- // 一个批次,执行一次
- if (counter >= batchSize) {
- st.executeBatch()
- st.clearBatch()
- counter = 0
- }
- }
- // 最后不够一个批次的数据,一次性提交
- if(counter % batchSize > 0) {
- st.executeBatch()
- }
- conn.commit();
- } finally {
- // 关闭连接
- try {
- st.close()
- conn.close()
- dataSource.close()
- } catch {
- case e:Exception => e.printStackTrace()
- }
- }
- }
- }
- class JdbcPartitionFunction(sql:String, ps:ImmutableList[PstateSetter], pw:Properties)
- extends VoidFunction[util.Iterator[org.apache.spark.sql.Row]] {
- /**
- * 每 batchSize 个提交一次
- */
- val batchSize = Integer.parseInt(pw.getProperty("batchsize", "2000"))
- override def call(iter: util.Iterator[org.apache.spark.sql.Row]): Unit = {
- val dataSource = new BasicDataSource()
- dataSource.setMaxActive(1)
- dataSource.setMinIdle(1)
- dataSource.setInitialSize(1)
- dataSource.setDriverClassName(pw.getProperty("driver"))
- dataSource.setUrl(pw.getProperty("url"))
- dataSource.setUsername(pw.getProperty("user"))
- dataSource.setPassword(pw.getProperty("password"))
- val conn = dataSource.getConnection
- conn.setAutoCommit(false)
- val st = conn.prepareStatement(sql)
- var counter = 0
- try {
- while(iter.hasNext()) {
- val row = iter.next()
- var i = 1
- val iterator: UnmodifiableIterator[PstateSetter] = ps.iterator()
- while(iterator.hasNext) {
- val pstateSetter: PstateSetter = iterator.next()
- pstateSetter.setValue(st, i, row)
- i += 1
- }
- counter += 1
- st.addBatch()
- // 一个批次,执行一次
- if(counter >= batchSize){
- st.executeBatch()
- st.clearBatch()
- counter = 0
- }
- }
- // 最后不够一个批次的数据,一次性提交
- if(counter % batchSize > 0) {
- st.executeBatch()
- }
- conn.commit();
- } finally {
- // 关闭连接
- try {
- st.close()
- conn.close()
- dataSource.close()
- } catch {
- case e:Exception => e.printStackTrace()
- }
- }
- }
- }
|