Peter Farlow
04/05/2024, 6:17 AMPeter Farlow
04/05/2024, 6:17 AMclass PriorityInterruptibleExecutorService<R>(coroutineScope: CoroutineScope, priorityLevels: Int = 3) {
private val channels = List(priorityLevels) {
Channel<TaskData<R>>(Channel.UNLIMITED)
}
init {
coroutineScope.launch {
while (isActive) {
val taskData = select {
for (channel in channels) {
channel.onReceive { it }
}
}
handleTask(taskData)
}
}
}
suspend fun submit(priority: Int, block: suspend () -> R): Deferred<R> {
val taskPair = TaskData(priority, Task(block), CompletableDeferred())
channels[priority].send(taskPair)
return taskPair.future
}
private suspend fun handleTask(currentTaskData: TaskData<R>): Boolean = coroutineScope {
if (!isTaskInterruptible(currentTaskData)) {
currentTaskData.future.complete(currentTaskData.task.block())
} else {
val currentTaskBlock = currentTaskData.taskJob ?: async { currentTaskData.task.block() }
val next = getDataOrNextTask(currentTaskBlock)
val currentResult = next.first
if (currentResult != null) {
currentTaskData.future.complete(currentResult)
} else {
val nextTask = requireNotNull(next.second)
if (shouldInterruptTask(nextTask.priority)) {
val failure = TaskInterruptedException()
currentTaskBlock.cancel(failure)
currentTaskData.future.completeExceptionally(failure)
nextTask.future.complete(currentTaskData.task.block())
} else {
handleTask(currentTaskData.copy(taskJob = currentTaskBlock))
}
}
}
}
private suspend fun getDataOrNextTask(data: Deferred<R>): Pair<R?, TaskData<R>?> = select {
data.onAwait { Pair(it, null) }
for (channel in channels) {
channel.onReceive { Pair(null, it) }
}
}
private fun shouldInterruptTask(nextTaskPriority: Int) = nextTaskPriority == 0
private fun isTaskInterruptible(taskData: TaskData<R>) = taskData.priority == channels.lastIndex
}
class TaskInterruptedException : CancellationException()
data class TaskData<R>(
val priority: Int,
val task: Task<R>,
val future: CompletableDeferred<R>,
val taskJob: Deferred<R>? = null
)
data class Task<R>(val block: suspend () -> R)
Peter Farlow
04/05/2024, 6:17 AMfun test() = runBlocking {
val job = launch {
val service = PriorityInterruptibleExecutorService<String>(this)
val first = service.submit(2) {
delay(2_000)
"first".also { println(it) }
}
first.invokeOnCompletion { println("got ${first.getCompleted()} and cancelled ${first.isCancelled}") }
val second = service.submit(0) {
delay(40)
"second".also { println(it) }
}
second.invokeOnCompletion { println("got ${second.getCompleted()}") }
val third = service.submit(0) {
delay(40)
"third".also { println(it) }
}
third.invokeOnCompletion { println("got ${third.getCompleted()}") }
}
delay(3_000)
job.cancelAndJoin()
}
Peter Farlow
04/05/2024, 6:17 AMsecond
got second
third
got third
first
got first and cancelled false
Peter Farlow
04/05/2024, 6:18 AMPeter Farlow
04/05/2024, 6:19 AMSam
04/05/2024, 7:56 AMhandleTask
function is currently very complicated, and the fact that you're reading from the channel(s) in two different places (the select loop and the handleTask
function) makes it very hard to keep track of what's going on.
Consider moving the cancellation logic into the loop itself. While each task is running, the select
loop can also continue to run, waiting for any tasks with higher priority than the current task. If it finds one, it can cancel the current task and start the new one. On the other hand, if the current task completes successfully, the select loop can continue to its next iteration and start waiting for tasks with any priority.
It still might not fix your current test case, but it should at least make it easier to grok the problem.
As an aside, it's best to avoid suspend
functions that also return a Job
or Deferred
. Since your channels are all UNLIMITED
, you don't really need to use the suspending version of send
in your submit
function.Peter Farlow
04/05/2024, 12:47 PMinit {
coroutineScope.launch {
var currentTaskData: TaskData<R>? = null
while (isActive) {
val currentTaskExecution = currentTaskData?.let { async { it.task.block() } }
val resultOrTask = getDataOrNextTask(currentTaskExecution)
val result = resultOrTask.first
if (result != null) {
currentTaskData!!.future.complete(result)
} else {
currentTaskData = resultOrTask.second
}
}
}
}
Peter Farlow
04/05/2024, 12:48 PMgetDataOrNextTask()
was updated to make the argument nullable)Peter Farlow
04/05/2024, 12:50 PMPeter Farlow
04/05/2024, 1:57 PMinit {
coroutineScope.launch {
val currentTaskChannel: Channel<TaskData<R>> = Channel(Channel.UNLIMITED)
while (isActive) {
val currentTaskData = currentTaskChannel.tryReceive().getOrNull()
val currentTaskExecution = currentTaskData?.taskJob ?: currentTaskData?.let { async { it.task.block() } }
val resultOrTask = getDataOrNextTask(currentTaskExecution)
val result = resultOrTask.first
if (result != null) {
currentTaskData!!.future.complete(Result.success(result))
} else {
val nextTask = requireNotNull(resultOrTask.second)
if (shouldInterruptTask(currentTaskData, nextTask)) {
val failure = TaskInterruptedException()
currentTaskData?.future?.complete(Result.failure(failure))
currentTaskExecution?.cancel(failure)
currentTaskChannel.trySend(nextTask)
} else {
if (currentTaskData != null) {
currentTaskChannel.trySend(currentTaskData.copy(taskJob = currentTaskExecution))
}
currentTaskChannel.trySend(nextTask)
}
}
}
}
}
private fun shouldInterruptTask(taskData: TaskData<R>?, nextTask: TaskData<R>?) =
taskData?.priority == channels.lastIndex && nextTask?.priority == 0
fun submit(priority: Int, tag: String, block: suspend () -> R): Deferred<Result<R>> {
val taskPair = TaskData(tag, priority, Task(block), CompletableDeferred())
channels[priority].trySend(taskPair)
return taskPair.future
}
private suspend fun getDataOrNextTask(data: Deferred<R>?): Pair<R?, TaskData<R>?> = select {
data?.onAwait?.invoke { Pair(it, null) }
for (channel in channels) {
channel.onReceive { Pair(null, it) }
}
}
Peter Farlow
04/05/2024, 1:58 PM