Hi all, could somebody please shed some light on h...
# koog-agentic-framework
s
Hi all, could somebody please shed some light on how the Persistency feature works and how to use it within custom strategy graph nodes? For example, let's say I have a standard single-run strategy, but I have my own custom node implementations for said strategy. How could I incorporate the Persistency feature into these nodes to make sure I'm saving the progress somewhere every time a node has executed. In my code, each time a message comes through from a user, I create a new
AIAgent
instance with this strategy graph. What is the method of best practice here to make sure we "pick up from where we left off" for that particular user? I need a way that a user can have a back and fourth conversation, which means storing their conversation state somewhere intermittently and retrieving it when a new user message comes through. I was looking at maybe having a final node in my strategy graph that launches an
llm.readSession
, extracts the prompt, and saves it in my DB. Then every time a user sends a message, the first node in the graph can retrieve their prompt from the DB that includes all past messages and tool calls/results and simply amend their new message to it. But I'm thinking maybe the Persistency feature can address these concerns for me and will prevent me from re-inventing the wheel. Cheers!
a
The persistency feature will automatically handle the snapshot of everything that is persisted per checkpoint (prompt/messages etc), you just need to implement a StorageProvider that works with your DB and call
rollbackToLatestCheckpoint
when executing your agent. You can save checkpoints for specific users by associating the checkpointing
agentId
s with your
userId
s (or their sessions)
👀 1
I think ideally you wouldn't need to spawn multiple AIAgent instances to handle multiple users/agentIds, but I'm not sure if the library currently supports running an agent with a different agentId than it was created with
s
Yeah so currently I have this code:
Copy code
suspend fun processMessage(
    company: String,
    brand: String,
    channelEvent: TextChannelEvent,
    channelInterface: ChannelInterface
): AgentUtterance {
    val awsAccessKeyId = parameterService.getOrThrow("bedrock/access-key-id")
    val awsSecretAccessKey = parameterService.getOrThrow("bedrock/secret-access-key")
    val agent = createAgent(
        company = company,
        brand = brand,
        channelEvent = channelEvent,
        toolFactory = toolFactory,
        promptExecutor = providePromptExecutor(awsAccessKeyId, awsSecretAccessKey)
    )
    try {
        val response = agent.run(channelEvent.userText)
        return AgentUtterance().addElement(TopicMessage("CONCIERGE_RESPONSE", response.body))
    } catch (e: Exception) {
        return AgentUtterance().addElement(TopicMessage("CONCIERGE_RESPONSE", "Sorry, something went wrong. Please try again later."))
    }
}
So I spawn a new
AIAgent
instance for each message that comes through
I'll have a look at what each checkpoint contains. If I can create/retrieve them on
userId
as you say, it should be fairly straightforward. Unless I'm getting the wrong end of the stick here and I should provide each user with their own
AIAgent
instance, so each user uses the same exact one each time as they did previously. So let's say a user's ID is
1234
, then I'd create the
AIAgent
instance like this:
Copy code
return AIAgent(
    promptExecutor = promptExecutor,
    strategy = createStrategy(),
    agentConfig = createAgentConfig(company, brand, channelEvent.channelContext.userId, conversationSessionManager, isNewUser),
    toolRegistry = toolRegistry,
    id = "1234"
)
Then when I call
persistency().rollbackToLatestCheckpoint
it will retrieve the latest checkpoint for that user (I think). I'm not sure how the
id
field is used for `AIAgent`s
and then I'm guessing rather than manually creating/saving a checkpoint at each node, I can do this?
Copy code
install(Persistency) {
    enableAutomaticPersistency = true
}
@Anastasiia Zarechneva @Pavel Gorgulov Could you guys please shed some light on this please 🙏 How do a save the checkpoint for a specific user? And then when that user sends a message in the future, how do I retrieve said checkpoint? How does the Agent ID and Checkpoint ID tie into this?
I must be getting confused here, because this is the code for
persistency().getCheckpointById()
for after I've saved one manually:
Copy code
/**
 * Retrieves a specific checkpoint by ID for the specified agent.
 *
 * @param checkpointId The ID of the checkpoint to retrieve
 * @return The checkpoint data with the specified ID, or null if not found
 */
public suspend fun getCheckpointById(checkpointId: String): AgentCheckpointData? =
    persistencyStorageProvider.getCheckpoints().firstOrNull { it.checkpointId == checkpointId }
That means my custom
StorageProvider
has to implement a function that returns the checkpoints for ALL users, as
persistencyStorageProvider.getCheckpoints()
doesn't accept any parameters for filtering, like a
userID
. There could be millions of checkpoints in there for a large system from all of the user's checkpoints, and this function returns all of them and then filters after? That can't be right, unless I'm misunderstanding something here
Same thing with
persistencyStorageProvider.getLatestCheckpoint()
Why would this function ever be useful? Surely calling that would have the distinct possibility of returning a checkpoint that doesn't belong to the intended user at all?
Unless the intention for Koog is that each user request is handled by the same agent instance each time? For example: User 1 -> Agent 1 User 2 -> Agent 2 User 3 -> Agent 3 When User 1 sends a message, it's always handled by Agent 1, and vice versa for the other users. If that's the case, how do we make sure that each user's request is handled by their associated agent? Because right now I'm just calling this:
Copy code
suspend fun processMessage(
    company: String,
    brand: String,
    channelEvent: TextChannelEvent
): AgentUtterance {
    val awsAccessKeyId = parameterService.getOrThrow("bedrock/access-key-id")
    val awsSecretAccessKey = parameterService.getOrThrow("bedrock/secret-access-key")
    val agent = createAgent(
        company = company,
        brand = brand,
        channelEvent = channelEvent,
        toolFactory = toolFactory,
        promptExecutor = providePromptExecutor(awsAccessKeyId, awsSecretAccessKey)
    )
    try {
        val response = agent.run(channelEvent.userText)
        return AgentUtterance().addElement(TopicMessage("CONCIERGE_RESPONSE", response.body))
    } catch (e: Exception) {
        return AgentUtterance().addElement(TopicMessage("CONCIERGE_RESPONSE", "Sorry, something went wrong. Please try again later."))
    }
}
Copy code
private fun createAgent(
    company: String,
    brand: String,
    channelEvent: TextChannelEvent,
    toolFactory: ToolFactory,
    promptExecutor: PromptExecutor
): AIAgent<String, ConciergeStructuredResponse> {
    val (chatUser: ChatUser, isNewUser) = userService.getOrCreateUser(
        channelEvent.channelContext.userId,
        channelEvent.channelContext.channelType,
        company,
        brand,
        channelEvent
    )
    val toolRegistry = createToolRegistry(company, brand, channelEvent, chatUser, toolFactory)
    return AIAgent(
        promptExecutor = promptExecutor,
        strategy = createStrategy(),
        agentConfig = createAgentConfig(company, brand, channelEvent.channelContext.userId, conversationSessionManager, isNewUser),
        toolRegistry = toolRegistry
    ) {
        install(Persistency)
    }
}
Each time a user message comes through my system... So I'm creating a complete new agent every time. Also, how does the Prompt tie into this? I noticed that each Prompt has an
id
field, which the KDoc mentions is a "unique ID". In what way is this ID unique, and why does it need to be? Is this ID used anywhere else in Koog?
a
Looking at the CheckpointExample in the koog repo, it looks like the pattern is to make the desired ID a member of the storage provider implementation (
class InMemoryPersistencyStorageProvider(private val persistenceId: String) : PersistencyStorageProvider
) then the function implementations would use that ID to filter the DB
This doesnt seem ideal performance-wise though as it means a new provider instance is created for every user. I think having parameters in the provider functions like you mentioned would be a better library design for koog
s
Ahh I see, thank you for pointing out that example
And yeah looking at
InMemoryPersistencyStorageProvider
and how that uses
agentIds
for filtering, I guess you could always just create a class that implements
PersistencyStorageProvider
that accepts a
userId
argument instead, and then the
getLatestCheckpoint
can use that for filtering. But then as you say, that doesn't scale well as you'll need a separate object for each user request
a
I opened a PR to Koog to address this, it just passes thru the
AIAgentContext
to the persistency storage provider methods so you can access the context.agentId or the rest of the context dynamically https://github.com/JetBrains/koog/pull/536
a
@Mark Tkachenko could you please help us here? 🙂
s
Yeah I'm very confused around the concept of referring to an "agent" by an agent ID and it's intended use in Koog 😄 Why would I need to reference anything using that? And also what is being referenced when I use that? Is it the strategy graph? The prompt? Or is the entire agent object being persisted somewhere between sessions or something? For instance, if I call this function:
persistency().getCheckpoints()
That means that the same exact
AIAgent
instance must have been invoked with
AIAgent.run()
multiple times for that function to be useful (unless it's being used between the point
AIAgent.run()
is called and that the strategy graph finishes), which just isn't a pattern I'm familiar with. Not only that, it will also have had to be invoked on behalf of a single user ONLY, to prevent that function from returning checkpoints for multiple users. if that's the case, why are we persisting this exact agent and it's checkpoints? Why don't we just persist all checkpoints no matter the agent or user, then we can have a function that only returns the checkpoints belonging to a single user or some other filtering metric? If so, the specific
AIAgent
instance that was used now doesn't matter. For example, in the case of a user ID, each user has one of those because a user object might contain some state that is mutable and needs to be persisted in a database, so we can then retrieve and modify said state for that user by referencing their ID. What about in the case of agents? What is the state that we're retrieving or referencing when we refer to an agent by it's ID? Unless I'm mistaken here and there are multiple patterns that Koog accommodates. The one I'm using is for each user request that comes through, we spawn the user a fresh, new instance of
AIAgent
. All we have to do then is modify the
Prompt
to include all the messages in the conversation history from previous sessions
a
I'm having similar confusion with the AgentMemory feature, the documentation seems to assume an agent only serves and persists data for a single user
👍 1
s
Yeah it does seem that way...
v
Hi! Really sorry that it caused some confustion, and thanks for sharing your feedback — we’ll definitely work on this shortly to improve the experience. … As for the
AgentMemory
feature — it actually doesn’t assume that the agent persists the data for one single user. There are `MemorySubject`s that are serializable and can be stored and serialized as you want in your instance of the
AgentMemoryProvider
(it’s an interface that you’ll likely implement for your favorite storage, ex: S3 or more likely — a real database):
Copy code
public suspend fun save(fact: Fact, subject: MemorySubject, scope: MemoryScope)
Basically if you have some
User
class — not the one we have in examples (the one in examples is actually a singleton User object — for sure it’s there rather for the educational purposes with an intent of simplification), but the real one that would have a
userId: String
for example:
Copy code
@Serializable
data class User(val userId: String) : MemorySubject() {
            override val name: String = "user"
            override val promptDescription: String =
                "User's preferences, settings, and behavior patterns, expectations from the agent, preferred messaging style, etc."
            override val priorityLevel: Int = 2
}
Then you’ll implement the
AgentMemoryProvider.save
so that it saves the information under the right keys in your database (ex: userId + subject type -> fact) We were planning to implement the out-of-the-box database support in the nearest future but you can already do it by yourself for your use case.
m
@Sam Hello! Thank you for a feedback! Persistence feature doesn’t use agent id for persistence. Basically whole idea was to require PersistencyStorageProvider to retrieve proper checkpoint for particular agent and it’s the responsibility of PersistencyStorageProvider implementation, so we do not need to pass agent-id or something from the callsite. So if you have, for example, Postgres DB and millions of agents (and millions of snapshots) - then the provider is responsible to use proper table with proper indexes to retrieve proper checkpoint for agent. That’s what @Aria made in her PR about SQL persistence for example: https://github.com/JetBrains/koog/pull/481
s
Thanks for the responses! What do you mean by a "particular agent"? I'm very confused about this. Do you mean like strategy graph configuration, because we need to make sure we retrieve checkpoints that are compatible with the strategy graph we're executing?
m
I mean particular agent instance, you’re most likely don’t want to share the same agent between users, for example, because you might experience numerous problems, like context sharing etc.
because we need to make sure we retrieve checkpoints that are compatible with the strategy graph we’re executing
yes, of course if you restore agent from snapshot - you want it to be based on the same graph
s
Ohhh I see. Sorry about that haha, I thought you were referring to these agents like they were living somewhere between sessions or something, my bad 😅 Indeed, for each request that comes through I'm creating a new
AIAgent
instance to handle that request. So yeah, what I'm struggling with is implementing a way for the first node in my strategy graph to "load" where that particular user left off. Also making sure that at each node, I'm updating this state. I know you can do
enableAutomaticPersistency = true
, but I'd like to do it manually at first just so I know what's going on behind the scenes
m
So what you want is to use that API
Copy code
context.withPersistency(context) { ctx ->
        createCheckpoint(
            agentContext = ctx,
            nodeId = "current-node-id",
            lastInput = inputData,
            lastInputType = inputType,
            checkpointId = ctx.runId,
        )
    }
and getLatestCheckpoint/rollbackToCheckpoint when loading the state
s
Can the
nodeId
be just the same as the name of the node, or does it need to be globally unique? If it isn't, how come the node name isn't used here? Because it's optional? Also I'm assuming that
lastInput
and
lastInputType
are taken from the node's Input argument? So in this case:
Copy code
// Node that updates the prompt with the user's message, calls the LLM and adds the resulting response to the prompt
fun AIAgentSubgraphBuilderBase<*, *>.nodeUpdatePromptWithUserMessageAndCallLLM(
    name: String? = null,
    allowToolCalls: Boolean = true
): AIAgentNodeDelegate<String, Message.Response> =
    node(name) { userMessage ->
        llm.writeSession {
            updatePrompt {
                user(userMessage)
            }
            if (allowToolCalls) requestLLM()
            else requestLLMWithoutTools()
        }
    }
lastInput
would be
userMessage
, and
lastInputType
would be
String
? And finally, how can I access
context
within a node? For some reason for me it's not picking it up? Or does it need to be
context()
or
withContext()
, and what's the difference between these. When I do have it, where is
runId
set? Sorry for a lot of questions 😅
m
does it need to be globally unique
graph nodes should have unique in order to use persistency.
lastInput
would be
userMessage
, and
lastInputType
would be
String
yes, that’s correct
And finally, how can I access
context
within a node
context is
this
inside
node { }
builder
s
To what level does it need to be unique? Like a different ID every time a node is executed? If so, how come? Can persistency not use the checkpointId for this?
m
No, not every time node executed, just provide node with uniqie id in builder and that’s enough
s
how do you give a node an ID in the node builder? And what do you mean by uniqueness here? Is it in the sense that no other nodes have that ID, or in the sense that no other node executions have that ID? For example, I could have one node with the ID "test" and another with "test-2", and that would be okay. Or does the ID have to be like
"test-" + UUID.randomUUID().toString()
so that it's never the same twice? i.e. no hard-coded strings allowed pretty much
m
nodeId should be unique inside agent graph, not globally unique, because when we’re restoring the checkpoint - we’re looking into that agent graph, not all agents graphs. You assign nodeId by giving node the name, if you don’t give name - id will be created from variable’s name for that node
s
nice one! thanks 🙂
been a massive help
oh yeah one more thing,
runId
. Does that have to be globally unique? And where do I set that? Or is this only set here in `AIAgent.run()`:
Copy code
override suspend fun run(agentInput: Input): Output {
    runningMutex.withLock {
        if (isRunning) {
            throw IllegalStateException("Agent is already running")
        }

        isRunning = true
    }

    pipeline.prepareFeatures()

    val sessionUuid = Uuid.random()
    val runId = sessionUuid.toString()
Because if so, wouldn't providing
checkpointId = ctx.runId
to a
createCheckpoint()
call lead to the case where multiple checkpoints have the same checkpoint ID?
or is that why
nodeId
is also provided to
createCheckpoint()
as an argument, so we know which checkpoint to return when
getLatestCheckpoint()
is called?