-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathsticky_sessions.py
More file actions
216 lines (182 loc) · 7.06 KB
/
Copy pathsticky_sessions.py
File metadata and controls
216 lines (182 loc) · 7.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Sticky sessions: route consecutive requests in a conversation to the same provider.
Maintains a mapping of conversation_id -> (provider, model) with a configurable TTL.
Ensures conversation continuity by preferring the same provider that handled
previous requests in the same conversation.
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from threading import Lock
from typing import Any
logger = logging.getLogger(__name__)
DEFAULT_SESSION_TTL = 1800 # 30 minutes
CLEANUP_INTERVAL = 300 # 5 minutes
@dataclass
class SessionEntry:
"""A sticky session binding a conversation to a provider."""
conversation_id: str
provider: str
model: str
created_at: float
last_used_at: float
request_count: int = 0
ttl: float = DEFAULT_SESSION_TTL
@property
def is_expired(self) -> bool:
return time.time() - self.last_used_at > self.ttl
def touch(self) -> None:
self.last_used_at = time.time()
self.request_count += 1
def to_dict(self) -> dict[str, Any]:
return {
"conversation_id": self.conversation_id,
"provider": self.provider,
"model": self.model,
"created_at": self.created_at,
"last_used_at": self.last_used_at,
"request_count": self.request_count,
"ttl": self.ttl,
"expires_in": max(0, self.ttl - (time.time() - self.last_used_at)),
}
class StickySessionManager:
"""Manages sticky sessions for conversation continuity.
When a conversation starts, the chosen provider is recorded.
Subsequent requests with the same conversation_id will prefer
the same provider, unless the session has expired.
"""
def __init__(self, default_ttl: float = DEFAULT_SESSION_TTL) -> None:
self._sessions: dict[str, SessionEntry] = {}
self._lock = Lock()
self._default_ttl = default_ttl
self._cleanup_task: asyncio.Task | None = None
def get(
self, conversation_id: str,
) -> tuple[str | None, str | None]:
"""Get the (provider, model) for a conversation.
Returns (None, None) if no active session exists.
"""
if not conversation_id:
return None, None
with self._lock:
entry = self._sessions.get(conversation_id)
if not entry:
return None, None
if entry.is_expired:
del self._sessions[conversation_id]
logger.debug("Session expired: %s", conversation_id)
return None, None
entry.touch()
return entry.provider, entry.model
def set(
self,
conversation_id: str,
provider: str,
model: str,
ttl: float | None = None,
) -> None:
"""Create or update a sticky session."""
if not conversation_id:
return
now = time.time()
with self._lock:
existing = self._sessions.get(conversation_id)
if existing:
existing.provider = provider
existing.model = model
existing.touch()
else:
self._sessions[conversation_id] = SessionEntry(
conversation_id=conversation_id,
provider=provider,
model=model,
created_at=now,
last_used_at=now,
ttl=ttl or self._default_ttl,
)
def remove(self, conversation_id: str) -> bool:
"""Remove a sticky session. Returns True if session existed."""
with self._lock:
return self._sessions.pop(conversation_id, None) is not None
def get_all(self) -> list[dict[str, Any]]:
"""Get all active sessions as dicts."""
with self._lock:
active = []
expired = []
for cid, entry in self._sessions.items():
if entry.is_expired:
expired.append(cid)
else:
active.append(entry.to_dict())
for cid in expired:
del self._sessions[cid]
return active
def get_stats(self) -> dict[str, Any]:
"""Get session statistics."""
with self._lock:
total = len(self._sessions)
active = sum(1 for e in self._sessions.values() if not e.is_expired)
total_requests = sum(e.request_count for e in self._sessions.values())
return {
"total_sessions": total,
"active_sessions": active,
"expired_sessions": total - active,
"total_requests": total_requests,
"default_ttl": self._default_ttl,
}
def cleanup_expired(self) -> int:
"""Remove all expired sessions. Returns count removed."""
with self._lock:
expired = [cid for cid, e in self._sessions.items() if e.is_expired]
for cid in expired:
del self._sessions[cid]
if expired:
logger.info("Cleaned up %d expired sessions", len(expired))
return len(expired)
def start_cleanup_loop(self) -> None:
"""Start background cleanup task."""
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop_cleanup(self) -> None:
"""Stop the background cleanup task."""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
async def _cleanup_loop(self) -> None:
"""Periodically clean up expired sessions."""
while True:
try:
self.cleanup_expired()
except Exception as e:
logger.error("Session cleanup error: %s", e)
await asyncio.sleep(CLEANUP_INTERVAL)
def extract_conversation_id(self, messages: list[dict[str, Any]]) -> str:
"""Extract or generate a conversation ID from messages.
Strategy:
1. Check for explicit conversation_id in the request
2. Hash the first message content as a conversation identifier
3. Return empty string if no messages
"""
if not messages:
return ""
# Use hash of first user message as conversation ID
for msg in messages:
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str) and content:
return f"conv_{hash(content) & 0xFFFFFFFFFFFFFFFF:x}"
elif isinstance(content, list):
# Multimodal content
text_parts = [
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
]
combined = " ".join(text_parts)
if combined:
return f"conv_{hash(combined) & 0xFFFFFFFFFFFFFFFF:x}"
return ""
# Global singleton
sticky_sessions = StickySessionManager()