package com.yiidata.intergration.web.task;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
*
* 异步任务队列
*
*
*
* Created by zhaopx.
* User: zhaopx
* Date: 2020/11/16
* Time: 15:36
*
*
*
* @author zhaopx
*/
@Slf4j
public class AsyncTaskQueue {
/**
* 任务缓存
*/
private final static Map TASK_QUEUE = new ConcurrentHashMap<>();
/**
* 任务错误队列
*/
private final static Map ERROR_TASK_QUEUE = new ConcurrentHashMap<>();
/**
* 正在运行的队列
*/
private final static Set RUNNING_TASK_QUEUE = new HashSet<>();
/**
* 控制 Spark 并发的信号量
*/
private final Semaphore semaphore;
/**
* 公平锁
*/
private final static Lock LOCK = new ReentrantLock();
private static AsyncTaskQueue SPARK_QUEUE;
private AsyncTaskQueue(int permits) {
semaphore = new Semaphore(permits);
}
/**
* 初次调用有效,
* @return
*/
public static AsyncTaskQueue getInstance() {
return getInstance(3);
}
/**
* 按照配置,设置并发量。 第一次调用有效
* @param permits
* @return
*/
public static synchronized AsyncTaskQueue getInstance(int permits) {
if(SPARK_QUEUE == null) {
SPARK_QUEUE = new AsyncTaskQueue(permits);
}
return SPARK_QUEUE;
}
/**
* 添加任务
* @param taskId
* @param taskInfo
*/
public static boolean addTask(String taskId, Map taskInfo) {
LOCK.lock();
try {
if(!TASK_QUEUE.containsKey(taskId)) {
TASK_QUEUE.put(taskId, taskInfo);
log.info("add task: {} , params: {}", taskId, String.valueOf(taskInfo));
return true;
}
} finally {
LOCK.unlock();
}
return false;
}
/**
* 获取当前需要执行队列的长度
* @return
*/
public static int getPendingTaskSize() {
LOCK.lock();
try {
HashMap tmpMap = new HashMap<>(TASK_QUEUE);
for (String s : RUNNING_TASK_QUEUE) {
tmpMap.remove(s);
}
return tmpMap.size();
} finally {
LOCK.unlock();
}
}
/**
* 获取当前需要执行队列
* @return
*/
public static Set getPendingTasks() {
LOCK.lock();
try {
HashMap tmpMap = new HashMap<>(TASK_QUEUE);
for (String s : RUNNING_TASK_QUEUE) {
tmpMap.remove(s);
}
return tmpMap.keySet();
} finally {
LOCK.unlock();
}
}
/**
* 获取当前正在执行任务的长度
* @return
*/
public static int getRunningTaskSize() {
return RUNNING_TASK_QUEUE.size();
}
public static Object getTaskInfo(String taskId) {
return TASK_QUEUE.get(taskId);
}
/**
* 移除任务
* @param taskId
*/
public static void removeTask(String taskId) {
LOCK.lock();
try {
TASK_QUEUE.remove(taskId);
RUNNING_TASK_QUEUE.remove(taskId);
log.info("remove task: {}", taskId);
} finally {
LOCK.unlock();
}
}
/**
* 错误的任务报告
* @param taskId
*/
public static void reportError(String taskId) {
LOCK.lock();
try {
Object errorTaskInfo = TASK_QUEUE.remove(taskId);
ERROR_TASK_QUEUE.put(taskId, errorTaskInfo);
RUNNING_TASK_QUEUE.remove(taskId);
} finally {
LOCK.unlock();
}
}
/**
* 判断任务是否正在运行
* @param taskId
* @return
*/
public static boolean runningTask(String taskId) {
return RUNNING_TASK_QUEUE.contains(taskId);
}
/**
* 执行该函数
* @param executor
* @param task
*/
public void execute(Executor executor, final SuperTask task) {
executor.execute(()->{
final String runningTaskId = task.getTaskId();
// 有任务需要运行
if(AsyncTaskQueue.runningTask(runningTaskId)) {
// 取得的待运行的task,不能是正在运行的列表中的
log.info("task {} running.", runningTaskId);
return;
}
// 获得一个许可
try {
semaphore.acquire();
} catch (InterruptedException e) {
return;
}
try {
// 运行任务
RUNNING_TASK_QUEUE.add(runningTaskId);
log.info("running task: {}", runningTaskId);
task.run();
log.info("finished task: {}", runningTaskId);
// 执行成功,移除
removeTask(runningTaskId);
} catch (Exception e) {
log.info("执行任务异常。error task: " + runningTaskId, e);
// 运行错误
reportError(runningTaskId);
} finally {
// 释放许可
semaphore.release();
}
});
}
}