package com.yiidata.amc.jdbc import java.util import java.util.Properties import com.google.common.collect.ImmutableList.Builder import com.google.common.collect.{UnmodifiableIterator, ImmutableList} import org.apache.commons.dbcp.BasicDataSource import org.apache.spark.api.java.function.VoidFunction import org.apache.spark.rdd.RDD /** * *
* new JdbcDFWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
* Array("DATETIME", "HOLIDAY"), select).save(pro)
*
*
*
* * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠列名取值,而JdbcRDDWriter 靠列下标取值。 * * PstateSetter 数量应该和SQL中的?数量相同 * *
* * Created by ZhenQin on 2018/2/6 0006-18:01 * Vendor: yiidata.com * */ class JdbcDFWriter(sql:String, columnsNames:Array[String], df:org.apache.spark.sql.DataFrame){ { val arrays: util.List[String] = new util.ArrayList[String](columnsNames.length) columnsNames.foreach(c=>{arrays.add(c)}) val split: Array[String] = (sql + " ").split("\\?") if(columnsNames.length != (split.length -1)){ throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。") } // 如果传入的列有DataFrame 中不包含的列,则抛出异常 df.schema.foreach(s=>{arrays.remove(s.name)}) if(!arrays.isEmpty){ throw new IllegalArgumentException("unknown field: " + arrays.toString) } } /** * 保存该 DataFrame 到数据库 * * @param props 数据库连接信息 */ def save(props:Properties):Unit = { val builder: Builder[PstateSetter] = ImmutableList.builder() // 要注意顺序 val fieldMap:util.Map[String, PstateSetter] = new util.HashMap[String, PstateSetter](df.schema.length) df.schema.foreach(f=>{ f.dataType match { case org.apache.spark.sql.types.IntegerType => fieldMap.put(f.name, new IntPstateSetter(f.name)) case org.apache.spark.sql.types.LongType => fieldMap.put(f.name, new LongPstateSetter(f.name)) case org.apache.spark.sql.types.DoubleType => fieldMap.put(f.name, new DoublePstateSetter(f.name)) case org.apache.spark.sql.types.FloatType => fieldMap.put(f.name, new FloatPstateSetter(f.name)) case org.apache.spark.sql.types.ShortType => fieldMap.put(f.name, new ShortPstateSetter(f.name)) case org.apache.spark.sql.types.ByteType => fieldMap.put(f.name, new BytePstateSetter(f.name)) case org.apache.spark.sql.types.BooleanType => fieldMap.put(f.name, new BoolPstateSetter(f.name)) case org.apache.spark.sql.types.StringType => fieldMap.put(f.name, new StringPstateSetter(f.name)) case org.apache.spark.sql.types.BinaryType => fieldMap.put(f.name, new StringPstateSetter(f.name)) case org.apache.spark.sql.types.TimestampType => fieldMap.put(f.name, new TimestampPstateSetter(f.name)) case org.apache.spark.sql.types.DateType => fieldMap.put(f.name, new DatePstateSetter(f.name)) case t: org.apache.spark.sql.types.DecimalType => fieldMap.put(f.name, new DecimalPstateSetter(f.name)) case _ => None } }) // 顺序的加入每一个参数列 for(col <- columnsNames){ builder.add(fieldMap.get(col)) } df.javaRDD.foreachPartition(new JdbcPartitionFunction(sql, builder.build(), props)) } } /** * Jdbc RDD, 存储到数据库中。 * *
* new JdbcRDDWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
* Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), rdd)
*
*
* $1 则取RDD Line 第二列数据,从 $0 开始
* $2 时取RDD Line 第三列数据,从 $0 开始
*
* JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠 列名取值,而JdbcRDDWriter 靠列下标取值。
*
* 一般的 RDD 中应该为一个 Array(),这样的结果可以使用 array(0) 来取得对应的值。
*
* PstateSetter 数量应该和SQL中的?数量相同
*
* @param sql 执行更新的SQL
* @param columnTypes 字段
* @param rdd RDD
*/
class JdbcRDDWriter(sql:String, columnTypes:Array[PstateSetter], @transient rdd:RDD[Array[Any]]) extends Serializable {
{
val split: Array[String] = (sql + " ").split("\\?")
if(columnTypes.length != (split.length -1)){
throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
}
}
/**
* 保存该 RDD 到数据库
*
* @param props 数据库连接信息
*/
def save(props:Properties):Unit = {
val builder: Builder[PstateSetter] = ImmutableList.builder()
rdd.foreachPartition(iter =>{
val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
val dataSource = new BasicDataSource()
dataSource.setMaxActive(1)
dataSource.setMinIdle(1)
dataSource.setInitialSize(1)
dataSource.setDriverClassName(props.getProperty("driver"))
dataSource.setUrl(props.getProperty("url"))
dataSource.setUsername(props.getProperty("user"))
dataSource.setPassword(props.getProperty("password"))
val conn = dataSource.getConnection
conn.setAutoCommit(false)
val st = conn.prepareStatement(sql)
try {
var counter = 0
while (iter.hasNext) {
val line = iter.next()
var i = 1
for (pstateSetter <- columnTypes) {
val v = line(pstateSetter.index)
pstateSetter.setValue(st, i, v)
i += 1
}
counter += 1
st.addBatch()
// 一个批次,执行一次
if (counter >= batchSize) {
st.executeBatch()
st.clearBatch()
counter = 0
}
}
// 最后不够一个批次的数据,一次性提交
if(counter % batchSize > 0) {
st.executeBatch()
}
conn.commit();
} finally {
// 关闭连接
try {
st.close()
conn.close()
dataSource.close()
} catch {
case e:Exception => e.printStackTrace()
}
}
})
}
}
/**
* Jdbc Array(数组), 存储到数据库中。
*
*
* new JdbcArrayWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
* Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), array)
*
*
* $1 则取RDD Line 第二列数据,从 $0 开始
* $2 时取RDD Line 第三列数据,从 $0 开始
*
*
* 一般的 Array() 中应该为一个 Array[Any],这样的结果可以使用 array(0) 来取得对应的值。
*
* PstateSetter 数量应该和SQL中的?数量相同
*
* @param sql 执行更新的SQL
* @param columnTypes 字段
* @param list Array
*/
class JdbcArrayWriter(sql:String, columnTypes:Array[PstateSetter], @transient list:Array[Array[Any]]) extends Serializable {
{
val split: Array[String] = (sql + " ").split("\\?")
if(columnTypes.length != (split.length -1)){
throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
}
}
/**
* 保存该 RDD 到数据库
*
* @param props 数据库连接信息
*/
def save(props:Properties):Unit = {
val builder: Builder[PstateSetter] = ImmutableList.builder()
val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
val dataSource = new BasicDataSource()
dataSource.setMaxActive(1)
dataSource.setMinIdle(1)
dataSource.setInitialSize(1)
dataSource.setDriverClassName(props.getProperty("driver"))
dataSource.setUrl(props.getProperty("url"))
dataSource.setUsername(props.getProperty("user"))
dataSource.setPassword(props.getProperty("password"))
val conn = dataSource.getConnection
conn.setAutoCommit(false)
val st = conn.prepareStatement(sql)
try {
var counter = 0
val iter = list.iterator
while (iter.hasNext) {
val line = iter.next()
var i = 1
for (pstateSetter <- columnTypes) {
val v = line(pstateSetter.index)
pstateSetter.setValue(st, i, v)
i += 1
}
counter += 1
st.addBatch()
// 一个批次,执行一次
if (counter >= batchSize) {
st.executeBatch()
st.clearBatch()
counter = 0
}
}
// 最后不够一个批次的数据,一次性提交
if(counter % batchSize > 0) {
st.executeBatch()
}
conn.commit();
} finally {
// 关闭连接
try {
st.close()
conn.close()
dataSource.close()
} catch {
case e:Exception => e.printStackTrace()
}
}
}
}
class JdbcPartitionFunction(sql:String, ps:ImmutableList[PstateSetter], pw:Properties)
extends VoidFunction[util.Iterator[org.apache.spark.sql.Row]] {
/**
* 每 batchSize 个提交一次
*/
val batchSize = Integer.parseInt(pw.getProperty("batchsize", "2000"))
override def call(iter: util.Iterator[org.apache.spark.sql.Row]): Unit = {
val dataSource = new BasicDataSource()
dataSource.setMaxActive(1)
dataSource.setMinIdle(1)
dataSource.setInitialSize(1)
dataSource.setDriverClassName(pw.getProperty("driver"))
dataSource.setUrl(pw.getProperty("url"))
dataSource.setUsername(pw.getProperty("user"))
dataSource.setPassword(pw.getProperty("password"))
val conn = dataSource.getConnection
conn.setAutoCommit(false)
val st = conn.prepareStatement(sql)
var counter = 0
try {
while(iter.hasNext()) {
val row = iter.next()
var i = 1
val iterator: UnmodifiableIterator[PstateSetter] = ps.iterator()
while(iterator.hasNext) {
val pstateSetter: PstateSetter = iterator.next()
pstateSetter.setValue(st, i, row)
i += 1
}
counter += 1
st.addBatch()
// 一个批次,执行一次
if(counter >= batchSize){
st.executeBatch()
st.clearBatch()
counter = 0
}
}
// 最后不够一个批次的数据,一次性提交
if(counter % batchSize > 0) {
st.executeBatch()
}
conn.commit();
} finally {
// 关闭连接
try {
st.close()
conn.close()
dataSource.close()
} catch {
case e:Exception => e.printStackTrace()
}
}
}
}