Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 128 additions & 22 deletions docs/source-pytorch/accelerators/accelerator_prepare.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,27 @@ See :ref:`replace-sampler-ddp` for more information.
Synchronize validation and test logging
***************************************

When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes.
This is done by adding ``sync_dist=True`` to all ``self.log`` calls in the validation and test step. This will automatically average values across all processes.
This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers.
The ``sync_dist`` option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.
When running in distributed mode, each rank runs ``validation_step`` and ``test_step`` on its own
shard of the data. Without explicit synchronization, the value your logger persists is rank 0's
local value — computed on just ``1 / world_size`` of the validation or test set. That is the
metric your :class:`~lightning.pytorch.callbacks.ModelCheckpoint` and
:class:`~lightning.pytorch.callbacks.EarlyStopping` callbacks see, so an unsynchronized metric
can silently pick the wrong checkpoint.

Note if you use any built in metrics or custom metrics that use `TorchMetrics <https://torchmetrics.readthedocs.io/>`_, these do not need to be updated and are automatically handled for you.
Lightning gives you three tools to fix this, and they are **not interchangeable**:

- ``sync_dist=True`` — mean-reduces a scalar across ranks. Correct only for averageable metrics.
- `TorchMetrics <https://torchmetrics.readthedocs.io/>`__ — syncs the metric's internal *state*, then computes. Correct for non-averageable metrics such as F1 or AUC.
- :meth:`~lightning.pytorch.core.LightningModule.all_gather` — gathers raw tensors across ranks so you can compute any reduction yourself.

Pick the lightest tool that fits the metric. If you accumulate per-step outputs and compute a
custom metric in ``on_validation_epoch_end`` (or ``on_test_epoch_end``), jump to
:ref:`manual-all-gather` — that is the pattern most DDP custom-metric questions come down to.

``sync_dist=True``
==================

The simplest option. Lightning mean-reduces each logged value across all ranks before storing it.

.. testcode::

Expand All @@ -101,31 +116,122 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics <h
# Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
self.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:
The ``sync_dist`` option can also be used in logging calls during the training step, but be aware
that this can lead to significant communication overhead and slow down your training.

.. warning::
``sync_dist=True`` averages per-rank *values*. It is only correct when
``mean(per_rank_metric) == global_metric``. It is **wrong** for F1, AUC, and precision or
recall on imbalanced classes — the mean of per-rank F1 scores is not the global F1 score.
For those metrics, reach for TorchMetrics instead.

TorchMetrics
============

`TorchMetrics <https://torchmetrics.readthedocs.io/>`__ handles the non-averageable case by
syncing the metric's internal *state* (for example, the running counts of true and false
positives) across ranks, then computing the metric from the merged state. The result matches
what you would get by evaluating on one rank with the full dataset. No ``sync_dist`` flag is
needed; the metric synchronizes itself when it is logged.

.. code-block:: python

def __init__(self):
super().__init__()
self.outputs = []
from torchmetrics.classification import BinaryF1Score


def test_step(self, batch, batch_idx):
x, y = batch
tensors = self(x)
self.outputs.append(tensors)
return tensors
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.val_f1 = BinaryF1Score()


def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
self.val_f1.update(logits, y)
# Passing the metric object to self.log triggers DDP sync at epoch end.
self.log("val_f1", self.val_f1, on_epoch=True)

This is the recommended option for any classification, retrieval, or ranking metric.

.. _manual-all-gather:

Manual ``all_gather``
=====================

def on_test_epoch_end(self):
mean = torch.mean(self.all_gather(self.outputs))
self.outputs.clear() # free memory
Use this when your metric is a custom computation over outputs accumulated across the whole
epoch — the case where neither ``sync_dist`` nor TorchMetrics fits. The pattern is: accumulate
per-step outputs into a list on the module, then at epoch end call
:meth:`~lightning.pytorch.core.LightningModule.all_gather` to combine each rank's contributions
before computing the metric. ``all_gather`` returns a tensor of shape
``[world_size, *tensor_shape]`` and every rank receives the same result.

# When you call `self.log` only on rank 0, don't forget to add
# `rank_zero_only=True` to avoid deadlocks on synchronization.
# Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/pytorch-lightning/issues/15852
if self.trainer.is_global_zero:
self.log("my_reduced_metric", mean, rank_zero_only=True)
.. code-block:: python

class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.val_outputs = []


def validation_step(self, batch, batch_idx):
x, y = batch
predictions = self(x)
self.val_outputs.append(predictions)
return predictions


def on_validation_epoch_end(self):
# self.all_gather returns a tensor of shape [world_size, *tensor_shape] on every rank.
gathered = self.all_gather(self.val_outputs)
metric = my_custom_metric(gathered)
self.val_outputs.clear() # free memory before the next epoch

# When you call `self.log` only on rank 0, don't forget to add
# `rank_zero_only=True` to avoid deadlocks on synchronization.
# Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/pytorch-lightning/issues/15852
if self.trainer.is_global_zero:
self.log("my_custom_val_metric", metric, rank_zero_only=True)

The same pattern applies to ``test_step`` / ``on_test_epoch_end``.

A common source of confusion here is that ``on_validation_epoch_end`` runs on every rank, so at
first glance the metric looks like it is being computed ``world_size`` times. After
``all_gather`` every rank already holds the *same* gathered tensor, so every rank computes the
*same* value — the redundant work is cheap and the result is correct. The ``is_global_zero``
guard belongs around ``self.log``, not around the computation. Never guard ``all_gather``
itself with ``is_global_zero``: it is a collective, and if some ranks skip it the program will
hang.

Which one should I use?
=======================

.. list-table::
:header-rows: 1
:widths: 45 55

* - Metric
- Use
* - Averageable scalar (loss, accuracy, MSE)
- ``sync_dist=True``
* - Classification or ranking metric (F1, AUC, precision, recall)
- TorchMetrics
* - Custom reduction over gathered tensors
- ``self.all_gather()``

Common pitfalls
===============

- **Using** ``sync_dist=True`` **on a non-averageable metric.** The logged value is the mean of
per-rank metrics, which is not the global metric. Use TorchMetrics instead.
- **Guarding** ``all_gather`` **with** ``is_global_zero``. Collectives must be called on every
rank. Put the guard around ``self.log``, not around the gather.
- **Passing** ``rank_zero_only=True`` **to** ``self.log`` **without synchronizing first.** Rank 0
logs its local value only, which is the ``1 / world_size`` problem this section opens with.

See also: the `TorchMetrics distributed evaluation guide
<https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metrics-and-distributed-training-ddp>`_
for how TorchMetrics synchronizes state internally.


----
Expand Down
Loading