package com.yiidata.intergration.web.task; import lombok.extern.slf4j.Slf4j; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.Semaphore; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; /** * * 异步任务队列 * *
 *
 * Created by zhaopx.
 * User: zhaopx
 * Date: 2020/11/16
 * Time: 15:36
 *
 * 
* * @author zhaopx */ @Slf4j public class AsyncTaskQueue { /** * 任务缓存 */ private final static Map TASK_QUEUE = new ConcurrentHashMap<>(); /** * 任务错误队列 */ private final static Map ERROR_TASK_QUEUE = new ConcurrentHashMap<>(); /** * 正在运行的队列 */ private final static Set RUNNING_TASK_QUEUE = new HashSet<>(); /** * 控制 Spark 并发的信号量 */ private final Semaphore semaphore; /** * 公平锁 */ private final static Lock LOCK = new ReentrantLock(); private static AsyncTaskQueue SPARK_QUEUE; private AsyncTaskQueue(int permits) { semaphore = new Semaphore(permits); } /** * 初次调用有效, * @return */ public static AsyncTaskQueue getInstance() { return getInstance(3); } /** * 按照配置,设置并发量。 第一次调用有效 * @param permits * @return */ public static synchronized AsyncTaskQueue getInstance(int permits) { if(SPARK_QUEUE == null) { SPARK_QUEUE = new AsyncTaskQueue(permits); } return SPARK_QUEUE; } /** * 添加任务 * @param taskId * @param taskInfo */ public static boolean addTask(String taskId, Map taskInfo) { LOCK.lock(); try { if(!TASK_QUEUE.containsKey(taskId)) { TASK_QUEUE.put(taskId, taskInfo); log.info("add task: {} , params: {}", taskId, String.valueOf(taskInfo)); return true; } } finally { LOCK.unlock(); } return false; } /** * 获取当前需要执行队列的长度 * @return */ public static int getPendingTaskSize() { LOCK.lock(); try { HashMap tmpMap = new HashMap<>(TASK_QUEUE); for (String s : RUNNING_TASK_QUEUE) { tmpMap.remove(s); } return tmpMap.size(); } finally { LOCK.unlock(); } } /** * 获取当前需要执行队列 * @return */ public static Set getPendingTasks() { LOCK.lock(); try { HashMap tmpMap = new HashMap<>(TASK_QUEUE); for (String s : RUNNING_TASK_QUEUE) { tmpMap.remove(s); } return tmpMap.keySet(); } finally { LOCK.unlock(); } } /** * 获取当前正在执行任务的长度 * @return */ public static int getRunningTaskSize() { return RUNNING_TASK_QUEUE.size(); } public static Object getTaskInfo(String taskId) { return TASK_QUEUE.get(taskId); } /** * 移除任务 * @param taskId */ public static void removeTask(String taskId) { LOCK.lock(); try { TASK_QUEUE.remove(taskId); RUNNING_TASK_QUEUE.remove(taskId); log.info("remove task: {}", taskId); } finally { LOCK.unlock(); } } /** * 错误的任务报告 * @param taskId */ public static void reportError(String taskId) { LOCK.lock(); try { Object errorTaskInfo = TASK_QUEUE.remove(taskId); ERROR_TASK_QUEUE.put(taskId, errorTaskInfo); RUNNING_TASK_QUEUE.remove(taskId); } finally { LOCK.unlock(); } } /** * 判断任务是否正在运行 * @param taskId * @return */ public static boolean runningTask(String taskId) { return RUNNING_TASK_QUEUE.contains(taskId); } /** * 执行该函数 * @param executor * @param task */ public void execute(Executor executor, final SuperTask task) { executor.execute(()->{ final String runningTaskId = task.getTaskId(); // 有任务需要运行 if(AsyncTaskQueue.runningTask(runningTaskId)) { // 取得的待运行的task,不能是正在运行的列表中的 log.info("task {} running.", runningTaskId); return; } // 获得一个许可 try { semaphore.acquire(); } catch (InterruptedException e) { return; } try { // 运行任务 RUNNING_TASK_QUEUE.add(runningTaskId); log.info("running task: {}", runningTaskId); task.run(); log.info("finished task: {}", runningTaskId); // 执行成功,移除 removeTask(runningTaskId); } catch (Exception e) { log.info("执行任务异常。error task: " + runningTaskId, e); // 运行错误 reportError(runningTaskId); } finally { // 释放许可 semaphore.release(); } }); } }