thanksforallthefish
04/02/2025, 8:08 AM// producer thread, saves its trace information within the message
producer.produce(message.withTraceDetails())
// consumer thread, actually even another application instance
// up to this consumer.read, its an independent trace
val message = consumer.read()
processAfterRestoring(message.traceDetails()) {
// here the trace is restored from the message
process(message)
}
fun processAfterRestoring(sourceTrace: Trace, supplier: () -> Any)
The code we have works in normal scenarios, but not when process(message)
uses coroutines, eg
// here we have TraceONE
val message = consumer.read()
processAfterRestoring(message.traceDetails()) {
// here the trace is restored from the message
process(message)
}
fun process(message: Message) {
// here the trace is restored from the message
launchNewCoroutine {
// but here the trace is again TraceONE
doSomethingOn(message)
}
}
fun processAfterRestoring(sourceTrace: Trace, supplier: () -> Any)
I reproduce some more code (in thread, hopefully I captured all the relevant pieces).
I guess the 3 questions that pop up to me are:
• are we restoring traces incorrectly (not kotlin)?
• is there some issue with our Spring context config (also not kotlin)?
• what magic am I not seeing in observationRegistry.asContextElement()
that causes the issue (since when we don't use coroutines, restoring works I am kinda discounting the first 2 questions and focusing more on this last one)?thanksforallthefish
04/02/2025, 8:09 AMimport io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
import io.micrometer.context.ContextExecutorService
import io.micrometer.context.ContextRegistry
import io.micrometer.context.ContextSnapshotFactory
import io.micrometer.context.ThreadLocalAccessor
import io.micrometer.core.instrument.kotlin.asContextElement
import io.micrometer.observation.ObservationRegistry
import io.micrometer.tracing.Span
import io.micrometer.tracing.Tracer
import io.micrometer.tracing.contextpropagation.ObservationAwareSpanThreadLocalAccessor
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test
import org.slf4j.LoggerFactory
import org.slf4j.MDC
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.web.client.TestRestTemplate
import org.springframework.boot.test.web.client.getForObject
import org.springframework.boot.test.web.server.LocalServerPort
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.stereotype.Component
import org.springframework.test.context.TestConstructor
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ThreadPoolExecutor
import kotlin.use
@SpringBootTest(
webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
)
@TestConstructor(autowireMode = TestConstructor.AutowireMode.ALL)
class ObservationIT(
@LocalServerPort private val port: Int = 0,
@Value("\${server.servlet.context-path}") private val contextPath: String,
private val testRestTemplate: TestRestTemplate
) {
private val baseUrl = "<http://localhost>:$port$contextPath"
@Test
fun `coroutine vs java completable future`() {
val traceId = testRestTemplate.getForObject<String>("$baseUrl/api/v1/observations").shouldNotBeNull()
// Conventions:
// - runningXYZ: this is the trace that received the HTTP call
// - restoredXYZ: this is the trace the code is executed in (after it has been restored)
// - restoredBeforeCoroutine: special case, this is the trace AFTER it has been restored, but before execution is delegated to the coroutine
// - restoredInCoroutine: special case, this is the trace within the coroutine code
val (restoredCoroutine, runningCoroutine) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-coroutine/$traceId").shouldNotBeNull()
val (restoredBeforeCoroutine, restoredInCoroutine) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-coroutine-weirdness/$traceId").shouldNotBeNull()
val (restoredCompletable, runningCompletable) = testRestTemplate.getForObject<Pair<String, String>>("$baseUrl/api/v1/observations/consume-future/$traceId").shouldNotBeNull()
// this is expected, since the rest interface is the boundary for a new observation
traceId.traceId()
.shouldNotBe(runningCoroutine.traceId())
.shouldNotBe(runningCompletable.traceId())
traceId.traceId()
.shouldBe(restoredBeforeCoroutine.traceId())
// actually, we expect that the 3 traces should be the same, but `observationRegistry.asContextElement()` itself restored to the trace received by the controller
.shouldNotBe(restoredInCoroutine.traceId())
traceId.traceId()
.shouldBe(restoredCompletable.traceId())
// actually, we expect that the 3 traces should be the same, but `observationRegistry.asContextElement()` itself restored to the trace received by the controller
.shouldNotBe(restoredCoroutine.traceId())
}
private fun String.traceId() = this.split("/")[0]
private fun String.spanId() = this.split("/")[1]
}
@RestController
@RequestMapping("/api/v1/observations")
class ObservationControllerTest(
private val observation: CustomObservation,
private val observationRegistry: ObservationRegistry,
private val eventsExecutorService: ExecutorService,
) {
// Generates a trace to be restored afterwards
@GetMapping
fun getTraceToRestore() = "${context()?.traceId()}/${context()?.spanId()}"
@GetMapping("/consume-future/{traceId}/{spanId}")
fun consumeFutureInRestoredTrace(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
// new trace
observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
// trace is restored from path params
CompletableFuture.supplyAsync({
// trace is still the restored one
fullTrace()
}, eventsExecutorService).join()
} to fullTrace()
@GetMapping("/consume-coroutine/{traceId}/{spanId}")
fun consumeCoroutineInRestoredTrace(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
// new trace (TraceOne)
observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
// trace is restored from path params
runBlocking {
// trace is back to TraceOne
CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
fullTrace()
}.await()
}
} to fullTrace()
@GetMapping("/consume-coroutine-weirdness/{traceId}/{spanId}")
fun consumeCoroutineInRestoredTraceWeirdness(@PathVariable traceId: String, @PathVariable spanId: String): Pair<String, String> =
observation.consumeRestoringCorrelation(traceId, spanId, "restored") {
fullTrace() to
runBlocking {
CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
fullTrace()
}.await()
}
}
private fun fullTrace() = "${traceId()}/${spanId()}"
private fun context() = observationRegistry.currentSpan()?.context()
}
@Configuration
class ObservationConfig {
@Bean
fun eventsExecutorService(): ExecutorService =
ContextExecutorService.wrap(
Executors.newVirtualThreadPerTaskExecutor(),
ContextSnapshotFactory.builder().contextRegistry(ContextRegistry.getInstance()).build()
)
@Bean
fun contextAccessor(tracer: Tracer, observationRegistry: ObservationRegistry) =
ObservationAwareSpanThreadLocalAccessor(observationRegistry, tracer).also {
ContextRegistry.getInstance().registerThreadLocalAccessor(it)
}
}
@Component
class CustomObservation(
private val observationRegistry: ObservationRegistry,
private val tracer: Tracer,
private val spanAccessor: ThreadLocalAccessor<Span>,
) {
private val contextSnapshotFactory = ContextSnapshotFactory.builder().contextRegistry(ContextRegistry.getInstance()).build()
fun <T> consumeRestoringCorrelation(traceId: String?, spanId: String?, contextName: String, callable: () -> T): T =
executeIn(traceId, spanId, Span.Kind.CONSUMER, contextName, callable)
private fun <T> executeIn(
traceId: String?,
spanId: String?,
kind: Span.Kind,
contextName: String,
callable: () -> T
): T =
if (traceId == null || spanId == null) logger.warn("Missing trace details, creating a new one").let { callable() }
else executeIn(TraceContext(traceId, spanId, kind, contextName), callable)
private fun <T> executeIn(context: TraceContext, callable: () -> T): T =
try {
spanAccessor.setValue(buildSpan(context).start())
contextSnapshotFactory.captureAll().setThreadLocals().use {
callable()
}
} finally {
spanAccessor.restore()
}
private fun buildSpan(context: TraceContext): Span.Builder {
val traceContext = tracer.traceContextBuilder().traceId(context.traceId).spanId(context.spanId).build()
return tracer.spanBuilder().setParent(traceContext).kind(context.kind).name(context.name)
}
private data class TraceContext(
val traceId: String, val spanId: String, val kind: Span.Kind, val name: String
)
companion object {
private val logger = LoggerFactory.getLogger(CustomObservation::class.java)
}
}
fun traceId(): String? = MDC.get("traceId")
fun spanId(): String? = MDC.get("spanId")
Emil Kantis
04/02/2025, 4:01 PMthanksforallthefish
04/04/2025, 6:45 AMfun doSomething() {
// a trace is active here
CoroutineScope(eventsExecutorService.asCoroutineDispatcher()).async(observationRegistry.asContextElement()) {
// the same trace continues here
}.await()
}
but honestly we did not try more complex scenario, and definitely not using channels or reactive endpoints