|
|
@@ -0,0 +1,340 @@
|
|
|
+package com.yiidata.amc.jdbc
|
|
|
+
|
|
|
+import java.util
|
|
|
+import java.util.Properties
|
|
|
+
|
|
|
+import com.google.common.collect.ImmutableList.Builder
|
|
|
+import com.google.common.collect.{UnmodifiableIterator, ImmutableList}
|
|
|
+import org.apache.commons.dbcp.BasicDataSource
|
|
|
+import org.apache.spark.api.java.function.VoidFunction
|
|
|
+import org.apache.spark.rdd.RDD
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+/**
|
|
|
+ *
|
|
|
+ * <pre>
|
|
|
+ * new JdbcDFWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
|
|
|
+ * Array("DATETIME", "HOLIDAY"), select).save(pro)
|
|
|
+ * </pre>
|
|
|
+ *
|
|
|
+ *
|
|
|
+ * <p>
|
|
|
+ * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠列名取值,而JdbcRDDWriter 靠列下标取值。
|
|
|
+ *
|
|
|
+ * PstateSetter 数量应该和SQL中的?数量相同
|
|
|
+ *
|
|
|
+ * </p>
|
|
|
+ *
|
|
|
+ * Created by ZhenQin on 2018/2/6 0006-18:01
|
|
|
+ * Vendor: yiidata.com
|
|
|
+ *
|
|
|
+ */
|
|
|
+class JdbcDFWriter(sql:String, columnsNames:Array[String], df:org.apache.spark.sql.DataFrame){
|
|
|
+
|
|
|
+ {
|
|
|
+ val arrays: util.List[String] = new util.ArrayList[String](columnsNames.length)
|
|
|
+ columnsNames.foreach(c=>{arrays.add(c)})
|
|
|
+
|
|
|
+ val split: Array[String] = (sql + " ").split("\\?")
|
|
|
+ if(columnsNames.length != (split.length -1)){
|
|
|
+ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果传入的列有DataFrame 中不包含的列,则抛出异常
|
|
|
+ df.schema.foreach(s=>{arrays.remove(s.name)})
|
|
|
+ if(!arrays.isEmpty){
|
|
|
+ throw new IllegalArgumentException("unknown field: " + arrays.toString)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 保存该 DataFrame 到数据库
|
|
|
+ *
|
|
|
+ * @param props 数据库连接信息
|
|
|
+ */
|
|
|
+ def save(props:Properties):Unit = {
|
|
|
+ val builder: Builder[PstateSetter] = ImmutableList.builder()
|
|
|
+
|
|
|
+ // 要注意顺序
|
|
|
+ val fieldMap:util.Map[String, PstateSetter] = new util.HashMap[String, PstateSetter](df.schema.length)
|
|
|
+ df.schema.foreach(f=>{
|
|
|
+ f.dataType match {
|
|
|
+ case org.apache.spark.sql.types.IntegerType => fieldMap.put(f.name, new IntPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.LongType => fieldMap.put(f.name, new LongPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.DoubleType => fieldMap.put(f.name, new DoublePstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.FloatType => fieldMap.put(f.name, new FloatPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.ShortType => fieldMap.put(f.name, new ShortPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.ByteType => fieldMap.put(f.name, new BytePstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.BooleanType => fieldMap.put(f.name, new BoolPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.StringType => fieldMap.put(f.name, new StringPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.BinaryType => fieldMap.put(f.name, new StringPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.TimestampType => fieldMap.put(f.name, new TimestampPstateSetter(f.name))
|
|
|
+ case org.apache.spark.sql.types.DateType => fieldMap.put(f.name, new DatePstateSetter(f.name))
|
|
|
+ case t: org.apache.spark.sql.types.DecimalType => fieldMap.put(f.name, new DecimalPstateSetter(f.name))
|
|
|
+ case _ => None
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ // 顺序的加入每一个参数列
|
|
|
+ for(col <- columnsNames){
|
|
|
+ builder.add(fieldMap.get(col))
|
|
|
+ }
|
|
|
+
|
|
|
+ df.javaRDD.foreachPartition(new JdbcPartitionFunction(sql, builder.build(), props))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+/**
|
|
|
+ * Jdbc RDD, 存储到数据库中。
|
|
|
+ *
|
|
|
+ * <pre>
|
|
|
+ * new JdbcRDDWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
|
|
|
+ * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), rdd)
|
|
|
+ * </pre>
|
|
|
+ *
|
|
|
+ * $1 则取RDD Line 第二列数据,从 $0 开始
|
|
|
+ * $2 时取RDD Line 第三列数据,从 $0 开始
|
|
|
+ *
|
|
|
+ * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠 列名取值,而JdbcRDDWriter 靠列下标取值。
|
|
|
+ *
|
|
|
+ * 一般的 RDD 中应该为一个 Array(),这样的结果可以使用 array(0) 来取得对应的值。
|
|
|
+ *
|
|
|
+ * PstateSetter 数量应该和SQL中的?数量相同
|
|
|
+ *
|
|
|
+ * @param sql 执行更新的SQL
|
|
|
+ * @param columnTypes 字段
|
|
|
+ * @param rdd RDD
|
|
|
+ */
|
|
|
+class JdbcRDDWriter(sql:String, columnTypes:Array[PstateSetter], @transient rdd:RDD[Array[Any]]) extends Serializable {
|
|
|
+
|
|
|
+ {
|
|
|
+ val split: Array[String] = (sql + " ").split("\\?")
|
|
|
+ if(columnTypes.length != (split.length -1)){
|
|
|
+ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 保存该 RDD 到数据库
|
|
|
+ *
|
|
|
+ * @param props 数据库连接信息
|
|
|
+ */
|
|
|
+ def save(props:Properties):Unit = {
|
|
|
+ val builder: Builder[PstateSetter] = ImmutableList.builder()
|
|
|
+ rdd.foreachPartition(iter =>{
|
|
|
+ val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
|
|
|
+ val dataSource = new BasicDataSource()
|
|
|
+ dataSource.setMaxActive(1)
|
|
|
+ dataSource.setMinIdle(1)
|
|
|
+ dataSource.setInitialSize(1)
|
|
|
+ dataSource.setDriverClassName(props.getProperty("driver"))
|
|
|
+ dataSource.setUrl(props.getProperty("url"))
|
|
|
+ dataSource.setUsername(props.getProperty("user"))
|
|
|
+ dataSource.setPassword(props.getProperty("password"))
|
|
|
+ val conn = dataSource.getConnection
|
|
|
+ conn.setAutoCommit(false)
|
|
|
+ val st = conn.prepareStatement(sql)
|
|
|
+
|
|
|
+ try {
|
|
|
+ var counter = 0
|
|
|
+ while (iter.hasNext) {
|
|
|
+ val line = iter.next()
|
|
|
+ var i = 1
|
|
|
+
|
|
|
+ for (pstateSetter <- columnTypes) {
|
|
|
+ val v = line(pstateSetter.index)
|
|
|
+ pstateSetter.setValue(st, i, v)
|
|
|
+ i += 1
|
|
|
+ }
|
|
|
+
|
|
|
+ counter += 1
|
|
|
+ st.addBatch()
|
|
|
+
|
|
|
+ // 一个批次,执行一次
|
|
|
+ if (counter >= batchSize) {
|
|
|
+ st.executeBatch()
|
|
|
+ st.clearBatch()
|
|
|
+ counter = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 最后不够一个批次的数据,一次性提交
|
|
|
+ if(counter % batchSize > 0) {
|
|
|
+ st.executeBatch()
|
|
|
+ }
|
|
|
+ conn.commit();
|
|
|
+ } finally {
|
|
|
+ // 关闭连接
|
|
|
+ try {
|
|
|
+ st.close()
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ dataSource.close()
|
|
|
+ } catch {
|
|
|
+ case e:Exception => e.printStackTrace()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+/**
|
|
|
+ * Jdbc Array(数组), 存储到数据库中。
|
|
|
+ *
|
|
|
+ * <pre>
|
|
|
+ * new JdbcArrayWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
|
|
|
+ * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), array)
|
|
|
+ * </pre>
|
|
|
+ *
|
|
|
+ * $1 则取RDD Line 第二列数据,从 $0 开始
|
|
|
+ * $2 时取RDD Line 第三列数据,从 $0 开始
|
|
|
+ *
|
|
|
+ *
|
|
|
+ * 一般的 Array() 中应该为一个 Array[Any],这样的结果可以使用 array(0) 来取得对应的值。
|
|
|
+ *
|
|
|
+ * PstateSetter 数量应该和SQL中的?数量相同
|
|
|
+ *
|
|
|
+ * @param sql 执行更新的SQL
|
|
|
+ * @param columnTypes 字段
|
|
|
+ * @param list Array
|
|
|
+ */
|
|
|
+class JdbcArrayWriter(sql:String, columnTypes:Array[PstateSetter], @transient list:Array[Array[Any]]) extends Serializable {
|
|
|
+
|
|
|
+ {
|
|
|
+ val split: Array[String] = (sql + " ").split("\\?")
|
|
|
+ if(columnTypes.length != (split.length -1)){
|
|
|
+ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 保存该 RDD 到数据库
|
|
|
+ *
|
|
|
+ * @param props 数据库连接信息
|
|
|
+ */
|
|
|
+ def save(props:Properties):Unit = {
|
|
|
+ val builder: Builder[PstateSetter] = ImmutableList.builder()
|
|
|
+ val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
|
|
|
+ val dataSource = new BasicDataSource()
|
|
|
+ dataSource.setMaxActive(1)
|
|
|
+ dataSource.setMinIdle(1)
|
|
|
+ dataSource.setInitialSize(1)
|
|
|
+ dataSource.setDriverClassName(props.getProperty("driver"))
|
|
|
+ dataSource.setUrl(props.getProperty("url"))
|
|
|
+ dataSource.setUsername(props.getProperty("user"))
|
|
|
+ dataSource.setPassword(props.getProperty("password"))
|
|
|
+ val conn = dataSource.getConnection
|
|
|
+ conn.setAutoCommit(false)
|
|
|
+ val st = conn.prepareStatement(sql)
|
|
|
+ try {
|
|
|
+ var counter = 0
|
|
|
+ val iter = list.iterator
|
|
|
+ while (iter.hasNext) {
|
|
|
+ val line = iter.next()
|
|
|
+ var i = 1
|
|
|
+
|
|
|
+ for (pstateSetter <- columnTypes) {
|
|
|
+ val v = line(pstateSetter.index)
|
|
|
+ pstateSetter.setValue(st, i, v)
|
|
|
+ i += 1
|
|
|
+ }
|
|
|
+
|
|
|
+ counter += 1
|
|
|
+ st.addBatch()
|
|
|
+
|
|
|
+ // 一个批次,执行一次
|
|
|
+ if (counter >= batchSize) {
|
|
|
+ st.executeBatch()
|
|
|
+ st.clearBatch()
|
|
|
+ counter = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 最后不够一个批次的数据,一次性提交
|
|
|
+ if(counter % batchSize > 0) {
|
|
|
+ st.executeBatch()
|
|
|
+ }
|
|
|
+ conn.commit();
|
|
|
+ } finally {
|
|
|
+ // 关闭连接
|
|
|
+ try {
|
|
|
+ st.close()
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ dataSource.close()
|
|
|
+ } catch {
|
|
|
+ case e:Exception => e.printStackTrace()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+class JdbcPartitionFunction(sql:String, ps:ImmutableList[PstateSetter], pw:Properties)
|
|
|
+ extends VoidFunction[util.Iterator[org.apache.spark.sql.Row]] {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 每 batchSize 个提交一次
|
|
|
+ */
|
|
|
+ val batchSize = Integer.parseInt(pw.getProperty("batchsize", "2000"))
|
|
|
+
|
|
|
+
|
|
|
+ override def call(iter: util.Iterator[org.apache.spark.sql.Row]): Unit = {
|
|
|
+ val dataSource = new BasicDataSource()
|
|
|
+ dataSource.setMaxActive(1)
|
|
|
+ dataSource.setMinIdle(1)
|
|
|
+ dataSource.setInitialSize(1)
|
|
|
+ dataSource.setDriverClassName(pw.getProperty("driver"))
|
|
|
+ dataSource.setUrl(pw.getProperty("url"))
|
|
|
+ dataSource.setUsername(pw.getProperty("user"))
|
|
|
+ dataSource.setPassword(pw.getProperty("password"))
|
|
|
+ val conn = dataSource.getConnection
|
|
|
+ conn.setAutoCommit(false)
|
|
|
+ val st = conn.prepareStatement(sql)
|
|
|
+ var counter = 0
|
|
|
+ try {
|
|
|
+ while(iter.hasNext()) {
|
|
|
+ val row = iter.next()
|
|
|
+ var i = 1
|
|
|
+ val iterator: UnmodifiableIterator[PstateSetter] = ps.iterator()
|
|
|
+ while(iterator.hasNext) {
|
|
|
+ val pstateSetter: PstateSetter = iterator.next()
|
|
|
+ pstateSetter.setValue(st, i, row)
|
|
|
+ i += 1
|
|
|
+ }
|
|
|
+
|
|
|
+ counter += 1
|
|
|
+ st.addBatch()
|
|
|
+
|
|
|
+ // 一个批次,执行一次
|
|
|
+ if(counter >= batchSize){
|
|
|
+ st.executeBatch()
|
|
|
+ st.clearBatch()
|
|
|
+ counter = 0
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 最后不够一个批次的数据,一次性提交
|
|
|
+ if(counter % batchSize > 0) {
|
|
|
+ st.executeBatch()
|
|
|
+ }
|
|
|
+ conn.commit();
|
|
|
+ } finally {
|
|
|
+ // 关闭连接
|
|
|
+ try {
|
|
|
+ st.close()
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ dataSource.close()
|
|
|
+ } catch {
|
|
|
+ case e:Exception => e.printStackTrace()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|