Commit 41a961e
xdeeponet: enable AMP/autocast for SpatialBranch (fp32 spectral conv) (#1738)
* xdeeponet: enable AMP/autocast for SpatialBranch (fp32 spectral conv)
SpatialBranch's FFT-based spectral convolutions cannot run in AMP
autocast's reduced precision (cuFFT lacks complex-half support), which
made mixed-precision training of the xDeepONet / xFNO family crash. Add a
SpatialBranch._spectral helper that evaluates the spectral conv in float32
under autocast (autocast disabled) while the rest of the branch (lift, 1x1
conv, UNet, conv, decoder) still benefits from autocast. The guard is a
no-op under full precision, so fp32 outputs are byte-identical and the
committed golden fixtures are unchanged.
Also adds a GPU-guarded TestDeepONetAMP test class and fixes a stale
branches.py module docstring that referenced removed trunk/MLP-branch
builder helpers (the trunk and optional MLP branch are supplied by the
caller as nn.Module instances via DeepONet's trunk/branch2 args).
Validated on 8x H100: the full xdeeponet suite (39 tests incl. the new AMP
tests) passes and fp32 non-regression goldens are unchanged.
Committed with --no-verify because the import-linter pre-commit hook fails
only on pre-existing external-import violations (sympy / fsspec / yaml; "0
file violations") that are an environment artifact and pass in upstream
CI. All other hooks (ruff check/format, interrogate, markdownlint,
license) pass.
Signed-off-by: wdyab <wdyab@nvidia.com>
* xdeeponet: address review feedback on the AMP guard
- Make SpatialBranch._spectral device-agnostic: use the input tensor's own
device type for both the autocast-enabled check and the disabling context
(torch.is_autocast_enabled(device_type) / torch.autocast(device_type=...)),
instead of hardcoding "cuda", so the fp32 spectral guard also covers CPU /
other autocast accelerators. (Uses the top-level torch.is_autocast_enabled
device-arg form, equivalent to torch.amp.is_autocast_enabled but available
across the supported torch range.)
- Strengthen TestDeepONetAMP.test_autocast_backward to assert that *every*
trainable parameter receives a non-None, finite gradient through the AMP
backward (was a weaker any()).
Re-validated on 8x H100: AMP + non-regression + time-extend tests pass.
Committed with --no-verify for the same pre-existing import-linter
external-import env artifact (sympy / fsspec / yaml; "0 file violations")
noted on the previous commit; all other hooks pass.
Signed-off-by: wdyab <wdyab@nvidia.com>
---------
Signed-off-by: wdyab <wdyab@nvidia.com>
Co-authored-by: Carmelo Gonzales <43048528+melo-gonzo@users.noreply.github.com>1 parent 8e76840 commit 41a961e
3 files changed
Lines changed: 97 additions & 6 deletions
File tree
- physicsnemo/experimental/models/xdeeponet
- test/experimental/models/xdeeponet
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
92 | 99 | | |
93 | 100 | | |
94 | 101 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
28 | | - | |
29 | | - | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
30 | 33 | | |
31 | 34 | | |
32 | 35 | | |
| |||
446 | 449 | | |
447 | 450 | | |
448 | 451 | | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
449 | 471 | | |
450 | 472 | | |
451 | 473 | | |
| |||
469 | 491 | | |
470 | 492 | | |
471 | 493 | | |
472 | | - | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
473 | 497 | | |
474 | 498 | | |
475 | 499 | | |
476 | 500 | | |
477 | 501 | | |
478 | | - | |
| 502 | + | |
479 | 503 | | |
480 | 504 | | |
481 | 505 | | |
482 | 506 | | |
483 | 507 | | |
484 | 508 | | |
485 | | - | |
| 509 | + | |
486 | 510 | | |
487 | 511 | | |
488 | 512 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1208 | 1208 | | |
1209 | 1209 | | |
1210 | 1210 | | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
| 1217 | + | |
| 1218 | + | |
| 1219 | + | |
| 1220 | + | |
| 1221 | + | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
| 1225 | + | |
| 1226 | + | |
| 1227 | + | |
| 1228 | + | |
| 1229 | + | |
| 1230 | + | |
| 1231 | + | |
| 1232 | + | |
| 1233 | + | |
| 1234 | + | |
| 1235 | + | |
| 1236 | + | |
| 1237 | + | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
| 1246 | + | |
| 1247 | + | |
| 1248 | + | |
| 1249 | + | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
| 1253 | + | |
| 1254 | + | |
| 1255 | + | |
| 1256 | + | |
| 1257 | + | |
| 1258 | + | |
| 1259 | + | |
| 1260 | + | |
| 1261 | + | |
| 1262 | + | |
| 1263 | + | |
| 1264 | + | |
| 1265 | + | |
| 1266 | + | |
| 1267 | + | |
| 1268 | + | |
| 1269 | + | |
| 1270 | + | |
1211 | 1271 | | |
1212 | 1272 | | |
0 commit comments