Skip to content

Commit 7a7f79a

Browse files
feat(openai): native content-block streaming for chat completions
1 parent f7d7e0b commit 7a7f79a

6 files changed

Lines changed: 636 additions & 0 deletions

File tree

libs/partners/openai/langchain_openai/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from langchain_openai._version import __version__
44
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
55
from langchain_openai.chat_models._client_utils import StreamChunkTimeoutError
6+
from langchain_openai.chat_models._stream_events import (
7+
aconvert_openai_completions_stream,
8+
convert_openai_completions_stream,
9+
)
610
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
711
from langchain_openai.llms import AzureOpenAI, OpenAI
812
from langchain_openai.tools import custom_tool
@@ -16,5 +20,7 @@
1620
"OpenAIEmbeddings",
1721
"StreamChunkTimeoutError",
1822
"__version__",
23+
"aconvert_openai_completions_stream",
24+
"convert_openai_completions_stream",
1925
"custom_tool",
2026
]
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Native content-block streaming-event converter for OpenAI Chat Completions.
2+
3+
Drives raw OpenAI Chat Completions chunks into the shared `BlockStreamTracker`,
4+
reusing `BaseChatOpenAI._convert_chunk_to_generation_chunk` for content
5+
extraction (it already yields indexed content blocks + tool-call chunks). This
6+
converter is the reuse seam for OpenAI-compatible providers (deepseek, groq,
7+
fireworks, xai, openrouter), which adapt their chunk shape to OpenAI's and call
8+
it with a different `provider`.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from collections.abc import Callable
14+
from typing import TYPE_CHECKING, Any
15+
16+
from langchain_core.language_models.stream_events import (
17+
BlockStreamTracker,
18+
accumulate_usage,
19+
build_message_finish,
20+
iter_protocol_blocks,
21+
)
22+
from langchain_core.messages import AIMessageChunk
23+
24+
if TYPE_CHECKING:
25+
from collections.abc import AsyncIterator, Iterator
26+
27+
from langchain_core.outputs import ChatGenerationChunk
28+
from langchain_protocol.protocol import (
29+
MessageMetadata,
30+
MessagesData,
31+
MessageStartData,
32+
)
33+
34+
# Bound `BaseChatOpenAI._convert_chunk_to_generation_chunk`.
35+
MakeChunk = Callable[..., "ChatGenerationChunk | None"]
36+
37+
38+
def _message_start(
39+
message_id: str | None, model: str | None, provider: str
40+
) -> MessageStartData:
41+
metadata: MessageMetadata = {"provider": provider}
42+
if model:
43+
metadata["model"] = model
44+
return {
45+
"event": "message-start",
46+
"role": "ai",
47+
"id": message_id or "",
48+
"metadata": metadata,
49+
}
50+
51+
52+
def convert_openai_completions_stream(
53+
raw: Iterator[Any],
54+
make_chunk: MakeChunk,
55+
*,
56+
base_generation_info: dict[str, Any] | None = None,
57+
message_id: str | None = None,
58+
provider: str = "openai",
59+
) -> Iterator[MessagesData]:
60+
"""Convert a raw OpenAI Chat Completions chunk stream to protocol events.
61+
62+
Args:
63+
raw: Raw OpenAI chunks (dicts or SDK objects with `model_dump`).
64+
make_chunk: `BaseChatOpenAI._convert_chunk_to_generation_chunk`, injected
65+
so the converter stays pure and unit-testable.
66+
base_generation_info: Passed to `make_chunk` for the first chunk only
67+
(mirrors `_stream`), `{}` thereafter.
68+
message_id: Message id for `message-start`. Left empty by default so
69+
the v3 stream's seeded LangChain run id stands (matching the compat
70+
bridge); the provider completion id is deliberately not used here.
71+
provider: `model_provider` id for downstream reuse (groq, deepseek, ...).
72+
73+
Yields:
74+
Protocol `MessagesData` lifecycle events.
75+
"""
76+
tracker = BlockStreamTracker()
77+
started = False
78+
usage: dict[str, Any] | None = None
79+
response_metadata: dict[str, Any] = {"model_provider": provider}
80+
model: str | None = None
81+
first = True
82+
83+
for chunk in raw:
84+
if not isinstance(chunk, dict):
85+
chunk = chunk.model_dump()
86+
if model is None and chunk.get("model"):
87+
model = chunk["model"]
88+
gen = make_chunk(chunk, AIMessageChunk, base_generation_info if first else {})
89+
first = False
90+
if gen is None:
91+
continue
92+
msg = gen.message
93+
if not started:
94+
started = True
95+
yield _message_start(message_id, model, provider)
96+
for key, block in iter_protocol_blocks(msg):
97+
yield from tracker.feed(key, block)
98+
usage_metadata = getattr(msg, "usage_metadata", None)
99+
if usage_metadata:
100+
usage = accumulate_usage(usage, usage_metadata)
101+
merged = {**(gen.generation_info or {}), **(msg.response_metadata or {})}
102+
if merged:
103+
response_metadata.update(merged)
104+
# `_convert_chunk_to_generation_chunk` hardcodes
105+
# `model_provider="openai"`; re-apply the caller's `provider` so
106+
# OpenAI-compatible reuse (groq, deepseek, ...) isn't mislabeled.
107+
response_metadata["model_provider"] = provider
108+
109+
if not started:
110+
return
111+
yield from tracker.finish_all()
112+
yield build_message_finish(usage=usage, response_metadata=response_metadata)
113+
114+
115+
async def aconvert_openai_completions_stream(
116+
raw: AsyncIterator[Any],
117+
make_chunk: MakeChunk,
118+
*,
119+
base_generation_info: dict[str, Any] | None = None,
120+
message_id: str | None = None,
121+
provider: str = "openai",
122+
) -> AsyncIterator[MessagesData]:
123+
"""Async twin of `convert_openai_completions_stream`. `make_chunk` is sync."""
124+
tracker = BlockStreamTracker()
125+
started = False
126+
usage: dict[str, Any] | None = None
127+
response_metadata: dict[str, Any] = {"model_provider": provider}
128+
model: str | None = None
129+
first = True
130+
131+
async for chunk in raw:
132+
if not isinstance(chunk, dict):
133+
chunk = chunk.model_dump()
134+
if model is None and chunk.get("model"):
135+
model = chunk["model"]
136+
gen = make_chunk(chunk, AIMessageChunk, base_generation_info if first else {})
137+
first = False
138+
if gen is None:
139+
continue
140+
msg = gen.message
141+
if not started:
142+
started = True
143+
yield _message_start(message_id, model, provider)
144+
for key, block in iter_protocol_blocks(msg):
145+
for ev in tracker.feed(key, block):
146+
yield ev
147+
usage_metadata = getattr(msg, "usage_metadata", None)
148+
if usage_metadata:
149+
usage = accumulate_usage(usage, usage_metadata)
150+
merged = {**(gen.generation_info or {}), **(msg.response_metadata or {})}
151+
if merged:
152+
response_metadata.update(merged)
153+
# `_convert_chunk_to_generation_chunk` hardcodes
154+
# `model_provider="openai"`; re-apply the caller's `provider` so
155+
# OpenAI-compatible reuse (groq, deepseek, ...) isn't mislabeled.
156+
response_metadata["model_provider"] = provider
157+
158+
if not started:
159+
return
160+
for ev in tracker.finish_all():
161+
yield ev
162+
yield build_message_finish(usage=usage, response_metadata=response_metadata)

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
LanguageModelInput,
5757
ModelProfileRegistry,
5858
)
59+
from langchain_core.language_models._compat_bridge import (
60+
achunks_to_events,
61+
chunks_to_events,
62+
)
5963
from langchain_core.language_models.chat_models import (
6064
BaseChatModel,
6165
LangSmithParams,
@@ -149,11 +153,16 @@
149153
_convert_from_v1_to_responses,
150154
_convert_to_v03_ai_message,
151155
)
156+
from langchain_openai.chat_models._stream_events import (
157+
aconvert_openai_completions_stream,
158+
convert_openai_completions_stream,
159+
)
152160
from langchain_openai.data._profiles import _PROFILES
153161

154162
if TYPE_CHECKING:
155163
import httpx
156164
from langchain_core.language_models import ModelProfile
165+
from langchain_protocol.protocol import MessagesData
157166
from openai.types.responses import Response
158167

159168
logger = logging.getLogger(__name__)
@@ -1911,6 +1920,147 @@ async def _astream(
19111920
)
19121921
yield generation_chunk
19131922

1923+
def _stream_chat_model_events(
1924+
self,
1925+
messages: list[BaseMessage],
1926+
stop: list[str] | None = None,
1927+
run_manager: CallbackManagerForLLMRun | None = None,
1928+
*,
1929+
message_id: str | None = None,
1930+
**kwargs: Any,
1931+
) -> Iterator[MessagesData]:
1932+
"""Emit OpenAI-native content-block events for the Chat Completions path.
1933+
1934+
Defers to the compat bridge for cases this converter does not yet
1935+
specialize: the Responses API, structured output (`response_format`),
1936+
and raw-header mode. Detected by core's `_iter_v2_events`.
1937+
"""
1938+
# Responses API / structured output / raw headers: bridge over `_stream`,
1939+
# which (on `ChatOpenAI`) routes to the Responses path when applicable.
1940+
# `response_format` may arrive via call kwargs or be baked into
1941+
# `model_kwargs`; both fold into the payload, so check both.
1942+
if (
1943+
self._use_responses_api({**kwargs, **self.model_kwargs})
1944+
or kwargs.get("response_format") is not None
1945+
or self.model_kwargs.get("response_format") is not None
1946+
or self.include_response_headers
1947+
):
1948+
# Forward kwargs untouched (as core's `_iter_v2_events` would):
1949+
# `_stream` handles `stream_usage` itself, and the Responses path
1950+
# rejects a stray `stream_usage` kwarg, so we must not inject one.
1951+
yield from chunks_to_events(
1952+
self._stream(
1953+
messages,
1954+
stop=stop,
1955+
run_manager=run_manager,
1956+
**kwargs,
1957+
),
1958+
message_id=message_id,
1959+
)
1960+
return
1961+
1962+
self._ensure_sync_client_available()
1963+
kwargs["stream"] = True
1964+
stream_usage = self._should_stream_usage(
1965+
kwargs.pop("stream_usage", None), **kwargs
1966+
)
1967+
if stream_usage:
1968+
kwargs["stream_options"] = {"include_usage": stream_usage}
1969+
payload = self._get_request_payload(messages, stop=stop, **kwargs)
1970+
try:
1971+
with self.client.create(**payload) as response:
1972+
for event in convert_openai_completions_stream(
1973+
response,
1974+
self._convert_chunk_to_generation_chunk,
1975+
message_id=message_id,
1976+
):
1977+
if (
1978+
run_manager is not None
1979+
and event["event"] == "content-block-delta"
1980+
and event["delta"].get("type") == "text-delta"
1981+
):
1982+
# Text-only by design on the v3 events path: the events
1983+
# themselves carry block/usage detail, so the legacy
1984+
# `chunk=`/`logprobs=` callback args are not threaded.
1985+
run_manager.on_llm_new_token(
1986+
str(event["delta"].get("text", ""))
1987+
)
1988+
yield event
1989+
except openai.BadRequestError as e:
1990+
_handle_openai_bad_request(e)
1991+
except openai.APIError as e:
1992+
_handle_openai_api_error(e)
1993+
1994+
async def _astream_chat_model_events(
1995+
self,
1996+
messages: list[BaseMessage],
1997+
stop: list[str] | None = None,
1998+
run_manager: AsyncCallbackManagerForLLMRun | None = None,
1999+
*,
2000+
message_id: str | None = None,
2001+
**kwargs: Any,
2002+
) -> AsyncIterator[MessagesData]:
2003+
"""Async twin of `_stream_chat_model_events`."""
2004+
if (
2005+
self._use_responses_api({**kwargs, **self.model_kwargs})
2006+
or kwargs.get("response_format") is not None
2007+
or self.model_kwargs.get("response_format") is not None
2008+
or self.include_response_headers
2009+
):
2010+
# Forward kwargs untouched (as core's `_aiter_v2_events` would):
2011+
# `_astream` handles `stream_usage` itself, and the Responses path
2012+
# rejects a stray `stream_usage` kwarg, so we must not inject one.
2013+
async for event in achunks_to_events(
2014+
self._astream(
2015+
messages,
2016+
stop=stop,
2017+
run_manager=run_manager,
2018+
**kwargs,
2019+
),
2020+
message_id=message_id,
2021+
):
2022+
yield event
2023+
return
2024+
2025+
kwargs["stream"] = True
2026+
stream_usage = self._should_stream_usage(
2027+
kwargs.pop("stream_usage", None), **kwargs
2028+
)
2029+
if stream_usage:
2030+
kwargs["stream_options"] = {"include_usage": stream_usage}
2031+
payload = self._get_request_payload(messages, stop=stop, **kwargs)
2032+
try:
2033+
response = await self.async_client.create(**payload)
2034+
async with response as stream:
2035+
# Mirror `_astream`: apply per-chunk stall protection before the
2036+
# converter consumes the stream.
2037+
timed_stream = _astream_with_chunk_timeout(
2038+
stream,
2039+
self.stream_chunk_timeout,
2040+
model_name=self.model_name,
2041+
)
2042+
async for event in aconvert_openai_completions_stream(
2043+
timed_stream,
2044+
self._convert_chunk_to_generation_chunk,
2045+
message_id=message_id,
2046+
):
2047+
if (
2048+
run_manager is not None
2049+
and event["event"] == "content-block-delta"
2050+
and event["delta"].get("type") == "text-delta"
2051+
):
2052+
# Text-only by design on the v3 events path: the events
2053+
# themselves carry block/usage detail, so the legacy
2054+
# `chunk=`/`logprobs=` callback args are not threaded.
2055+
await run_manager.on_llm_new_token(
2056+
str(event["delta"].get("text", ""))
2057+
)
2058+
yield event
2059+
except openai.BadRequestError as e:
2060+
_handle_openai_bad_request(e)
2061+
except openai.APIError as e:
2062+
_handle_openai_api_error(e)
2063+
19142064
async def _agenerate(
19152065
self,
19162066
messages: list[BaseMessage],

libs/partners/openai/tests/unit_tests/chat_models/test_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,13 +635,38 @@ def mock_create(*args: Any, **kwargs: Any) -> MockSyncContextManager:
635635
events = list(llm.stream_events("你的名字叫什么?只回答名字", version="v3"))
636636

637637
assert_valid_event_stream(events)
638+
# `message-start` carries the stream's LangChain run id (threaded from core),
639+
# not the provider completion id and not an empty string.
640+
assert events[0]["event"] == "message-start"
641+
assert events[0]["id"]
642+
assert not events[0]["id"].startswith("chatcmpl")
638643
# At minimum, a text block with the accumulated answer.
639644
finishes = [e for e in events if e["event"] == "content-block-finish"]
640645
assert len(finishes) >= 1
641646
text_finishes = [f for f in finishes if f["content"]["type"] == "text"]
642647
assert len(text_finishes) == 1
643648

644649

650+
async def test_openai_astream_events_v3_lifecycle(mock_openai_completion: list) -> None:
651+
"""Async twin of `test_openai_stream_events_v3_lifecycle`."""
652+
from langchain_tests.utils.stream_lifecycle import assert_valid_event_stream
653+
654+
llm = ChatOpenAI(model="gpt-4o", api_key=SecretStr("test"))
655+
mock_client = MagicMock()
656+
657+
async def mock_acreate(*args: Any, **kwargs: Any) -> MockAsyncContextManager:
658+
return MockAsyncContextManager(mock_openai_completion)
659+
660+
mock_client.create = mock_acreate
661+
with patch.object(llm, "async_client", mock_client):
662+
stream = await llm.astream_events("test", version="v3")
663+
events = [e async for e in stream]
664+
665+
assert_valid_event_stream(events)
666+
finishes = [e for e in events if e["event"] == "content-block-finish"]
667+
assert len([f for f in finishes if f["content"]["type"] == "text"]) == 1
668+
669+
645670
@pytest.fixture
646671
def mock_completion() -> dict:
647672
return {

0 commit comments

Comments
 (0)