Hello, I have something that I cannot wrap my head...
# spring
t
Hello, I have something that I cannot wrap my head around when it comes to observability and coroutines (in a spring boot application). In particular, we want to restore propagate tracing information through SQS, so we are saving trace information on the producer side within the message, and on the consumer we restore the trace before processing the message. Pseudocode (no worries, more code in thread):
Copy code
// producer thread, saves its trace information within the message
producer.produce(message.withTraceDetails())

// consumer thread, actually even another application instance
// up to this consumer.read, its an independent trace
val message = consumer.read()
processAfterRestoring(message.traceDetails()) {
  // here the trace is restored from the message
  process(message)
}

fun processAfterRestoring(sourceTrace: Trace, supplier: () -> Any)
The code we have works in normal scenarios, but not when
process(message)
uses coroutines, eg
Copy code
// here we have TraceONE
val message = consumer.read()
processAfterRestoring(message.traceDetails()) {
  // here the trace is restored from the message
  process(message)
}

fun process(message: Message) {
  // here the trace is restored from the message
  launchNewCoroutine {
    // but here the trace is again TraceONE
    doSomethingOn(message)
  }
}

fun processAfterRestoring(sourceTrace: Trace, supplier: () -> Any)
I reproduce some more code (in thread, hopefully I captured all the relevant pieces). I guess the 3 questions that pop up to me are: • are we restoring traces incorrectly (not kotlin)? • is there some issue with our Spring context config (also not kotlin)? • what magic am I not seeing in
observationRegistry.asContextElement()
that causes the issue (since when we don't use coroutines, restoring works I am kinda discounting the first 2 questions and focusing more on this last one)?
reproduction code
Copy code
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.micrometer.context.ContextExecutorService
import io.micrometer.context.ContextRegistry
import io.micrometer.context.ContextSnapshotFactory
import io.micrometer.context.ThreadLocalAccessor
import io.micrometer.core.instrument.kotlin.asContextElement
import io.micrometer.observation.ObservationRegistry
import io.micrometer.tracing.Span
import io.micrometer.tracing.Tracer
import io.micrometer.tracing.contextpropagation.ObservationAwareSpanThreadLocalAccessor
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import org.slf4j.LoggerFactory
import org.slf4j.MDC
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.web.client.TestRestTemplate
import org.springframework.boot.test.web.client.getForObject
import org.springframework.boot.test.web.server.LocalServerPort
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.stereotype.Component
import org.springframework.test.context.TestConstructor
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ThreadPoolExecutor
import kotlin.use

@SpringBootTest(
    webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
)
@TestConstructor(autowireMode = TestConstructor.AutowireMode.ALL)
class ObservationIT(
    @LocalServerPort private val port: Int = 0,
    @Value("\${server.servlet.context-path}") private val contextPath: String,
    private val testRestTemplate: TestRestTemplate
) {
    private val baseUrl = "<http://localhost>:$port$contextPath"

    @Test
    fun `coroutine vs java completable future`() {
        val traceId = testRestTemplate.getForObject<String>("$baseUrl/api/v1/observations").shouldNotBeNull()

        // Conventions:
        // - runningXYZ: this is the trace that received the HTTP call
        // - restoredXYZ: this is the trace the code is executed in (after it has been restored)
        // - restoredBeforeCoroutine: special case, this is the trace AFTER it has been restored, but before execution is delegated to the coroutine
        // - restoredInCoroutine: special case, this is the trace within the coroutine code
        val (restoredCoroutine, runningCoroutine) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-coroutine/$traceId").shouldNotBeNull()
        val (restoredBeforeCoroutine, restoredInCoroutine) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-coroutine-weirdness/$traceId").shouldNotBeNull()
        val (restoredCompletable, runningCompletable) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-future/$traceId").shouldNotBeNull()

        // this is expected, since the rest interface is the boundary for a new observation
        traceId.traceId()
            .shouldNotBe(runningCoroutine.traceId())
            .shouldNotBe(runningCompletable.traceId())

        traceId.traceId()
            .shouldBe(restoredBeforeCoroutine.traceId())
            // actually, we expect that the 3 traces should be the same, but `observationRegistry.asContextElement()` itself restored to the trace received by the controller
            .shouldNotBe(restoredInCoroutine.traceId())

        traceId.traceId()
            .shouldBe(restoredCompletable.traceId())
            // actually, we expect that the 3 traces should be the same, but `observationRegistry.asContextElement()` itself restored to the trace received by the controller
            .shouldNotBe(restoredCoroutine.traceId())
    }

    private fun String.traceId() = this.split("/")[0]
    private fun String.spanId() = this.split("/")[1]
}

@RestController
@RequestMapping("/api/v1/observations")
class ObservationControllerTest(
    private val observation: CustomObservation,
    private val observationRegistry: ObservationRegistry,
    private val eventsExecutorService: ExecutorService,
) {
    // Generates a trace to be restored afterwards
    @GetMapping
    fun getTraceToRestore() = "${context()?.traceId()}/${context()?.spanId()}"

    @GetMapping("/consume-future/{traceId}/{spanId}")
    fun consumeFutureInRestoredTrace(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
        // new trace
        observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
            // trace is restored from path params
            CompletableFuture.supplyAsync({
                // trace is still the restored one
                fullTrace()
            }, eventsExecutorService).join()
        } to fullTrace()

    @GetMapping("/consume-coroutine/{traceId}/{spanId}")
    fun consumeCoroutineInRestoredTrace(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
        // new trace (TraceOne)
        observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
            // trace is restored from path params
            runBlocking {
                // trace is back to TraceOne
                CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
                    fullTrace()
                }.await()
            }
        } to fullTrace()

    @GetMapping("/consume-coroutine-weirdness/{traceId}/{spanId}")
    fun consumeCoroutineInRestoredTraceWeirdness(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
        observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
            fullTrace() to
            runBlocking {
                CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
                    fullTrace()
                }.await()
            }
        }

    private fun fullTrace() = "${traceId()}/${spanId()}"

    private fun context() = observationRegistry.currentSpan()?.context()
}

@Configuration
class ObservationConfig {
    @Bean
    fun eventsExecutorService(): ExecutorService =
        ContextExecutorService.wrap(
            Executors.newVirtualThreadPerTaskExecutor(), 
            ContextSnapshotFactory.builder().contextRegistry(ContextRegistry.getInstance()).build()
        )

    @Bean
    fun contextAccessor(tracer: Tracer, observationRegistry: ObservationRegistry) =
        ObservationAwareSpanThreadLocalAccessor(observationRegistry, tracer).also {
            ContextRegistry.getInstance().registerThreadLocalAccessor(it)
        }
}

@Component
class CustomObservation(
    private val observationRegistry: ObservationRegistry,
    private val tracer: Tracer,
    private val spanAccessor: ThreadLocalAccessor<Span>,
) {
    private val contextSnapshotFactory = ContextSnapshotFactory.builder().contextRegistry(ContextRegistry.getInstance()).build()

    fun <T> consumeRestoringCorrelation(traceId: String?, spanId: String?, contextName: String, callable: () -> T): T =
        executeIn(traceId, spanId, Span.Kind.CONSUMER, contextName, callable)

    private fun <T> executeIn(
        traceId: String?,
        spanId: String?,
        kind: Span.Kind,
        contextName: String,
        callable: () -> T
    ): T =
        if (traceId == null || spanId == null) logger.warn("Missing trace details, creating a new one").let { callable() }
        else executeIn(TraceContext(traceId, spanId, kind, contextName), callable)

    private fun <T> executeIn(context: TraceContext, callable: () -> T): T =
        try {
            spanAccessor.setValue(buildSpan(context).start())
            contextSnapshotFactory.captureAll().setThreadLocals().use {
                callable()
            }
        } finally {
            spanAccessor.restore()
        }

    private fun buildSpan(context: TraceContext): Span.Builder {
        val traceContext = tracer.traceContextBuilder().traceId(context.traceId).spanId(context.spanId).build()
        return tracer.spanBuilder().setParent(traceContext).kind(context.kind).name(context.name)
    }

    private data class TraceContext(
        val traceId: String, val spanId: String, val kind: Span.Kind, val name: String
    )

    companion object {
        private val logger = LoggerFactory.getLogger(CustomObservation::class.java)
    }
}

fun traceId(): String? = MDC.get("traceId")
fun spanId(): String? = MDC.get("spanId")
e
Following. I tried to do something similar, but propagating the trace across Coroutine-channels. Sadly I haven’t found a solution 😕
t
btw, propagating traces to the coroutine works (for us) in the simple case
Copy code
fun doSomething() {
  // a trace is active here
  CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
                    // the same trace continues here
                }.await()
}
but honestly we did not try more complex scenario, and definitely not using channels or reactive endpoints