JdbcDFWriter.scala 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. package com.yiidata.amc.jdbc
  2. import java.util
  3. import java.util.Properties
  4. import com.google.common.collect.ImmutableList.Builder
  5. import com.google.common.collect.{UnmodifiableIterator, ImmutableList}
  6. import org.apache.commons.dbcp.BasicDataSource
  7. import org.apache.spark.api.java.function.VoidFunction
  8. import org.apache.spark.rdd.RDD
  9. /**
  10. *
  11. * <pre>
  12. * new JdbcDFWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  13. * Array("DATETIME", "HOLIDAY"), select).save(pro)
  14. * </pre>
  15. *
  16. *
  17. * <p>
  18. * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠列名取值,而JdbcRDDWriter 靠列下标取值。
  19. *
  20. * PstateSetter 数量应该和SQL中的?数量相同
  21. *
  22. * </p>
  23. *
  24. * Created by ZhenQin on 2018/2/6 0006-18:01
  25. * Vendor: yiidata.com
  26. *
  27. */
  28. class JdbcDFWriter(sql:String, columnsNames:Array[String], df:org.apache.spark.sql.DataFrame){
  29. {
  30. val arrays: util.List[String] = new util.ArrayList[String](columnsNames.length)
  31. columnsNames.foreach(c=>{arrays.add(c)})
  32. val split: Array[String] = (sql + " ").split("\\?")
  33. if(columnsNames.length != (split.length -1)){
  34. throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
  35. }
  36. // 如果传入的列有DataFrame 中不包含的列,则抛出异常
  37. df.schema.foreach(s=>{arrays.remove(s.name)})
  38. if(!arrays.isEmpty){
  39. throw new IllegalArgumentException("unknown field: " + arrays.toString)
  40. }
  41. }
  42. /**
  43. * 保存该 DataFrame 到数据库
  44. *
  45. * @param props 数据库连接信息
  46. */
  47. def save(props:Properties):Unit = {
  48. val builder: Builder[PstateSetter] = ImmutableList.builder()
  49. // 要注意顺序
  50. val fieldMap:util.Map[String, PstateSetter] = new util.HashMap[String, PstateSetter](df.schema.length)
  51. df.schema.foreach(f=>{
  52. f.dataType match {
  53. case org.apache.spark.sql.types.IntegerType => fieldMap.put(f.name, new IntPstateSetter(f.name))
  54. case org.apache.spark.sql.types.LongType => fieldMap.put(f.name, new LongPstateSetter(f.name))
  55. case org.apache.spark.sql.types.DoubleType => fieldMap.put(f.name, new DoublePstateSetter(f.name))
  56. case org.apache.spark.sql.types.FloatType => fieldMap.put(f.name, new FloatPstateSetter(f.name))
  57. case org.apache.spark.sql.types.ShortType => fieldMap.put(f.name, new ShortPstateSetter(f.name))
  58. case org.apache.spark.sql.types.ByteType => fieldMap.put(f.name, new BytePstateSetter(f.name))
  59. case org.apache.spark.sql.types.BooleanType => fieldMap.put(f.name, new BoolPstateSetter(f.name))
  60. case org.apache.spark.sql.types.StringType => fieldMap.put(f.name, new StringPstateSetter(f.name))
  61. case org.apache.spark.sql.types.BinaryType => fieldMap.put(f.name, new StringPstateSetter(f.name))
  62. case org.apache.spark.sql.types.TimestampType => fieldMap.put(f.name, new TimestampPstateSetter(f.name))
  63. case org.apache.spark.sql.types.DateType => fieldMap.put(f.name, new DatePstateSetter(f.name))
  64. case t: org.apache.spark.sql.types.DecimalType => fieldMap.put(f.name, new DecimalPstateSetter(f.name))
  65. case _ => None
  66. }
  67. })
  68. // 顺序的加入每一个参数列
  69. for(col <- columnsNames){
  70. builder.add(fieldMap.get(col))
  71. }
  72. df.javaRDD.foreachPartition(new JdbcPartitionFunction(sql, builder.build(), props))
  73. }
  74. }
  75. /**
  76. * Jdbc RDD, 存储到数据库中。
  77. *
  78. * <pre>
  79. * new JdbcRDDWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  80. * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), rdd)
  81. * </pre>
  82. *
  83. * $1 则取RDD Line 第二列数据,从 $0 开始
  84. * $2 时取RDD Line 第三列数据,从 $0 开始
  85. *
  86. * JdbcRDDWriter 和 JdbcDFWriter 在列取值时操作有明显不同。JdbcDFWriter 靠 列名取值,而JdbcRDDWriter 靠列下标取值。
  87. *
  88. * 一般的 RDD 中应该为一个 Array(),这样的结果可以使用 array(0) 来取得对应的值。
  89. *
  90. * PstateSetter 数量应该和SQL中的?数量相同
  91. *
  92. * @param sql 执行更新的SQL
  93. * @param columnTypes 字段
  94. * @param rdd RDD
  95. */
  96. class JdbcRDDWriter(sql:String, columnTypes:Array[PstateSetter], @transient rdd:RDD[Array[Any]]) extends Serializable {
  97. {
  98. val split: Array[String] = (sql + " ").split("\\?")
  99. if(columnTypes.length != (split.length -1)){
  100. throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
  101. }
  102. }
  103. /**
  104. * 保存该 RDD 到数据库
  105. *
  106. * @param props 数据库连接信息
  107. */
  108. def save(props:Properties):Unit = {
  109. val builder: Builder[PstateSetter] = ImmutableList.builder()
  110. rdd.foreachPartition(iter =>{
  111. val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
  112. val dataSource = new BasicDataSource()
  113. dataSource.setMaxActive(1)
  114. dataSource.setMinIdle(1)
  115. dataSource.setInitialSize(1)
  116. dataSource.setDriverClassName(props.getProperty("driver"))
  117. dataSource.setUrl(props.getProperty("url"))
  118. dataSource.setUsername(props.getProperty("user"))
  119. dataSource.setPassword(props.getProperty("password"))
  120. val conn = dataSource.getConnection
  121. conn.setAutoCommit(false)
  122. val st = conn.prepareStatement(sql)
  123. try {
  124. var counter = 0
  125. while (iter.hasNext) {
  126. val line = iter.next()
  127. var i = 1
  128. for (pstateSetter <- columnTypes) {
  129. val v = line(pstateSetter.index)
  130. pstateSetter.setValue(st, i, v)
  131. i += 1
  132. }
  133. counter += 1
  134. st.addBatch()
  135. // 一个批次,执行一次
  136. if (counter >= batchSize) {
  137. st.executeBatch()
  138. st.clearBatch()
  139. counter = 0
  140. }
  141. }
  142. // 最后不够一个批次的数据,一次性提交
  143. if(counter % batchSize > 0) {
  144. st.executeBatch()
  145. }
  146. conn.commit();
  147. } finally {
  148. // 关闭连接
  149. try {
  150. st.close()
  151. conn.close()
  152. dataSource.close()
  153. } catch {
  154. case e:Exception => e.printStackTrace()
  155. }
  156. }
  157. })
  158. }
  159. }
  160. /**
  161. * Jdbc Array(数组), 存储到数据库中。
  162. *
  163. * <pre>
  164. * new JdbcArrayWriter("update HOLIDAY_TABLE_1 set STATUS = '1' where DATETIME = ? and HOLIDAY = ? ",
  165. * Array(new StringPstateSetter("$1"), new StringPstateSetter("$2")), array)
  166. * </pre>
  167. *
  168. * $1 则取RDD Line 第二列数据,从 $0 开始
  169. * $2 时取RDD Line 第三列数据,从 $0 开始
  170. *
  171. *
  172. * 一般的 Array() 中应该为一个 Array[Any],这样的结果可以使用 array(0) 来取得对应的值。
  173. *
  174. * PstateSetter 数量应该和SQL中的?数量相同
  175. *
  176. * @param sql 执行更新的SQL
  177. * @param columnTypes 字段
  178. * @param list Array
  179. */
  180. class JdbcArrayWriter(sql:String, columnTypes:Array[PstateSetter], @transient list:Array[Array[Any]]) extends Serializable {
  181. {
  182. val split: Array[String] = (sql + " ").split("\\?")
  183. if(columnTypes.length != (split.length -1)){
  184. throw new IllegalArgumentException("SQL 中 ? 数量必须 和 columnsNames 数量相等,且必须对应。")
  185. }
  186. }
  187. /**
  188. * 保存该 RDD 到数据库
  189. *
  190. * @param props 数据库连接信息
  191. */
  192. def save(props:Properties):Unit = {
  193. val builder: Builder[PstateSetter] = ImmutableList.builder()
  194. val batchSize = Integer.parseInt(props.getProperty("batchsize", "2000"))
  195. val dataSource = new BasicDataSource()
  196. dataSource.setMaxActive(1)
  197. dataSource.setMinIdle(1)
  198. dataSource.setInitialSize(1)
  199. dataSource.setDriverClassName(props.getProperty("driver"))
  200. dataSource.setUrl(props.getProperty("url"))
  201. dataSource.setUsername(props.getProperty("user"))
  202. dataSource.setPassword(props.getProperty("password"))
  203. val conn = dataSource.getConnection
  204. conn.setAutoCommit(false)
  205. val st = conn.prepareStatement(sql)
  206. try {
  207. var counter = 0
  208. val iter = list.iterator
  209. while (iter.hasNext) {
  210. val line = iter.next()
  211. var i = 1
  212. for (pstateSetter <- columnTypes) {
  213. val v = line(pstateSetter.index)
  214. pstateSetter.setValue(st, i, v)
  215. i += 1
  216. }
  217. counter += 1
  218. st.addBatch()
  219. // 一个批次,执行一次
  220. if (counter >= batchSize) {
  221. st.executeBatch()
  222. st.clearBatch()
  223. counter = 0
  224. }
  225. }
  226. // 最后不够一个批次的数据,一次性提交
  227. if(counter % batchSize > 0) {
  228. st.executeBatch()
  229. }
  230. conn.commit();
  231. } finally {
  232. // 关闭连接
  233. try {
  234. st.close()
  235. conn.close()
  236. dataSource.close()
  237. } catch {
  238. case e:Exception => e.printStackTrace()
  239. }
  240. }
  241. }
  242. }
  243. class JdbcPartitionFunction(sql:String, ps:ImmutableList[PstateSetter], pw:Properties)
  244. extends VoidFunction[util.Iterator[org.apache.spark.sql.Row]] {
  245. /**
  246. * 每 batchSize 个提交一次
  247. */
  248. val batchSize = Integer.parseInt(pw.getProperty("batchsize", "2000"))
  249. override def call(iter: util.Iterator[org.apache.spark.sql.Row]): Unit = {
  250. val dataSource = new BasicDataSource()
  251. dataSource.setMaxActive(1)
  252. dataSource.setMinIdle(1)
  253. dataSource.setInitialSize(1)
  254. dataSource.setDriverClassName(pw.getProperty("driver"))
  255. dataSource.setUrl(pw.getProperty("url"))
  256. dataSource.setUsername(pw.getProperty("user"))
  257. dataSource.setPassword(pw.getProperty("password"))
  258. val conn = dataSource.getConnection
  259. conn.setAutoCommit(false)
  260. val st = conn.prepareStatement(sql)
  261. var counter = 0
  262. try {
  263. while(iter.hasNext()) {
  264. val row = iter.next()
  265. var i = 1
  266. val iterator: UnmodifiableIterator[PstateSetter] = ps.iterator()
  267. while(iterator.hasNext) {
  268. val pstateSetter: PstateSetter = iterator.next()
  269. pstateSetter.setValue(st, i, row)
  270. i += 1
  271. }
  272. counter += 1
  273. st.addBatch()
  274. // 一个批次,执行一次
  275. if(counter >= batchSize){
  276. st.executeBatch()
  277. st.clearBatch()
  278. counter = 0
  279. }
  280. }
  281. // 最后不够一个批次的数据,一次性提交
  282. if(counter % batchSize > 0) {
  283. st.executeBatch()
  284. }
  285. conn.commit();
  286. } finally {
  287. // 关闭连接
  288. try {
  289. st.close()
  290. conn.close()
  291. dataSource.close()
  292. } catch {
  293. case e:Exception => e.printStackTrace()
  294. }
  295. }
  296. }
  297. }