Skip to content

feat(ops): add keras.ops.image.euclidean_dist_transform#23120

Open
rstar327 wants to merge 1 commit into
keras-team:masterfrom
rstar327:add-euclidean-dist-transform
Open

feat(ops): add keras.ops.image.euclidean_dist_transform#23120
rstar327 wants to merge 1 commit into
keras-team:masterfrom
rstar327:add-euclidean-dist-transform

Conversation

@rstar327

Copy link
Copy Markdown
Contributor

Description

Adds keras.ops.image.euclidean_dist_transform — for each non-background pixel of a binary image, computes the Euclidean distance to the nearest background (zero-valued) pixel. Each channel is processed independently. Closes #22038.

This is commonly needed for distance-based loss functions on segmentation masks (boundary loss, signed distance regression, weighted CE) and for distance-aware visualization.

Signature

keras.ops.image.euclidean_dist_transform(
    images,           # 3D (H, W, C) or 4D (N, H, W, C) integer tensor
    dtype="float32",  # output float dtype
    data_format=None, # "channels_last" | "channels_first" | None -> config
)

Behavior

  • Returns a tensor of the same shape as images with the requested float dtype.
  • Per-channel: distance from each non-zero pixel to the nearest zero pixel; zero pixels have distance 0.
  • Implementation reuses scipy.ndimage.distance_transform_edt (already a Keras dependency) through each backend's appropriate callback mechanism:
    • tensorflow: tf.numpy_function
    • jax: jax.pure_callback
    • torch / numpy: direct numpy call on device-host buffer
    • openvino: raises NotImplementedError (matches the sobel_edges precedent)

API additions

  • keras.ops.image.euclidean_dist_transform

Tests

EuclideanDistTransformTest in keras/src/ops/image_test.py covers:

  • 3D and 4D inputs, channels_first and channels_last
  • Numeric agreement with the scipy reference on a randomized input
  • Closed-form check on a single-background-pixel image (distance equals geometric distance to that pixel)
  • Dtype handling (float32 default, float16 explicit)
  • Input validation (rejects float input, rejects rank-2 input)
  • Symbolic call shape/dtype propagation

All 9 tests pass on tensorflow, jax, torch, numpy backends locally. openvino is covered in CI (raises with a clear message and the test class skips).

Contributor Agreement

Please review our AI-Assisted Contribution Policy and check all boxes below before submitting your PR for review:

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Note: Failing to adhere to this agreement may result in your future PRs no longer being reviewed.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the euclidean_dist_transform operation to compute the Euclidean distance transform of binary images, with implementations provided for the JAX, NumPy, TensorFlow, and PyTorch backends, alongside comprehensive unit tests. The review feedback recommends adding integer dtype validation within the symbolic compute_output_spec method to catch invalid inputs early, and refactoring the NumPy backend implementation to use the top-level scipy import and np.squeeze for better consistency across backends.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread keras/src/ops/image.py
Comment on lines +2424 to +2432
def compute_output_spec(self, images):
images_shape = list(images.shape)
if len(images_shape) not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). "
f"Received: images.shape={images_shape}"
)
return KerasTensor(shape=images_shape, dtype=self.dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The compute_output_spec method of EuclideanDistTransform should validate that the input images has an integer dtype. This ensures that invalid inputs (like float tensors) are caught early during symbolic tracing/model building, adhering to the Keras API guidelines of catching user errors as early as possible.

    def compute_output_spec(self, images):
        images_shape = list(images.shape)
        if len(images_shape) not in (3, 4):
            raise ValueError(
                "Invalid images rank: expected rank 3 (single image) "
                "or rank 4 (batch of images). "
                f"Received: images.shape={images_shape}"
            )
        if not backend.is_int_dtype(images.dtype):
            raise TypeError(
                "`euclidean_dist_transform` expects an integer-dtype input. "
                f"Received: images.dtype={images.dtype}"
            )
        return KerasTensor(shape=images_shape, dtype=self.dtype)
References
  1. Catch user errors early and anticipate common mistakes. Do user input validation as soon as possible. (link)

Comment on lines +1239 to +1264
def euclidean_dist_transform(images, dtype="float32", data_format=None):
from scipy import ndimage

data_format = backend.standardize_data_format(data_format)
images = convert_to_tensor(images)
_validate_edt_input(images)

unbatched = images.ndim == 3
if unbatched:
images = images[np.newaxis, ...]
if data_format == "channels_first":
images = np.transpose(images, (0, 2, 3, 1))

np_dtype = backend.standardize_dtype(dtype)
out = np.empty(images.shape, dtype=np_dtype)
for b in range(images.shape[0]):
for c in range(images.shape[-1]):
out[b, :, :, c] = ndimage.distance_transform_edt(
images[b, :, :, c] != 0
)

if data_format == "channels_first":
out = np.transpose(out, (0, 3, 1, 2))
if unbatched:
out = out[0]
return out

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the NumPy backend, scipy is already imported at the top level from keras.src.utils.module_utils. We should use scipy.ndimage directly instead of locally importing from scipy import ndimage to leverage the lazy-loading wrapper and avoid redundant imports. Additionally, using np.squeeze(out, axis=0) instead of out[0] is more consistent with the JAX, TensorFlow, and PyTorch backend implementations.

def euclidean_dist_transform(images, dtype="float32", data_format=None):
    data_format = backend.standardize_data_format(data_format)
    images = convert_to_tensor(images)
    _validate_edt_input(images)

    unbatched = images.ndim == 3
    if unbatched:
        images = images[np.newaxis, ...]
    if data_format == "channels_first":
        images = np.transpose(images, (0, 2, 3, 1))

    np_dtype = backend.standardize_dtype(dtype)
    out = np.empty(images.shape, dtype=np_dtype)
    for b in range(images.shape[0]):
        for c in range(images.shape[-1]):
            out[b, :, :, c] = scipy.ndimage.distance_transform_edt(
                images[b, :, :, c] != 0
            )

    if data_format == "channels_first":
        out = np.transpose(out, (0, 3, 1, 2))
    if unbatched:
        out = np.squeeze(out, axis=0)
    return out

@codecov-commenter

codecov-commenter commented Jun 20, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.61905% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.07%. Comparing base (e0bfa2d) to head (4d664a0).

Files with missing lines Patch % Lines
keras/src/ops/image.py 82.35% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #23120      +/-   ##
==========================================
- Coverage   84.69%   84.07%   -0.63%     
==========================================
  Files         464      464              
  Lines       68957    69083     +126     
  Branches    11331    11365      +34     
==========================================
- Hits        58404    58082     -322     
- Misses       7624     8082     +458     
+ Partials     2929     2919      -10     
Flag Coverage Δ
keras 83.89% <97.61%> (-0.61%) ⬇️
keras-cpu 83.89% <97.61%> (+0.02%) ⬆️
keras-gpu ?
keras-jax 57.92% <36.50%> (-0.25%) ⬇️
keras-numpy 53.68% <33.33%> (-0.04%) ⬇️
keras-openvino 59.41% <7.93%> (-0.10%) ⬇️
keras-tensorflow 59.48% <36.50%> (-0.33%) ⬇️
keras-torch 58.60% <33.33%> (-0.44%) ⬇️
keras-tpu ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add euclidean_dist_transform

3 participants