Hello, I've created this `BatchProcessor<I,O>` cla...
# coroutines
g
Hello, I've created this
BatchProcessor<I,O>
class to process
(I) -> O
in batches of
List<I>
to
List<O>
(the batches are defined by a maximum number of
I
or a maximum duration to receive
List<I>
. The difficulty here is that I want to be able to retrieve the individual response (
val o = batching.process(i)
) even if the actual calculation is done by batches 👇
Copy code
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.future.future
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.io.Closeable
import java.util.concurrent.CompletableFuture

class BatchProcessor<I,O>(
    val maxMessage: Int,
    val maxDuration: Long,
    val processor: (List<I>) -> List<O>
) : Closeable {
    private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
    private val outputChannel = MutableSharedFlow<List<Pair<I, O>>>()
    private val inputChannel = Channel<I>()

    override fun close() = scope.cancel()

    // start listening and processing inputChannel by batch
    fun startAsync() = scope.future {
        inputChannel
            .receiveAsFlow()
            .collectBatch(maxMessage, maxDuration) { messages ->
            outputChannel.emit(
                messages.zip(processor(messages))
            )
        }
    }

    fun process(message: I): O = processAsync(message).get()

    fun processAsync(message: I): CompletableFuture<O> = scope.future {
        // Launch a coroutine to send the message to the inputChannel
        launch { inputChannel.send(message) }

        // Retrieve the response from the outputChannel corresponding to the sent message
        outputChannel
            .first { it.any { it.first == message } } // Wait for the first element in the outputChannel that contains the message
            .first { it.first == message }            // Get the first pair where the first element matches the sent message
            .second                                   // Return the "second" part of the pair
    }

    private suspend fun <M> Flow<M>.collectBatch(
        maxMessage: Int,       // Maximum number of messages to collect before processing the batch
        maxDuration: Long,     // Maximum duration (in milliseconds) before processing the batch
        action: suspend (List<M>) -> Unit  // Action to execute on each batch of messages
    ) = coroutineScope {
        // Preconditions check
        require(maxMessage > 1) { "batch size must be > 1" }
        require(maxDuration > 1) { "batch duration must be > 1" }

        // Initialize buffer and mutex for synchronization
        val buffer = mutableListOf<M>()
        val bufferMutex = Mutex()
        lateinit var timeoutJob: Job

        // Function to start the timeout job
        fun CoroutineScope.startTimeoutJob() = launch {
            try {
                delay(maxDuration)
                // If we reach the timeout before reaching the batch size
                val batch = bufferMutex.withLock {
                    ArrayList(buffer).also { buffer.clear() }
                }
                // If batch is not empty, execute the action on the batch
                if (batch.isNotEmpty()) {
                    action(batch)
                }
            } catch (e: CancellationException) {
                // Do nothing if the job is cancelled
            }
        }

        try {
            // Collect elements from the flow
            collect { value ->
                var batch: List<M>? = null
                bufferMutex.withLock {
                    // Add the current element to the buffer
                    buffer.add(value)
                    when (buffer.size) {
                        // After the first element, start the timeout job
                        1 -> timeoutJob = startTimeoutJob()
                        // If we reach the batch size before the timeout
                        maxMessage -> {
                            timeoutJob.cancel()
                            batch = ArrayList(buffer)
                            buffer.clear()
                        }

                        else -> null
                    }
                }
                // If a batch is ready, execute the action on it
                batch?.let { action(it) }
            }
        } catch (e: CancellationException) {
            // Do nothing in case of cancellation
        } finally {
            // Cancel the timeout job
            timeoutJob.cancel()
        }
    }
}
Here is an example of use:
Copy code
import kotlin.random.Random
import java.util.concurrent.atomic.AtomicInteger

fun main()  {
    val processor: (List<Int>) -> List<String> = { list -> list.map { it.toString() }}

    val i = AtomicInteger(0)
    BatchProcessor(maxMessage = 100, maxDuration = 1000, processor).use { batching ->
        batching.startAsync()

        repeat(1000) {
            val futures = List(200) {
                batching.processAsync(i.incrementAndGet())
            }
            futures.forEach { println(it.get()) }
        }

        println("done")
    }
}
Note:
BatchProcessor
is used in a Java context, so I can not used suspend methods. This seems to work well, but for some values of the loops above the main function is stuck after a few seconds. This does not seem to happen when
outputChannel
has some replay parameter (eg.
val outputChannel = MutableSharedFlow<List<Pair<I, O>>>(10)
) - any help to understand or improve this code would be greatly appreciated!
s
Instead of an output channel, I'd create the
CompletableFuture
right away, and put it into the input channel alongside the input. Much easier to keep track of everything that way!
Here's my attempt (untested)
Copy code
class BatchProcessor<I, O>(
  val maxMessage: Int,
  val maxDuration: Long,
  val processor: (List<I>) -> List<O>
) : Closeable {
  private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
  override fun close() = scope.cancel()

  private inner class Task(val input: I, val result: CompletableDeferred<O> = CompletableDeferred())

  private val tasks = Channel<Task>()

  fun processAsync(message: I): CompletableFuture<O> = scope.future {
    Task(message).also { tasks.send(it) }.result.await()
  }

  fun startAsync() = scope.launch {
    while (isActive) {
      nextBatch().process()
    }
  }

  private fun List<Task>.process() {
    val results = processor(map { it.input })
    zip(results).forEach { (task, output) ->
      task.result.complete(output)
    }
  }

  private suspend fun nextBatch(): List<Task> = buildList {
    while (isEmpty()) withTimeoutOrNull(maxDuration) {
      while (size < maxMessage && isActive) {
        for (task in tasks) add(task)
      }
    }
  }
}
👀 1
g
It's an impressive improvement. I've had to change a bit
nextBatch
to start the timeout after the first message of the batch:
Copy code
private suspend fun nextBatch(): List<Task> = buildList {
        add(tasks.receive())
        withTimeoutOrNull(maxDuration) {
            for (task in tasks) {
                add(task)
                if (size == maxMessage) break
            }
        }
    }
But I still have an issue (with your version or mine) that I fail to understand - when I do:
Copy code
suspend fun main()  {
    val processor: (List<Int>) -> List<String> = { list -> list.map { it.toString() }}

    val i = AtomicInteger(0)
    BatchProcessor(maxMessage = 100, maxDuration = 1000, processor).use { batching ->
        batching.startAsync()

        val futures = List(30) {
            delay(100)
            batching.processAsync(i.incrementAndGet())
        }
        futures.forEach { println(it.get()) }

        println("done")
    }
}
the last get() never comes back. It works when I remove the
delay
???
s
😄 your
nextBatch()
function makes much more sense than mine, nice improvements/fixes 👍.
As for your
delay()
issue, I wonder if it's related to the fact that you're using a suspending
main()
function? You're kind of mixing suspending/blocking worlds by doing that. And calling
delay()
in
suspend fun main()
is actually going to switch you onto a different thread, for funky reasons.
Since you're using
get()
, you're best off testing this in non-coroutine land, and leaving suspending functions out of it on the consumer side.
Either that, or switch from
get()
to
await()
g
you are right - it works when I do
it.await()
OR use a non suspend main function, replacing
delay(100)
by
Thread.sleep(100)
. But I still do not understand why, as the scopes are different
For the record, here is the final code:
Copy code
class BatchProcessor<I, O>(
    val maxMessage: Int,
    val maxDuration: Long,
    val processor: (List<I>) -> List<O>
) : Closeable {
    private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
    override fun close() = scope.cancel()

    private inner class Task(val input: I, val output: CompletableDeferred<O> = CompletableDeferred())

    private val tasks = Channel<Task>()

    fun processAsync(message: I): CompletableFuture<O> = scope.future {
        Task(message).also { tasks.send(it) }.output.await()
    }

    fun process(message: I): O = processAsync(message).get()

    fun startAsync() = scope.future {
        while (isActive) {
            nextBatch().process()
        }
    }

    private fun List<Task>.process() {
        val results = processor(map { it.input })
        zip(results).forEach { (task, output) ->
            task.output.complete(output)
        }
    }

    private suspend fun nextBatch(): List<Task> = buildList {
        add(tasks.receive())
        withTimeoutOrNull(maxDuration) {
            while (size < maxMessage) {
                add(tasks.receive())
            }
        }
    }
}
Thank you @Sam!
s
Glad to help! Still a bit stumped by the deadlock issue, it feels like there should be a simple explanation but I haven't spotted it yet
g
Yeah it's really strange - it seems the last
nextBatch
never ends
s
Okay, so I figured it out
In
suspend fun main()
, when you call
delay()
, it schedules a continuation using a timer thread. After that, the remainder of your main function runs entirely on that timer thread. Turns out, the same timer thread is also used internally by the
withTimeoutOrNull
function.
So the blocking calls in the main function end up blocking the timer thread and preventing the timeout from resuming. Nasty.
🤯 1
The basic fix is the same as always: don't block threads that belong to coroutines, unless using Dispatchers.IO. But the mechanism behind it sure is an interesting one! I think it could only really happen when using
suspend fun main()
. Maybe also with
Dispatchers.Unconfined
.
g
After the
delay
in main, the thread is indeed
kotlinx.coroutines.DefaultExecutor
. And yes it seems that this thread is used internally by
withTimeoutOrNull
.
So I guess that this thread is blocked by the
get()
and
withTimeoutOrNull
does not complete. As you said. It's tricky.
Actually, there is still a tricky issue here:
Copy code
private suspend fun nextBatch(): List<Task> = buildList {
        add(tasks.receive())
        withTimeoutOrNull(maxDuration) {
            while (size < maxMessage) {
                add(tasks.receive())
            }
        }
    }
it can happen the
withTimeoutOrNull
triggers in the same time than
tasks.receive()
and the corresponding task of the channel will be skipped...