I’m attempting to implement a command processor th...
# coroutines
p
I’m attempting to implement a command processor that can accept commands of different priority levels as well as cancelling low priority commands (which in the domain I’m working in are likely to be very long) when a new higher priority command is enqueued behind it. Using Channels and a select statement I found it easy to implement one consumer coroutine that automatically favored commands from the higher priority channels, but handling cancelling lower priority commands is harder. I’ve got I thought should work my unit test is showing the task isn’t getting cancelled 🧵
Copy code
class PriorityInterruptibleExecutorService<R>(coroutineScope: CoroutineScope, priorityLevels: Int = 3) {
    private val channels = List(priorityLevels) {
        Channel<TaskData<R>>(Channel.UNLIMITED)
    }

    init {
        coroutineScope.launch {
            while (isActive) {
                val taskData = select {
                    for (channel in channels) {
                        channel.onReceive { it }
                    }
                }
                handleTask(taskData)
            }
        }
    }

    suspend fun submit(priority: Int, block: suspend () -> R): Deferred<R> {
        val taskPair = TaskData(priority, Task(block), CompletableDeferred())
        channels[priority].send(taskPair)
        return taskPair.future
    }

    private suspend fun handleTask(currentTaskData: TaskData<R>): Boolean = coroutineScope {
        if (!isTaskInterruptible(currentTaskData)) {
            currentTaskData.future.complete(currentTaskData.task.block())
        } else {
            val currentTaskBlock = currentTaskData.taskJob ?: async { currentTaskData.task.block() }
            val next = getDataOrNextTask(currentTaskBlock)
            val currentResult = next.first
            if (currentResult != null) {
                currentTaskData.future.complete(currentResult)
            } else {
                val nextTask = requireNotNull(next.second)
                if (shouldInterruptTask(nextTask.priority)) {
                    val failure = TaskInterruptedException()
                    currentTaskBlock.cancel(failure)
                    currentTaskData.future.completeExceptionally(failure)
                    nextTask.future.complete(currentTaskData.task.block())
                } else {
                    handleTask(currentTaskData.copy(taskJob = currentTaskBlock))
                }
            }
        }
    }

    private suspend fun getDataOrNextTask(data: Deferred<R>): Pair<R?, TaskData<R>?> = select {
        data.onAwait { Pair(it, null) }
        for (channel in channels) {
            channel.onReceive { Pair(null, it) }
        }
    }

    private fun shouldInterruptTask(nextTaskPriority: Int) = nextTaskPriority == 0
    private fun isTaskInterruptible(taskData: TaskData<R>) = taskData.priority == channels.lastIndex
}

class TaskInterruptedException : CancellationException()

data class TaskData<R>(
    val priority: Int,
    val task: Task<R>,
    val future: CompletableDeferred<R>,
    val taskJob: Deferred<R>? = null
)

data class Task<R>(val block: suspend () -> R)
Copy code
fun test() = runBlocking {
        val job = launch {
            val service = PriorityInterruptibleExecutorService<String>(this)
            val first = service.submit(2) {
                delay(2_000)
                "first".also { println(it) }
            }
            first.invokeOnCompletion { println("got ${first.getCompleted()} and cancelled ${first.isCancelled}") }
            val second = service.submit(0) {
                delay(40)
                "second".also { println(it) }
            }
            second.invokeOnCompletion { println("got ${second.getCompleted()}") }
            val third = service.submit(0) {
                delay(40)
                "third".also { println(it) }
            }
            third.invokeOnCompletion { println("got ${third.getCompleted()}") }
        }
        delay(3_000)
        job.cancelAndJoin()
    }
my test prints
Copy code
second
got second
third
got third
first
got first and cancelled false
I’m not sure why the first lowest priority task (highest index) returns false for cancelled.
Also, in general, I’m not sure if I’m taking the right approach here. Channels and select statements seems great for the priority portion, but I’m not sure if they fit into cancellation very well
s
Your select loop is a good starting point 👍. > I’m not sure why the first lowest priority task (highest index) returns false for cancelled. The reason is fairly simple: your loop reads from the higher priority channels first, so the higher priority tasks have already started by the time the lower priority task is considered. Your cancellation logic only considers tasks that aren't yet running, i.e. ones that are still in the queue (channel). But by the time the lower priority task checks whether it should cancel itself, the queue is already empty. > Also, in general, I’m not sure if I’m taking the right approach here. I think you'll find it easier to reason about if you separate the cancellation logic from the task handling itself. Your
handleTask
function is currently very complicated, and the fact that you're reading from the channel(s) in two different places (the select loop and the
handleTask
function) makes it very hard to keep track of what's going on. Consider moving the cancellation logic into the loop itself. While each task is running, the
select
loop can also continue to run, waiting for any tasks with higher priority than the current task. If it finds one, it can cancel the current task and start the new one. On the other hand, if the current task completes successfully, the select loop can continue to its next iteration and start waiting for tasks with any priority. It still might not fix your current test case, but it should at least make it easier to grok the problem. As an aside, it's best to avoid
suspend
functions that also return a
Job
or
Deferred
. Since your channels are all
UNLIMITED
, you don't really need to use the suspending version of
send
in your
submit
function.
p
Thank you, Sam, very helpful as usual! I think you’re right, after your perspective I think I can simplify this greatly:
Copy code
init {
        coroutineScope.launch {
            var currentTaskData: TaskData<R>? = null
            while (isActive) {
                val currentTaskExecution = currentTaskData?.let { async { it.task.block() } }
                val resultOrTask = getDataOrNextTask(currentTaskExecution)
                val result = resultOrTask.first
                if (result != null) {
                    currentTaskData!!.future.complete(result)
                } else {
                    currentTaskData = resultOrTask.second
                }
            }
        }
    }
(
getDataOrNextTask()
was updated to make the argument nullable)
oh, I forgot to actually handle cancelling the task 🙃. I was up late with this one last night
Copy code
init {
        coroutineScope.launch {
            val currentTaskChannel: Channel<TaskData<R>> = Channel(Channel.UNLIMITED)
            while (isActive) {
                val currentTaskData = currentTaskChannel.tryReceive().getOrNull()
                val currentTaskExecution = currentTaskData?.taskJob ?: currentTaskData?.let { async { it.task.block() } }
                val resultOrTask = getDataOrNextTask(currentTaskExecution)
                val result = resultOrTask.first
                if (result != null) {
                    currentTaskData!!.future.complete(Result.success(result))
                } else {
                    val nextTask = requireNotNull(resultOrTask.second)
                    if (shouldInterruptTask(currentTaskData, nextTask)) {
                        val failure = TaskInterruptedException()
                        currentTaskData?.future?.complete(Result.failure(failure))
                        currentTaskExecution?.cancel(failure)
                        currentTaskChannel.trySend(nextTask)
                    } else {
                        if (currentTaskData != null) {
                            currentTaskChannel.trySend(currentTaskData.copy(taskJob = currentTaskExecution))
                        }
                        currentTaskChannel.trySend(nextTask)
                    }
                }
            }
        }
    }

    private fun shouldInterruptTask(taskData: TaskData<R>?, nextTask: TaskData<R>?) =
        taskData?.priority == channels.lastIndex && nextTask?.priority == 0

    fun submit(priority: Int, tag: String, block: suspend () -> R): Deferred<Result<R>> {
        val taskPair = TaskData(tag, priority, Task(block), CompletableDeferred())
        channels[priority].trySend(taskPair)
        return taskPair.future
    }

    private suspend fun getDataOrNextTask(data: Deferred<R>?): Pair<R?, TaskData<R>?> = select {
        data?.onAwait?.invoke { Pair(it, null) }
        for (channel in channels) {
            channel.onReceive { Pair(null, it) }
        }
    }
Maybe there’s a way to simplify it further but I think this is better!