feat(ops): add keras.ops.image.euclidean_dist_transform#23120
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the euclidean_dist_transform operation to compute the Euclidean distance transform of binary images, with implementations provided for the JAX, NumPy, TensorFlow, and PyTorch backends, alongside comprehensive unit tests. The review feedback recommends adding integer dtype validation within the symbolic compute_output_spec method to catch invalid inputs early, and refactoring the NumPy backend implementation to use the top-level scipy import and np.squeeze for better consistency across backends.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def compute_output_spec(self, images): | ||
| images_shape = list(images.shape) | ||
| if len(images_shape) not in (3, 4): | ||
| raise ValueError( | ||
| "Invalid images rank: expected rank 3 (single image) " | ||
| "or rank 4 (batch of images). " | ||
| f"Received: images.shape={images_shape}" | ||
| ) | ||
| return KerasTensor(shape=images_shape, dtype=self.dtype) |
There was a problem hiding this comment.
The compute_output_spec method of EuclideanDistTransform should validate that the input images has an integer dtype. This ensures that invalid inputs (like float tensors) are caught early during symbolic tracing/model building, adhering to the Keras API guidelines of catching user errors as early as possible.
def compute_output_spec(self, images):
images_shape = list(images.shape)
if len(images_shape) not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). "
f"Received: images.shape={images_shape}"
)
if not backend.is_int_dtype(images.dtype):
raise TypeError(
"`euclidean_dist_transform` expects an integer-dtype input. "
f"Received: images.dtype={images.dtype}"
)
return KerasTensor(shape=images_shape, dtype=self.dtype)References
- Catch user errors early and anticipate common mistakes. Do user input validation as soon as possible. (link)
| def euclidean_dist_transform(images, dtype="float32", data_format=None): | ||
| from scipy import ndimage | ||
|
|
||
| data_format = backend.standardize_data_format(data_format) | ||
| images = convert_to_tensor(images) | ||
| _validate_edt_input(images) | ||
|
|
||
| unbatched = images.ndim == 3 | ||
| if unbatched: | ||
| images = images[np.newaxis, ...] | ||
| if data_format == "channels_first": | ||
| images = np.transpose(images, (0, 2, 3, 1)) | ||
|
|
||
| np_dtype = backend.standardize_dtype(dtype) | ||
| out = np.empty(images.shape, dtype=np_dtype) | ||
| for b in range(images.shape[0]): | ||
| for c in range(images.shape[-1]): | ||
| out[b, :, :, c] = ndimage.distance_transform_edt( | ||
| images[b, :, :, c] != 0 | ||
| ) | ||
|
|
||
| if data_format == "channels_first": | ||
| out = np.transpose(out, (0, 3, 1, 2)) | ||
| if unbatched: | ||
| out = out[0] | ||
| return out |
There was a problem hiding this comment.
In the NumPy backend, scipy is already imported at the top level from keras.src.utils.module_utils. We should use scipy.ndimage directly instead of locally importing from scipy import ndimage to leverage the lazy-loading wrapper and avoid redundant imports. Additionally, using np.squeeze(out, axis=0) instead of out[0] is more consistent with the JAX, TensorFlow, and PyTorch backend implementations.
def euclidean_dist_transform(images, dtype="float32", data_format=None):
data_format = backend.standardize_data_format(data_format)
images = convert_to_tensor(images)
_validate_edt_input(images)
unbatched = images.ndim == 3
if unbatched:
images = images[np.newaxis, ...]
if data_format == "channels_first":
images = np.transpose(images, (0, 2, 3, 1))
np_dtype = backend.standardize_dtype(dtype)
out = np.empty(images.shape, dtype=np_dtype)
for b in range(images.shape[0]):
for c in range(images.shape[-1]):
out[b, :, :, c] = scipy.ndimage.distance_transform_edt(
images[b, :, :, c] != 0
)
if data_format == "channels_first":
out = np.transpose(out, (0, 3, 1, 2))
if unbatched:
out = np.squeeze(out, axis=0)
return out
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #23120 +/- ##
==========================================
- Coverage 84.69% 84.07% -0.63%
==========================================
Files 464 464
Lines 68957 69083 +126
Branches 11331 11365 +34
==========================================
- Hits 58404 58082 -322
- Misses 7624 8082 +458
+ Partials 2929 2919 -10
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Description
Adds
keras.ops.image.euclidean_dist_transform— for each non-background pixel of a binary image, computes the Euclidean distance to the nearest background (zero-valued) pixel. Each channel is processed independently. Closes #22038.This is commonly needed for distance-based loss functions on segmentation masks (boundary loss, signed distance regression, weighted CE) and for distance-aware visualization.
Signature
Behavior
imageswith the requested float dtype.scipy.ndimage.distance_transform_edt(already a Keras dependency) through each backend's appropriate callback mechanism:tf.numpy_functionjax.pure_callbackNotImplementedError(matches thesobel_edgesprecedent)API additions
keras.ops.image.euclidean_dist_transformTests
EuclideanDistTransformTestinkeras/src/ops/image_test.pycovers:channels_firstandchannels_lastfloat32default,float16explicit)All 9 tests pass on tensorflow, jax, torch, numpy backends locally. openvino is covered in CI (raises with a clear message and the test class skips).
Contributor Agreement
Please review our AI-Assisted Contribution Policy and check all boxes below before submitting your PR for review:
Note: Failing to adhere to this agreement may result in your future PRs no longer being reviewed.