Gilles Barbier
10/01/2024, 11:54 AMBatchProcessor<I,O>
class to process (I) -> O
in batches of List<I>
to List<O>
(the batches are defined by a maximum number of I
or a maximum duration to receive List<I>
. The difficulty here is that I want to be able to retrieve the individual response (val o = batching.process(i)
) even if the actual calculation is done by batches 👇Gilles Barbier
10/01/2024, 12:03 PMimport kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.future.future
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.io.Closeable
import java.util.concurrent.CompletableFuture
class BatchProcessor<I,O>(
val maxMessage: Int,
val maxDuration: Long,
val processor: (List<I>) -> List<O>
) : Closeable {
private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
private val outputChannel = MutableSharedFlow<List<Pair<I, O>>>()
private val inputChannel = Channel<I>()
override fun close() = scope.cancel()
// start listening and processing inputChannel by batch
fun startAsync() = scope.future {
inputChannel
.receiveAsFlow()
.collectBatch(maxMessage, maxDuration) { messages ->
outputChannel.emit(
messages.zip(processor(messages))
)
}
}
fun process(message: I): O = processAsync(message).get()
fun processAsync(message: I): CompletableFuture<O> = scope.future {
// Launch a coroutine to send the message to the inputChannel
launch { inputChannel.send(message) }
// Retrieve the response from the outputChannel corresponding to the sent message
outputChannel
.first { it.any { it.first == message } } // Wait for the first element in the outputChannel that contains the message
.first { it.first == message } // Get the first pair where the first element matches the sent message
.second // Return the "second" part of the pair
}
private suspend fun <M> Flow<M>.collectBatch(
maxMessage: Int, // Maximum number of messages to collect before processing the batch
maxDuration: Long, // Maximum duration (in milliseconds) before processing the batch
action: suspend (List<M>) -> Unit // Action to execute on each batch of messages
) = coroutineScope {
// Preconditions check
require(maxMessage > 1) { "batch size must be > 1" }
require(maxDuration > 1) { "batch duration must be > 1" }
// Initialize buffer and mutex for synchronization
val buffer = mutableListOf<M>()
val bufferMutex = Mutex()
lateinit var timeoutJob: Job
// Function to start the timeout job
fun CoroutineScope.startTimeoutJob() = launch {
try {
delay(maxDuration)
// If we reach the timeout before reaching the batch size
val batch = bufferMutex.withLock {
ArrayList(buffer).also { buffer.clear() }
}
// If batch is not empty, execute the action on the batch
if (batch.isNotEmpty()) {
action(batch)
}
} catch (e: CancellationException) {
// Do nothing if the job is cancelled
}
}
try {
// Collect elements from the flow
collect { value ->
var batch: List<M>? = null
bufferMutex.withLock {
// Add the current element to the buffer
buffer.add(value)
when (buffer.size) {
// After the first element, start the timeout job
1 -> timeoutJob = startTimeoutJob()
// If we reach the batch size before the timeout
maxMessage -> {
timeoutJob.cancel()
batch = ArrayList(buffer)
buffer.clear()
}
else -> null
}
}
// If a batch is ready, execute the action on it
batch?.let { action(it) }
}
} catch (e: CancellationException) {
// Do nothing in case of cancellation
} finally {
// Cancel the timeout job
timeoutJob.cancel()
}
}
}
Here is an example of use:
import kotlin.random.Random
import java.util.concurrent.atomic.AtomicInteger
fun main() {
val processor: (List<Int>) -> List<String> = { list -> list.map { it.toString() }}
val i = AtomicInteger(0)
BatchProcessor(maxMessage = 100, maxDuration = 1000, processor).use { batching ->
batching.startAsync()
repeat(1000) {
val futures = List(200) {
batching.processAsync(i.incrementAndGet())
}
futures.forEach { println(it.get()) }
}
println("done")
}
}
Note: BatchProcessor
is used in a Java context, so I can not used suspend methods. This seems to work well, but for some values of the loops above the main function is stuck after a few seconds. This does not seem to happen when outputChannel
has some replay parameter (eg. val outputChannel = MutableSharedFlow<List<Pair<I, O>>>(10)
) - any help to understand or improve this code would be greatly appreciated!Sam
10/01/2024, 12:58 PMCompletableFuture
right away, and put it into the input channel alongside the input. Much easier to keep track of everything that way!Sam
10/01/2024, 12:58 PMclass BatchProcessor<I, O>(
val maxMessage: Int,
val maxDuration: Long,
val processor: (List<I>) -> List<O>
) : Closeable {
private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
override fun close() = scope.cancel()
private inner class Task(val input: I, val result: CompletableDeferred<O> = CompletableDeferred())
private val tasks = Channel<Task>()
fun processAsync(message: I): CompletableFuture<O> = scope.future {
Task(message).also { tasks.send(it) }.result.await()
}
fun startAsync() = scope.launch {
while (isActive) {
nextBatch().process()
}
}
private fun List<Task>.process() {
val results = processor(map { it.input })
zip(results).forEach { (task, output) ->
task.result.complete(output)
}
}
private suspend fun nextBatch(): List<Task> = buildList {
while (isEmpty()) withTimeoutOrNull(maxDuration) {
while (size < maxMessage && isActive) {
for (task in tasks) add(task)
}
}
}
}
Gilles Barbier
10/01/2024, 2:31 PMnextBatch
to start the timeout after the first message of the batch:
private suspend fun nextBatch(): List<Task> = buildList {
add(tasks.receive())
withTimeoutOrNull(maxDuration) {
for (task in tasks) {
add(task)
if (size == maxMessage) break
}
}
}
But I still have an issue (with your version or mine) that I fail to understand - when I do:
suspend fun main() {
val processor: (List<Int>) -> List<String> = { list -> list.map { it.toString() }}
val i = AtomicInteger(0)
BatchProcessor(maxMessage = 100, maxDuration = 1000, processor).use { batching ->
batching.startAsync()
val futures = List(30) {
delay(100)
batching.processAsync(i.incrementAndGet())
}
futures.forEach { println(it.get()) }
println("done")
}
}
the last get() never comes back. It works when I remove the delay
???Sam
10/01/2024, 2:35 PMnextBatch()
function makes much more sense than mine, nice improvements/fixes 👍.Sam
10/01/2024, 2:36 PMdelay()
issue, I wonder if it's related to the fact that you're using a suspending main()
function? You're kind of mixing suspending/blocking worlds by doing that. And calling delay()
in suspend fun main()
is actually going to switch you onto a different thread, for funky reasons.Sam
10/01/2024, 2:37 PMget()
, you're best off testing this in non-coroutine land, and leaving suspending functions out of it on the consumer side.Sam
10/01/2024, 2:37 PMget()
to await()
Gilles Barbier
10/01/2024, 2:40 PMit.await()
OR use a non suspend main function, replacing delay(100)
by Thread.sleep(100)
. But I still do not understand why, as the scopes are differentGilles Barbier
10/01/2024, 2:47 PMclass BatchProcessor<I, O>(
val maxMessage: Int,
val maxDuration: Long,
val processor: (List<I>) -> List<O>
) : Closeable {
private val scope = CoroutineScope(<http://Dispatchers.IO|Dispatchers.IO>)
override fun close() = scope.cancel()
private inner class Task(val input: I, val output: CompletableDeferred<O> = CompletableDeferred())
private val tasks = Channel<Task>()
fun processAsync(message: I): CompletableFuture<O> = scope.future {
Task(message).also { tasks.send(it) }.output.await()
}
fun process(message: I): O = processAsync(message).get()
fun startAsync() = scope.future {
while (isActive) {
nextBatch().process()
}
}
private fun List<Task>.process() {
val results = processor(map { it.input })
zip(results).forEach { (task, output) ->
task.output.complete(output)
}
}
private suspend fun nextBatch(): List<Task> = buildList {
add(tasks.receive())
withTimeoutOrNull(maxDuration) {
while (size < maxMessage) {
add(tasks.receive())
}
}
}
}
Thank you @Sam!Sam
10/01/2024, 2:54 PMGilles Barbier
10/01/2024, 2:57 PMnextBatch
never endsSam
10/01/2024, 3:03 PMSam
10/01/2024, 3:04 PMsuspend fun main()
, when you call delay()
, it schedules a continuation using a timer thread. After that, the remainder of your main function runs entirely on that timer thread. Turns out, the same timer thread is also used internally by the withTimeoutOrNull
function.Sam
10/01/2024, 3:04 PMSam
10/01/2024, 3:07 PMsuspend fun main()
. Maybe also with Dispatchers.Unconfined
.Gilles Barbier
10/01/2024, 4:37 PMdelay
in main, the thread is indeed kotlinx.coroutines.DefaultExecutor
. And yes it seems that this thread is used internally by withTimeoutOrNull
.Gilles Barbier
10/01/2024, 4:38 PMget()
and withTimeoutOrNull
does not complete. As you said. It's tricky.Gilles Barbier
11/01/2024, 2:49 PMprivate suspend fun nextBatch(): List<Task> = buildList {
add(tasks.receive())
withTimeoutOrNull(maxDuration) {
while (size < maxMessage) {
add(tasks.receive())
}
}
}
it can happen the withTimeoutOrNull
triggers in the same time than tasks.receive()
and the corresponding task of the channel will be skipped...