Skip to content

feat: add dtype option to LightningModule.load_from_checkpoint#21791

Open
gaurav0107 wants to merge 2 commits into
Lightning-AI:masterfrom
gaurav0107:fix/20833-add-a-dtype-option-for-load-from-checkpo
Open

feat: add dtype option to LightningModule.load_from_checkpoint#21791
gaurav0107 wants to merge 2 commits into
Lightning-AI:masterfrom
gaurav0107:fix/20833-add-a-dtype-option-for-load-from-checkpo

Conversation

@gaurav0107

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds an optional dtype argument to LightningModule.load_from_checkpoint
(plumbed through the internal _load_from_checkpoint) so that the restored
model's floating-point parameters and buffers can be cast to a target
torch.dtype at load time — for example loading an fp32/bf16-trained model
directly as float16 for faster inference/sampling, without a separate manual
.to(dtype) call. This mirrors how the existing map_location argument controls
device placement.

The change is strictly opt-in:

  • dtype defaults to None, which preserves the current behavior exactly.
  • When provided, the cast is applied via torch.nn.Module.to(device=..., dtype=...),
    which only casts floating-point tensors and leaves integer buffers unchanged.
  • It is added for LightningModule only; LightningDataModule carries no
    tensors to cast.

Downcasting (e.g. float32 -> float16) reduces numerical precision, so the
behavior is documented in the docstring and left entirely to the user's choice.

Fixes #20833

Before submitting
  • Was this discussed/agreed via a GitHub issue? Yes — requested in Add a dtype option for load_from_checkpoint #20833.
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (docstring updated)
  • Did you write any new necessary tests? (test_load_from_checkpoint_dtype, CPU-only)
  • Did you verify new and existing tests pass locally with your changes? ruff check/ruff format --check pass locally; the full pytest suite was not run in my environment (no GPU/torch dev install) and is left to CI.
  • Did you list all the breaking changes introduced by this pull request? None — additive, default None.
  • Did you update the CHANGELOG?

PR review

Anyone in the community is welcome to review the PR.

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Adds an optional `dtype` argument to `LightningModule.load_from_checkpoint`
(plumbed through `_load_from_checkpoint`) that casts the restored model's
floating-point parameters and buffers to the requested dtype at load time,
mirroring how `map_location` controls device placement. Defaults to None,
which preserves the existing behavior. Non-floating-point tensors are left
unchanged by `torch.nn.Module.to`.

Fixes Lightning-AI#20833
@gaurav0107 gaurav0107 marked this pull request as ready for review June 28, 2026 21:54
@codecov-commenter

codecov-commenter commented Jun 28, 2026

Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 79%. Comparing base (4819088) to head (413ce07).
✅ All tests successful. No failed tests found.
❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

❗ There is a different number of reports uploaded between BASE (4819088) and HEAD (413ce07). Click for more details.

HEAD has 178 uploads less than BASE
Flag BASE (4819088) HEAD (413ce07)
python3.10 6 3
cpu 84 40
lightning 30 15
pytest 42 0
python 6 3
lightning_fabric 27 0
python3.12 24 11
python3.12.7 18 9
python3.11 12 6
python3.13 18 8
pytorch2.8 6 5
pytorch_lightning 27 25
pytest-full 42 40
pytorch2.9 6 5
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #21791     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         270      267      -3     
  Lines       24010    23953     -57     
=========================================
- Hits        20787    18832   -1955     
- Misses       3223     5121   +1898     

@deependujha deependujha left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very sure if it has any real UX improvement: check comment

@gaurav0107

Copy link
Copy Markdown
Contributor Author

Thanks for taking a look, and fair points — let me answer honestly rather than defend the diff as-is.

You're right on two counts. Terminology: this is a dtype/precision cast, not quantization. And on the training path it's genuinely redundant — Trainer(precision=...) owns dtype/autocast there, and I wouldn't want a second, differently-scoped way to set that.

Where I think a case remains is the narrow no-Trainer inference/sampling path, which is what @arijit-hub originally asked for in #20833: load a checkpoint straight into a target dtype instead of a manual post-load .to(dtype).

Being candid about the strongest honest version of that, and where the current diff falls short:

  • As pure convenience you're correct — dtype= vs .to(dtype) is the same one-liner, and that alone isn't worth a new argument.
  • The real motivation is peak memory: for a large model, materializing every param in the checkpoint's fp32 and then casting keeps a full fp32 copy alive at the moment of the cast, whereas casting weights as they're loaded never holds the full-precision model. For big checkpoints that's the difference between fitting and OOMing, and a post-hoc .to(dtype) can't give it to you.
  • But this diff casts after load (model.to(device, dtype)), so as written it does NOT yet deliver that — it's exactly the one-liner you describe. If we keep it, I'd move the cast into the load path so it actually earns its place.

So, back to you: would you be open to it scoped strictly to the no-Trainer inference path, implemented as a load-time cast (real memory win) with a docstring/example spelling out that use-case and pointing training users to Trainer(precision=...)? Or if a third dtype entry point still isn't worth the API surface even then, I'm happy to close it — genuinely your call either way.

@deependujha

Copy link
Copy Markdown
Collaborator

@gaurav0107 thanks for your patience. Went through the code quickly and it makes sense to have it. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a dtype option for load_from_checkpoint

3 participants