Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions cosmos/operators/_watcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_dbt_event_xcom_key,
get_status_xcom_key,
get_xcom_val,
is_dbt_node_status_failed,
is_dbt_node_status_skipped,
is_dbt_node_status_success,
is_dbt_node_status_terminal,
Expand Down Expand Up @@ -101,6 +102,12 @@


def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None:
"""Push a single dbt log event straight to XCom (one write per event, overwriting).

Retained for callers that do not supply a ``node_event_buffer`` (e.g. the Kubernetes watcher),
which keep the original behaviour. The local watcher uses the buffered path
(``_accumulate_dbt_log_event`` + ``_flush_dbt_event``) instead.
"""
logger.debug("dbt_log: %s", dbt_log)
data = dbt_log.get("data", {})
info = dbt_log.get("info", {})
Expand Down Expand Up @@ -129,6 +136,129 @@ def _process_dbt_log_event(task_instance: Any, dbt_log: dict[str, Any]) -> None:
safe_xcom_push(task_instance=task_instance, key=get_dbt_event_xcom_key(unique_id), value=dbt_event)


# Guards the in-memory per-node event buffer. dbt runner callbacks fire from multiple threads, so
# accumulating into the buffer (and flushing it) must be serialised.
_node_event_buffer_lock = threading.Lock()


def _new_buffer_entry() -> dict[str, Any]:
return {"status": None, "start_time": None, "finish_time": None, "msg": None, "_has_error_msg": False}


def _accumulate_dbt_log_event(node_event_buffer: dict[str, dict[str, Any]], dbt_log: dict[str, Any]) -> None:
"""Merge a single dbt JSON log event into the producer's in-memory per-node buffer.

Called for every allowlisted dbt event during the producer run. Nothing is written to XCom here;
``_flush_dbt_event`` writes the merged event when the node is terminal (and again if a later error
event adds the message). This avoids a per-event XCom write (and global-lock contention) for every
event of every node.

The error text lives in different places per invocation mode: ``run_result.message`` on
``NodeFinished`` (dbt-runner mode) and ``data.msg`` on ``RunResultError`` (subprocess mode). We read
both, and let an error message win so it is not overwritten by a later, often empty, message.
"""
info = dbt_log.get("info", {})
event_name = info.get("name")
if event_name not in _DBT_EVENT_ALLOWLIST:
return
data = dbt_log.get("data", {})
node_info = data.get("node_info") or {}
unique_id = node_info.get("unique_id")
if not unique_id:
return

status = node_info.get("node_status")
start_time = node_info.get("node_started_at")
finish_time = node_info.get("node_finished_at")
run_result = data.get("run_result") or {}
msg = run_result.get("message") or data.get("msg") or info.get("msg")
is_error = (
event_name in _DBT_ERROR_EVENTS_TYPES
or is_dbt_node_status_failed(status)
or is_dbt_node_status_failed(run_result.get("status"))
)

with _node_event_buffer_lock:
entry = node_event_buffer.setdefault(unique_id, _new_buffer_entry())
# Some non-lifecycle events (e.g. RunResultError) carry the literal string "None" as
# node_status; ignore it so it cannot clobber the real terminal status.
if status and status != "None":
entry["status"] = status
if start_time:
entry["start_time"] = _iso_to_string(start_time)
if finish_time:
entry["finish_time"] = _iso_to_string(finish_time)
if msg:
# An error message wins and must not be overwritten by a later (often empty or generic)
# message; a non-error message only fills an as-yet-empty slot.
if is_error:
entry["msg"] = msg
entry["_has_error_msg"] = True
elif not entry["_has_error_msg"]:
entry["msg"] = msg


def _flush_dbt_event(
task_instance: Any,
node_event_buffer: dict[str, dict[str, Any]],
unique_id: str,
terminal_status: str | None = None,
) -> None:
"""Write a node's buffered event to XCom.

Called when the node reaches a terminal status (with ``terminal_status`` so the authoritative status
is stamped -- subprocess's status-bearing ``LogModelResult`` is not allowlisted, so the entry may
otherwise lack one) and again if a later error event adds the message. The entry is intentionally
not removed: it may be re-flushed, and the whole buffer is cleared at the start of ``execute``.
"""
with _node_event_buffer_lock:
entry = node_event_buffer.setdefault(unique_id, _new_buffer_entry())
if terminal_status and terminal_status != "None":
entry["status"] = terminal_status
value = {key: val for key, val in entry.items() if key != "_has_error_msg"}
safe_xcom_push(task_instance=task_instance, key=get_dbt_event_xcom_key(unique_id), value=value)


def _flush_dbt_event_on_error(
task_instance: Any, node_event_buffer: dict[str, dict[str, Any]], dbt_log: dict[str, Any]
) -> None:
"""Re-flush a node when an allowlisted error event carrying a message arrives.

Subprocess mode emits the error text in a ``RunResultError`` event that arrives *after* the terminal
status event already flushed, so this second flush delivers that message to the consumer.
"""
info = dbt_log.get("info", {})
if info.get("name") not in _DBT_ERROR_EVENTS_TYPES:
return
data = dbt_log.get("data", {})
# Mirror the message sources used by _accumulate_dbt_log_event (run_result.message / data.msg /
# info.msg) so an error whose text is only in info.msg still triggers the re-flush.
if not ((data.get("run_result") or {}).get("message") or data.get("msg") or info.get("msg")):
return
unique_id = (data.get("node_info") or {}).get("unique_id")
if unique_id:
_flush_dbt_event(task_instance, node_event_buffer, unique_id)


def _record_dbt_log_event(
node_event_buffer: dict[str, dict[str, Any]] | None, context: Any, dbt_log: dict[str, Any]
) -> None:
"""Record a dbt log event, choosing the buffered path or the original per-event push.

Local watcher (``node_event_buffer`` provided): accumulate in memory; the merged event is flushed
to XCom when the node is terminal, and again here if this is an error event carrying the message
(subprocess emits it after the terminal event). Otherwise (e.g. the Kubernetes watcher): push per event.
"""
ti = context.get("ti") if context else None
if node_event_buffer is not None:
_accumulate_dbt_log_event(node_event_buffer, dbt_log)
if ti is not None:
_flush_dbt_event_on_error(ti, node_event_buffer, dbt_log)
return
if ti:
_process_dbt_log_event(ti, dbt_log)


def _extract_compiled_sql(
project_dir: str, unique_id: str, node_path: str | None, resource_type: str | None
) -> str | None:
Expand Down Expand Up @@ -303,6 +433,7 @@ def store_dbt_resource_status_from_log(
dataset_namespace: str | None = None,
should_generate_model_uris: bool = True,
upstream_failure_skipped_ids: set[str] | None = None,
node_event_buffer: dict[str, dict[str, Any]] | None = None,
) -> None:
"""
Parses a single line from dbt JSON logs and stores node status to Airflow XCom.
Expand Down Expand Up @@ -333,19 +464,21 @@ def store_dbt_resource_status_from_log(
``NodeFinished`` events with ``node_status="skipped"`` for these unique_ids
are rewritten to ``"failed"`` so the consumer sensor fails (and Airflow can
retry it) rather than going SKIPPED. See #2698.
:param node_event_buffer: Mutable in-memory accumulator keyed by node unique_id. Each dbt event
is merged here (no XCom write); the merged event is flushed to XCom when the node reaches a
terminal state and may be re-flushed if a later error event provides the message (e.g. subprocess mode).
This replaces the previous per-event XCom push, which wrote on every event (lock contention) and let a trailing ``NodeFinished`` overwrite the captured error message.
"""
# extra_kwargs is typed Any and may be falsy/None; normalise so the .get() calls below are safe.
extra_kwargs = extra_kwargs or {}
try:
log_line = json.loads(line)
context = extra_kwargs.get("context") if extra_kwargs else None
ti = context.get("ti") if context else None

if ti:
_process_dbt_log_event(ti, log_line)
except json.JSONDecodeError:
_surface_non_json_stdout(line)
log_line = {}
else:
context = extra_kwargs.get("context")
_record_dbt_log_event(node_event_buffer, context, log_line)
if context is not None:
_store_startup_event_from_log(context["ti"], log_line)
Comment thread
pankajastro marked this conversation as resolved.
node_info = log_line.get("data", {}).get("node_info", {})
Expand Down Expand Up @@ -403,6 +536,11 @@ def store_dbt_resource_status_from_log(
context["ti"], project_dir, unique_id, node_info.get("node_path"), node_info.get("resource_type")
)

# Flush this node's buffered structured event now that it is terminal, stamping the
# authoritative terminal status (subprocess's status-bearing event is not allowlisted).
if node_event_buffer is not None and unique_id:
_flush_dbt_event(context["ti"], node_event_buffer, unique_id, terminal_status=dbt_node_status)

# Additionally, log the message from dbt logs
_log_dbt_msg(log_line)

Expand Down
6 changes: 6 additions & 0 deletions cosmos/operators/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# Mutable dict populated lazily from the manifest; shared with the log parser.
self._dataset_namespace: str | None = None
self._model_outlet_uris: dict[str, list[str]] = {}
# In-memory per-node event buffer shared with the log parser. dbt events accumulate here and
# are flushed to XCom at terminal (and may be re-flushed if a later error event carries the message),
# instead of one XCom write per event.
Comment thread
pankajastro marked this conversation as resolved.
self._node_event_buffer: dict[str, dict[str, Any]] = {}
# Mutable set populated by the log parser when dbt emits SkippingDetails
# or LogSkipBecauseError for a node; subsequent "skipped" terminal events
# for those unique_ids are rewritten to "failed" so the consumer sensor
Expand All @@ -262,6 +266,7 @@ def _make_parse_callable(self) -> Callable[[str, Any], None]:
dataset_namespace=self._dataset_namespace,
should_generate_model_uris=self._should_generate_model_uris,
upstream_failure_skipped_ids=self._upstream_failure_skipped_ids,
node_event_buffer=self._node_event_buffer,
)

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -468,6 +473,7 @@ def execute(self, context: Context, **kwargs: Any) -> Any:
)
self._model_outlet_uris.clear()
self._upstream_failure_skipped_ids.clear()
self._node_event_buffer.clear()

task_instance = context.get("ti")
if task_instance is None:
Expand Down
120 changes: 119 additions & 1 deletion tests/operators/_watcher/test_watcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import pytest
from airflow.exceptions import AirflowException, AirflowSkipException

from cosmos.operators._watcher.base import BaseConsumerSensor, _process_dbt_log_event
from cosmos.operators._watcher.base import (
BaseConsumerSensor,
_accumulate_dbt_log_event,
_flush_dbt_event,
_flush_dbt_event_on_error,
_process_dbt_log_event,
store_dbt_resource_status_from_log,
)
from cosmos.operators.local import DbtRunLocalOperator


Expand Down Expand Up @@ -274,3 +281,114 @@ def test_no_status_increments_poke_retry(self):

assert result is False
assert sensor.poke_retry_number == 1


class TestNodeEventBuffer:
"""Tests for the in-memory per-node event buffer (local watcher log shipping)."""

UID = "model.pkg.my_model"

def _runner_node_finished(self, node_status="error", message="DB Error: table missing"):
"""dbt-runner mode: NodeFinished carries the error in run_result.message."""
return {
"info": {"name": "NodeFinished", "msg": ""},
"data": {
"node_info": {
"unique_id": self.UID,
"node_status": node_status,
"node_started_at": "2024-01-01T00:00:00",
"node_finished_at": "2024-01-01T00:01:00",
},
"run_result": {"status": node_status, "message": message},
},
}

def _run_result_error(self, msg="DB Error: table missing"):
"""subprocess mode: the error arrives in a RunResultError event (node_status='None')."""
return {
"info": {"name": "RunResultError"},
"data": {"node_info": {"unique_id": self.UID, "node_status": "None"}, "msg": msg},
}

def _flushed_value(self, buffer, **kwargs):
with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:
_flush_dbt_event(Mock(), buffer, self.UID, **kwargs)
return mock_push.call_args.kwargs["value"]

def test_runner_mode_captures_error_from_run_result_message(self):
"""dbt-runner mode: the error in NodeFinished.run_result.message reaches the consumer."""
buffer: dict = {}
_accumulate_dbt_log_event(buffer, self._runner_node_finished(message="DB Error: table missing"))
value = self._flushed_value(buffer, terminal_status="error")
assert value["status"] == "error"
assert value["msg"] == "DB Error: table missing"
assert value["start_time"] == "00:00:00" and value["finish_time"] == "00:01:00"
assert "_has_error_msg" not in value # internal flag stripped before push

def test_subprocess_mode_captures_error_from_run_result_error_event(self):
"""subprocess mode: terminal status flushes first; the later RunResultError re-flush adds the msg."""
buffer: dict = {}
# Terminal status event (LogModelResult-equivalent) flushes first -- status stamped, no msg yet.
first = self._flushed_value(buffer, terminal_status="error")
assert first["status"] == "error" and first["msg"] is None
# RunResultError arrives afterwards with the error text.
err = self._run_result_error(msg="DB Error: table missing")
_accumulate_dbt_log_event(buffer, err)
with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:
_flush_dbt_event_on_error(Mock(), buffer, err)
value = mock_push.call_args.kwargs["value"]
assert value["status"] == "error" # not clobbered to "None"
assert value["msg"] == "DB Error: table missing"

def test_error_message_not_overwritten_by_later_empty_message(self):
buffer: dict = {}
_accumulate_dbt_log_event(buffer, self._run_result_error(msg="real error"))
_accumulate_dbt_log_event(buffer, self._runner_node_finished(message="")) # later, empty
value = self._flushed_value(buffer, terminal_status="error")
assert value["msg"] == "real error"

def test_none_string_status_is_ignored(self):
buffer: dict = {}
_accumulate_dbt_log_event(buffer, self._run_result_error(msg="err")) # node_status == "None"
assert buffer[self.UID]["status"] is None

@pytest.mark.parametrize("extra_kwargs", [None, {}])
def test_store_dbt_resource_status_from_log_tolerates_falsy_extra_kwargs(self, extra_kwargs):
"""A falsy/None extra_kwargs must not raise during parsing (no context to push to)."""
line = '{"info": {"name": "NodeFinished"}, "data": {"node_info": {"unique_id": "model.p.m", "node_status": "success"}}}'
# Should be a no-op (no context/ti) rather than raising AttributeError.
store_dbt_resource_status_from_log(line, extra_kwargs, node_event_buffer={})

def test_accumulate_ignores_non_allowlisted_and_missing_unique_id(self):
buffer: dict = {}
_accumulate_dbt_log_event(
buffer, {"info": {"name": "LogStartLine"}, "data": {"node_info": {"unique_id": self.UID}}}
)
_accumulate_dbt_log_event(buffer, {"info": {"name": "NodeFinished"}, "data": {"node_info": {}}})
assert buffer == {}

def test_flush_on_error_ignores_non_error_and_messageless_events(self):
buffer: dict = {}
with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:
# Not an error event -> no flush.
_flush_dbt_event_on_error(
Mock(), buffer, {"info": {"name": "NodeFinished"}, "data": {"node_info": {"unique_id": self.UID}}}
)
# Error event but no message -> no flush.
_flush_dbt_event_on_error(
Mock(), buffer, {"info": {"name": "RunResultError"}, "data": {"node_info": {"unique_id": self.UID}}}
)
mock_push.assert_not_called()

def test_flush_on_error_when_message_only_in_info_msg(self):
"""An allowlisted error event whose text is only in info.msg must still re-flush."""
buffer: dict = {}
event = {
"info": {"name": "RunResultError", "msg": "the error"},
"data": {"node_info": {"unique_id": self.UID, "node_status": "None"}},
}
_accumulate_dbt_log_event(buffer, event)
with patch("cosmos.operators._watcher.base.safe_xcom_push") as mock_push:
_flush_dbt_event_on_error(Mock(), buffer, event)
assert mock_push.called
assert mock_push.call_args.kwargs["value"]["msg"] == "the error"
Loading