Skip to content

Commit 65d749e

Browse files
Gkrumbach07claude
andcommitted
feat: allow live model switching on running agentic sessions (RHOAIENG-56044)
Enable users to change the LLM model on running sessions without stopping/restarting. Model changes take effect on the next AG-UI run, preserving conversation history. Backend Changes: - Remove Running/Creating phase restriction for llmSettings updates in UpdateSession - Add model validation against manifest and runner provider compatibility - Pass current model to runner via X-Current-Model and X-Current-Model-Vertex-ID headers - Resolve Vertex AI model IDs from manifest when needed Runner Changes: - Extract model from request headers on each /run invocation - Inject model into forwarded_props to override env var for that run - Support Vertex AI model ID override via LLM_MODEL_VERTEX_ID env var Frontend Changes: - Add UpdateAgenticSessionRequest type with llmSettings field - Add updateSession() API function for PATCH /agentic-sessions/:name - (UI implementation deferred - users can update via API directly) Architecture: - Next-run-only approach: no pod restart needed, leverages existing forwarded_props mechanism - Backend resolves Vertex ID from ConfigMap, consistent with pod creation flow - No context handoff: conversation history replayed naturally (works for same-family switches) References: RHOAIENG-56044, GitHub Issue #1090 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent d1117d8 commit 65d749e

6 files changed

Lines changed: 158 additions & 11 deletions

File tree

components/backend/handlers/sessions.go

100644100755
Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,17 +1215,57 @@ func UpdateSession(c *gin.Context) {
12151215
return
12161216
}
12171217

1218-
// Prevent spec changes while session is running or being created
1218+
// Prevent most spec changes while session is running or being created
1219+
// Exception: llmSettings.model can be updated live (takes effect on next run)
12191220
if status, ok := item.Object["status"].(map[string]interface{}); ok {
12201221
if phase, ok := status["phase"].(string); ok {
12211222
if strings.EqualFold(phase, "Running") || strings.EqualFold(phase, "Creating") {
1222-
c.JSON(http.StatusConflict, gin.H{
1223-
"error": "Cannot modify session specification while the session is running",
1224-
"phase": phase,
1225-
})
1226-
return
1223+
// Check if request is ONLY updating llmSettings (live model switching)
1224+
isOnlyModelUpdate := req.LLMSettings != nil &&
1225+
req.InitialPrompt == nil &&
1226+
req.DisplayName == nil &&
1227+
req.Timeout == nil
1228+
1229+
if !isOnlyModelUpdate {
1230+
c.JSON(http.StatusConflict, gin.H{
1231+
"error": "Cannot modify session specification while the session is running (except llmSettings)",
1232+
"phase": phase,
1233+
})
1234+
return
1235+
}
1236+
log.Printf("Live model update on Running session %s: %+v", sessionName, req.LLMSettings)
1237+
}
1238+
}
1239+
}
1240+
1241+
// Validate model if being updated
1242+
if req.LLMSettings != nil && req.LLMSettings.Model != "" {
1243+
// Get runner type from existing session spec to determine provider
1244+
spec := item.Object["spec"].(map[string]interface{})
1245+
runnerTypeID := DefaultRunnerType // default
1246+
if envVars, ok := spec["environmentVariables"].(map[string]interface{}); ok {
1247+
if rt, ok := envVars["RUNNER_TYPE"].(string); ok && rt != "" {
1248+
runnerTypeID = rt
12271249
}
12281250
}
1251+
1252+
// Resolve provider for validation
1253+
runnerProvider := ""
1254+
if rt, rtErr := GetRuntime(runnerTypeID); rtErr == nil {
1255+
runnerProvider = rt.Provider
1256+
} else {
1257+
log.Printf("WARNING: could not resolve runner type %q from registry: %v", runnerTypeID, rtErr)
1258+
}
1259+
1260+
// Validate model is available for this runner's provider
1261+
k8sClt, _ := GetK8sClientsForRequest(c)
1262+
if !isModelAvailable(c.Request.Context(), k8sClt, req.LLMSettings.Model, runnerProvider, project) {
1263+
c.JSON(http.StatusBadRequest, gin.H{
1264+
"error": "Model is not available for this runner type",
1265+
"model": req.LLMSettings.Model,
1266+
})
1267+
return
1268+
}
12291269
}
12301270

12311271
// Update spec

components/backend/websocket/agui_proxy.go

100644100755
Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,12 @@ func HandleAGUIRunProxy(c *gin.Context) {
367367
log.Printf("Run with per-user credentials for %s/%s", projectName, sessionName)
368368
}
369369

370+
// Fetch current model from session spec for live model switching
371+
currentModel := getCurrentModel(projectName, sessionName)
372+
currentModelVertexID := getCurrentModelVertexID(projectName, sessionName, currentModel)
373+
370374
// Start background goroutine to proxy runner SSE → persist + broadcast
371-
go proxyRunnerStream(runnerURL, bodyBytes, sessionName, runID, threadID, currentUserID, currentUserName, callerToken)
375+
go proxyRunnerStream(runnerURL, bodyBytes, sessionName, runID, threadID, currentUserID, currentUserName, callerToken, currentModel, currentModelVertexID)
372376

373377
// Return metadata immediately — events arrive via GET /agui/events
374378
c.JSON(http.StatusOK, gin.H{
@@ -382,13 +386,14 @@ func HandleAGUIRunProxy(c *gin.Context) {
382386
// a background goroutine so the POST /agui/run handler can return immediately.
383387
// If userID is provided, forwards user context headers for credential scoping.
384388
// callerToken is the original user's bearer token for per-user credential requests.
385-
func proxyRunnerStream(runnerURL string, bodyBytes []byte, sessionName, runID, threadID, userID, userName, callerToken string) {
389+
// currentModel and currentModelVertexID are forwarded to enable live model switching.
390+
func proxyRunnerStream(runnerURL string, bodyBytes []byte, sessionName, runID, threadID, userID, userName, callerToken, currentModel, currentModelVertexID string) {
386391
logSuffix := ""
387392
if userID != "" {
388393
logSuffix = fmt.Sprintf(" (user=%s)", userID)
389394
}
390395
log.Printf("AGUI Proxy: connecting to runner at %s%s", runnerURL, logSuffix)
391-
resp, err := connectToRunner(runnerURL, bodyBytes, userID, userName, callerToken)
396+
resp, err := connectToRunner(runnerURL, bodyBytes, userID, userName, callerToken, currentModel, currentModelVertexID)
392397
if err != nil {
393398
log.Printf("AGUI Proxy: runner unavailable for %s: %v", sessionName, err)
394399
// Publish error events so GET /agui/events subscribers see the failure
@@ -784,7 +789,7 @@ var runnerHTTPClient = &http.Client{
784789
// container startup time and K8s Service DNS propagation. Retries on
785790
// "connection refused", "no such host", and "dial tcp" errors with
786791
// exponential backoff (500ms initial, 1.5x, capped at 5s, 15 attempts).
787-
func connectToRunner(runnerURL string, bodyBytes []byte, userID, userName, callerToken string) (*http.Response, error) {
792+
func connectToRunner(runnerURL string, bodyBytes []byte, userID, userName, callerToken, currentModel, currentModelVertexID string) (*http.Response, error) {
788793
maxAttempts := 15
789794
retryDelay := 500 * time.Millisecond
790795
maxDelay := 5 * time.Second
@@ -809,6 +814,13 @@ func connectToRunner(runnerURL string, bodyBytes []byte, userID, userName, calle
809814
if callerToken != "" {
810815
req.Header.Set("X-Caller-Token", callerToken)
811816
}
817+
// Forward current model for live model switching
818+
if currentModel != "" {
819+
req.Header.Set("X-Current-Model", currentModel)
820+
}
821+
if currentModelVertexID != "" {
822+
req.Header.Set("X-Current-Model-Vertex-ID", currentModelVertexID)
823+
}
812824

813825
resp, err := runnerHTTPClient.Do(req)
814826
if err == nil {
@@ -1290,3 +1302,52 @@ func HandleTaskList(c *gin.Context) {
12901302
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
12911303
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), body)
12921304
}
1305+
1306+
// ─── Live Model Switching Helpers ────────────────────────────────────
1307+
1308+
// getCurrentModel fetches the current model from the session spec.
1309+
// Returns empty string if unable to read (runner will fall back to env var).
1310+
func getCurrentModel(projectName, sessionName string) string {
1311+
if handlers.DynamicClient == nil {
1312+
return ""
1313+
}
1314+
1315+
gvr := handlers.GetAgenticSessionV1Alpha1Resource()
1316+
obj, err := handlers.DynamicClient.Resource(gvr).Namespace(projectName).Get(
1317+
context.Background(), sessionName, metav1.GetOptions{},
1318+
)
1319+
if err != nil {
1320+
return ""
1321+
}
1322+
1323+
model, _, _ := unstructured.NestedString(obj.Object, "spec", "llmSettings", "model")
1324+
return model
1325+
}
1326+
1327+
// getCurrentModelVertexID resolves the Vertex AI model ID for the given model.
1328+
// Returns empty string if Vertex ID mapping is not needed or unavailable.
1329+
func getCurrentModelVertexID(projectName, sessionName, model string) string {
1330+
if model == "" {
1331+
return ""
1332+
}
1333+
1334+
// Load model manifest from ConfigMap mount
1335+
manifestPath := handlers.ManifestPath()
1336+
manifest, err := handlers.LoadManifest(manifestPath)
1337+
if err != nil {
1338+
log.Printf("WARNING: failed to load model manifest for Vertex ID resolution: %v", err)
1339+
return ""
1340+
}
1341+
1342+
// Find model entry and extract Vertex ID
1343+
for _, entry := range manifest.Models {
1344+
if entry.ID == model && entry.Available {
1345+
if entry.VertexID != "" {
1346+
log.Printf("Resolved Vertex ID for model %q: %s", model, entry.VertexID)
1347+
}
1348+
return entry.VertexID
1349+
}
1350+
}
1351+
1352+
return ""
1353+
}

components/frontend/src/services/api/sessions.ts

100644100755
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import type {
1212
ListAgenticSessionsPaginatedResponse,
1313
StopAgenticSessionRequest,
1414
StopAgenticSessionResponse,
15+
UpdateAgenticSessionRequest,
1516
CloneAgenticSessionRequest,
1617
CloneAgenticSessionResponse,
1718
PaginationParams,
@@ -188,6 +189,20 @@ export async function getSessionPodEvents(
188189
return apiClient.get(`/projects/${projectName}/agentic-sessions/${sessionName}/pod-events`);
189190
}
190191

192+
/**
193+
* Update session spec (supports live model switching)
194+
*/
195+
export async function updateSession(
196+
projectName: string,
197+
sessionName: string,
198+
updates: UpdateAgenticSessionRequest
199+
): Promise<AgenticSession> {
200+
return apiClient.put<AgenticSession, UpdateAgenticSessionRequest>(
201+
`/projects/${projectName}/agentic-sessions/${sessionName}`,
202+
updates
203+
);
204+
}
205+
191206
/**
192207
* Update the display name of a session
193208
*/

components/frontend/src/types/api/sessions.ts

100644100755
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ export type StopAgenticSessionResponse = {
180180
message: string;
181181
};
182182

183+
export type UpdateAgenticSessionRequest = {
184+
llmSettings?: Partial<LLMSettings>;
185+
displayName?: string;
186+
initialPrompt?: string;
187+
timeout?: number;
188+
};
189+
183190
export type CloneAgenticSessionRequest = {
184191
targetProject: string;
185192
newSessionName: string;

components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py

100644100755
Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,29 @@ async def run(
140140
current_user_id: str = "",
141141
current_user_name: str = "",
142142
caller_token: str = "",
143+
current_model: str = "",
144+
current_model_vertex_id: str = "",
143145
) -> AsyncIterator[BaseEvent]:
144-
"""Full run lifecycle: initialize → session worker → tracing."""
146+
"""Full run lifecycle: initialize → session worker → tracing.
147+
148+
Live model switching: current_model and current_model_vertex_id from
149+
the backend override the env var model for this run.
150+
"""
145151
thread_id = input_data.thread_id or (self._context.session_id if self._context else "")
146152

147153
await self._initialize_run(thread_id, current_user_id, current_user_name, caller_token)
148154

155+
# Live model switching: inject current model into forwarded_props
156+
if current_model:
157+
if input_data.forwarded_props is None:
158+
input_data.forwarded_props = {}
159+
input_data.forwarded_props["model"] = current_model
160+
logger.info(f"Live model switch: using {current_model}")
161+
# Update env var for Vertex AI mapping (used by setup_sdk_authentication)
162+
if current_model_vertex_id:
163+
os.environ["LLM_MODEL_VERTEX_ID"] = current_model_vertex_id
164+
logger.info(f"Vertex ID set: {current_model_vertex_id}")
165+
149166
from ag_ui_claude_sdk.utils import process_messages
150167

151168
user_msg, _ = process_messages(input_data)

components/runners/ambient-runner/ambient_runner/endpoints/run.py

100644100755
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,18 @@ async def run_agent(input_data: RunnerInput, request: Request):
6868
# The caller's bearer token — used for credential requests so each user
6969
# can only access their own credentials (no BOT_TOKEN impersonation).
7070
caller_token = request.headers.get("x-caller-token", "")
71+
# Extract current model for live model switching
72+
current_model = request.headers.get("x-current-model", "")
73+
current_model_vertex_id = request.headers.get("x-current-model-vertex-id", "")
7174
if current_user_id:
7275
from ambient_runner.platform.auth import sanitize_user_context
7376

7477
current_user_id, current_user_name = sanitize_user_context(
7578
current_user_id, current_user_name
7679
)
7780
logger.info(f"Run user context: {current_user_id}")
81+
if current_model:
82+
logger.info(f"Run with live model: {current_model}")
7883

7984
logger.info(
8085
f"Run: thread_id={run_agent_input.thread_id}, run_id={run_agent_input.run_id}"
@@ -87,6 +92,8 @@ async def event_stream():
8792
current_user_id=current_user_id,
8893
current_user_name=current_user_name,
8994
caller_token=caller_token,
95+
current_model=current_model,
96+
current_model_vertex_id=current_model_vertex_id,
9097
):
9198
try:
9299
yield encoder.encode(event)

0 commit comments

Comments
 (0)