Skip to content

feat: support stage config overrides via with_#2086

Merged
praateekmahajan merged 8 commits into
NVIDIA-NeMo:mainfrom
praateekmahajan:codex/stage-spec-overrides
Jun 23, 2026
Merged

feat: support stage config overrides via with_#2086
praateekmahajan merged 8 commits into
NVIDIA-NeMo:mainfrom
praateekmahajan:codex/stage-spec-overrides

Conversation

@praateekmahajan

@praateekmahajan praateekmahajan commented Jun 16, 2026

Copy link
Copy Markdown
Contributor
  1. Adds with_ overrides for ray_stage_spec, xenna_stage_spec, and num_workers. Stage spec overrides are shallow-merged, with user-provided keys winning; a typed unset sentinel preserves the difference between omitted and explicit None.
  2. Applies num_workers to Ray Data task stages via TaskPoolStrategy; actor stages continue to use ActorPoolStrategy.
  3. Lets Xenna use stage.num_workers() for cluster-wide worker counts, rejects xenna_stage_spec["num_workers"], and raises if a stage sets both num_workers() and Xenna num_workers_per_node.
  4. Reserves num_workers as the ProcessingStage backend worker hook and renames the ASR dataloader setting to dataloader_num_workers so stage-local worker counts do not shadow num_workers().
  5. Moves single-input/fanout defaults to num_workers() == 1 for file partitioning, URL generation, PDF partitioning, cluster-wise pairwise partitioning, audio manifest reader/writer, audio segment extraction, and ALM long-form manifest reading. PDF partitioning now also advertises itself as a Ray fanout stage.
  6. No deduplication/removal stage behavior was removed; the deduplication change is limited to semantic pairwise partitioning using generic num_workers() instead of Xenna-only num_workers_per_node.
  7. Keeps PyAnnote existing xenna_num_workers constructor option as a compatibility alias, but routes it through num_workers() instead of xenna_stage_spec.
  8. Adds coverage for with_ backend overrides, reserved num_workers enforcement, Xenna worker fallback/rejection, Ray Data task worker sizing, and updated singleton-stage defaults.

Signed-off-by: Praateek <praateekm@gmail.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@praateekmahajan praateekmahajan changed the title feat: add stage worker override controls feat: support *_stage_spec and num_workers being overridden with_ Jun 16, 2026
…e-spec-overrides

Signed-off-by: Praateek <praateekm@gmail.com>
@praateekmahajan praateekmahajan marked this pull request as ready for review June 22, 2026 16:59
@praateekmahajan

Copy link
Copy Markdown
Contributor Author

@claude review

@greptile-apps

greptile-apps Bot commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR unifies backend worker sizing under a single num_workers() method and a with_() override API across all three backends (Ray Data, Ray Actor, Xenna), replacing the previous per-backend xenna_stage_spec()[\"num_workers\"] pattern. It also promotes several fanout/source stages from per-node to cluster-wide singleton defaults.

  • with_() overrides: ProcessingStage.with_() gains ray_stage_spec, xenna_stage_spec, and num_workers parameters with shallow-merge semantics; a typed _UnsetType sentinel distinguishes "not passed" from explicit None.
  • Backend routing: XennaExecutor now reads stage.num_workers() for StageSpec.num_workers and rejects "num_workers" in xenna_stage_spec(); RayDataStageAdapter applies TaskPoolStrategy(size=num_workers) for task stages instead of silently ignoring the value.
  • Field rename in ASR stages: num_workers dataclass field renamed to dataloader_num_workers in BaseASRProcessorStage, NeMoASRAlignerStage, and SplitASRAlignJoinStage to avoid conflict with the new guard that blocks num_workers as a class-level attribute.

Confidence Score: 4/5

Safe to merge after fixing the tutorial notebook that now uses the rejected xenna_stage_spec["num_workers"] pattern.

The new XennaExecutor validation that rejects "num_workers" in xenna_stage_spec() is correct and well-tested, but tutorials/synthetic/nemo_data_designer/ndd_data_generation_example.ipynb still monkey-patches xenna_stage_spec to return {"num_workers": 2}, which will raise ValueError when the tutorial is executed against a Xenna backend. All other changes — the with_() merge API, TaskPoolStrategy wiring, field rename in ASR stages — are internally consistent and covered by tests.

tutorials/synthetic/nemo_data_designer/ndd_data_generation_example.ipynb needs to be updated to ndd_stage = ndd_stage.with_(num_workers=2) before this PR ships.

Important Files Changed

Filename Overview
nemo_curator/stages/base.py Core change: adds with_ overrides for ray_stage_spec, xenna_stage_spec, and num_workers, plus a __init_subclass__ guard that blocks num_workers as a class attribute; guard has a narrow gap for num_workers = None without annotation.
nemo_curator/backends/xenna/executor.py Routes num_workers() to StageSpec.num_workers and rejects num_workers inside xenna_stage_spec(); logic is clean but the validation now breaks the tutorial notebook that monkey-patches xenna_stage_spec.
nemo_curator/backends/ray_data/adapter.py Replaces the ignore-with-warning path for task-stage num_workers with TaskPoolStrategy(size=num_workers); straightforward and covered by tests.
nemo_curator/stages/audio/tagging/inference/nemo_asr_align.py Renames num_workers dataclass field to dataloader_num_workers to avoid conflict with the new num_workers() sizing method; all internal references updated correctly.
tutorials/synthetic/nemo_data_designer/ndd_data_generation_example.ipynb Still uses the now-rejected pattern of setting num_workers inside xenna_stage_spec via monkey-patching; will raise ValueError at runtime when executed against Xenna.
tests/backends/test_xenna_executor.py New test file covering num_workers routing, num_workers_per_node coexistence, and rejection of num_workers in xenna_stage_spec; good coverage of the new validation paths.
tests/stages/common/test_base.py Adds tests for with_ backend spec overrides (merge semantics, chaining, None override, rejection of num_workers in xenna_stage_spec) and the new num_workers attribute guard.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User
    participant Stage as ProcessingStage
    participant With as with_()
    participant Ray as RayDataStageAdapter
    participant Xenna as XennaExecutor

    User->>Stage: "stage.with_(num_workers=4, ray_stage_spec={...})"
    Stage->>With: deepcopy(self)
    With->>With: merge ray_stage_spec (user keys win)
    With->>With: "set num_workers = _num_workers_method(4)"
    With-->>User: new_instance

    User->>Ray: process_dataset(dataset)
    Ray->>Stage: stage.num_workers() → 4
    Ray->>Ray: "TaskPoolStrategy(size=4) for task stage"
    Ray->>Ray: ActorPoolStrategy for actor stage

    User->>Xenna: execute([stage])
    Xenna->>Stage: stage.xenna_stage_spec()
    Xenna->>Xenna: reject if num_workers in spec
    Xenna->>Stage: stage.num_workers() → 4
    Xenna->>Xenna: "StageSpec(num_workers=4, num_workers_per_node=None)"
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User
    participant Stage as ProcessingStage
    participant With as with_()
    participant Ray as RayDataStageAdapter
    participant Xenna as XennaExecutor

    User->>Stage: "stage.with_(num_workers=4, ray_stage_spec={...})"
    Stage->>With: deepcopy(self)
    With->>With: merge ray_stage_spec (user keys win)
    With->>With: "set num_workers = _num_workers_method(4)"
    With-->>User: new_instance

    User->>Ray: process_dataset(dataset)
    Ray->>Stage: stage.num_workers() → 4
    Ray->>Ray: "TaskPoolStrategy(size=4) for task stage"
    Ray->>Ray: ActorPoolStrategy for actor stage

    User->>Xenna: execute([stage])
    Xenna->>Stage: stage.xenna_stage_spec()
    Xenna->>Xenna: reject if num_workers in spec
    Xenna->>Stage: stage.num_workers() → 4
    Xenna->>Xenna: "StageSpec(num_workers=4, num_workers_per_node=None)"
Loading

Reviews (6): Last reviewed commit: "Merge remote-tracking branch 'upstream/m..." | Re-trigger Greptile

Comment thread nemo_curator/stages/base.py Outdated

@VibhuJawa VibhuJawa left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Praateek <praateekm@gmail.com>
Signed-off-by: Praateek <praateekm@gmail.com>
@praateekmahajan praateekmahajan changed the title feat: support *_stage_spec and num_workers being overridden with_ feat: support stage config overrides via with_ Jun 22, 2026
…errides

Signed-off-by: Praateek <praateekm@gmail.com>
@praateekmahajan

Copy link
Copy Markdown
Contributor Author

/ok to test 4f5d147

@weijiac0619

Copy link
Copy Markdown
Contributor

just out of curiosity, what happened before when num_workers() == 1 didnt apply to ray data for single-input/fanout stages?

def xenna_stage_spec(self) -> dict[str, Any]:
return {"num_workers_per_node": 1}
def ray_stage_spec(self) -> dict[str, object]:
return {RayStageSpecKeys.IS_FANOUT_STAGE: True}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: do we set IS_FANOUT_STAGE manually by ourselves now in the pipeline? is there any logic we should follow to set IS_FANOUT_STAGE? did we missed setting it to true for pdf partitioning before this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we missed it before this PR. I know @sarahyurick is working with an external contributor to make it dynamic but i haven't taken a close look at it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does setting IS_FANOUT_STAGE make that stage much faster? is there any benchmark result to share if possible?

@praateekmahajan

Copy link
Copy Markdown
Contributor Author

@weijiac0619 ray-data anyway starts leniently from 1 "concurrent" worker so we never had the problem in case of single-input stages for ray data, only for xenna we had the issue

…errides

Signed-off-by: Praateek <praateekm@gmail.com>

# Conflicts:
#	nemo_curator/stages/interleaved/pdf/nemotron_parse/partitioning.py
#	tests/stages/interleaved/pdf/nemotron_parse/test_stages.py
@praateekmahajan

Copy link
Copy Markdown
Contributor Author

/ok to test c06bcda

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants