feat: add dtype option to LightningModule.load_from_checkpoint#21791
feat: add dtype option to LightningModule.load_from_checkpoint#21791gaurav0107 wants to merge 2 commits into
dtype option to LightningModule.load_from_checkpoint#21791Conversation
Adds an optional `dtype` argument to `LightningModule.load_from_checkpoint` (plumbed through `_load_from_checkpoint`) that casts the restored model's floating-point parameters and buffers to the requested dtype at load time, mirroring how `map_location` controls device placement. Defaults to None, which preserves the existing behavior. Non-floating-point tensors are left unchanged by `torch.nn.Module.to`. Fixes Lightning-AI#20833
|
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## master #21791 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 270 267 -3
Lines 24010 23953 -57
=========================================
- Hits 20787 18832 -1955
- Misses 3223 5121 +1898 |
deependujha
left a comment
There was a problem hiding this comment.
I'm not very sure if it has any real UX improvement: check comment
|
Thanks for taking a look, and fair points — let me answer honestly rather than defend the diff as-is. You're right on two counts. Terminology: this is a dtype/precision cast, not quantization. And on the training path it's genuinely redundant — Where I think a case remains is the narrow no-Trainer inference/sampling path, which is what @arijit-hub originally asked for in #20833: load a checkpoint straight into a target dtype instead of a manual post-load Being candid about the strongest honest version of that, and where the current diff falls short:
So, back to you: would you be open to it scoped strictly to the no-Trainer inference path, implemented as a load-time cast (real memory win) with a docstring/example spelling out that use-case and pointing training users to |
|
@gaurav0107 thanks for your patience. Went through the code quickly and it makes sense to have it. Thanks again! |
What does this PR do?
Adds an optional
dtypeargument toLightningModule.load_from_checkpoint(plumbed through the internal
_load_from_checkpoint) so that the restoredmodel's floating-point parameters and buffers can be cast to a target
torch.dtypeat load time — for example loading an fp32/bf16-trained modeldirectly as
float16for faster inference/sampling, without a separate manual.to(dtype)call. This mirrors how the existingmap_locationargument controlsdevice placement.
The change is strictly opt-in:
dtypedefaults toNone, which preserves the current behavior exactly.torch.nn.Module.to(device=..., dtype=...),which only casts floating-point tensors and leaves integer buffers unchanged.
LightningModuleonly;LightningDataModulecarries notensors to cast.
Downcasting (e.g.
float32->float16) reduces numerical precision, so thebehavior is documented in the docstring and left entirely to the user's choice.
Fixes #20833
Before submitting
dtypeoption forload_from_checkpoint#20833.test_load_from_checkpoint_dtype, CPU-only)ruff check/ruff format --checkpass locally; the full pytest suite was not run in my environment (no GPU/torch dev install) and is left to CI.None.PR review
Anyone in the community is welcome to review the PR.
Reviewer checklist