"""
Measure signal-clipping in ultrasound pulse-echo data (RF or I/Q).
Uses time-domain amplitude thresholding.
"""
from enum import Enum
from typing import Any, Optional, Union
import equinox as eqx
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Bool, Num, Real, jaxtyped
# Helper type for scalar values: jaxtyping
[docs]
Scalar = Union[float, int]
[docs]
class ClipDetectMethod(str, Enum):
"""Enum for different clipping metrics."""
[docs]
THRESHOLD = "threshold"
[docs]
MAX_AMPLITUDE = "max_amplitude"
@jaxtyped(typechecker=typechecker)
[docs]
def clip_ratio(
data: Num[Array, " *batch sample"],
method: Union[ClipDetectMethod, str] = ClipDetectMethod.THRESHOLD,
**kwargs: Any,
) -> float:
"""
Measure the ratio of clipped samples in the data.
Supports multiple methods for clipping detection.
Parameters
----------
data
Data array of shape (..., n_samples).
method
Method to use for clipping detection.
**kwargs
Additional keyword arguments for the clipping detection method.
Returns
-------
float
Ratio of clipped samples in the data.
"""
if method == ClipDetectMethod.THRESHOLD:
return float(is_clipped_threshold(data, **kwargs).mean())
elif method == ClipDetectMethod.MAX_AMPLITUDE:
return float(is_clipped_max_amplitude(data, **kwargs).mean())
else:
raise ValueError(f"Invalid clip detection method: {method}")
@jaxtyped(typechecker=typechecker)
@eqx.filter_jit
[docs]
def is_clipped_threshold(
data: Num[Array, " *batch sample"],
*,
max_threshold: float,
min_threshold: Optional[float] = None,
) -> Bool[Array, " *batch sample"]:
"""
Detect clipped samples in the data by comparing to a threshold.
Parameters
----------
data
Data array of shape (..., n_samples).
max_threshold
Maximum allowed amplitude.
min_threshold
Minimum allowed amplitude. If None, symmetric threshold is assumed.
Returns
-------
ndarray
Boolean mask with shape=data.shape indicating which samples were clipped.
"""
if jnp.iscomplexobj(data):
# Use magnitude for complex data to detect clipping
return jnp.abs(data) >= max_threshold
if min_threshold is None:
min_threshold = -max_threshold
return (data < min_threshold) | (data > max_threshold)
@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def _is_clipped_max_amplitude_core(
data_to_check: Real[Array, " sample"],
*,
low_factor: Optional[float],
high_factor: Optional[float],
min_contiguous: int,
) -> Bool[Array, " sample"]:
"""
Core computation for clip detection on a single channel.
Parameters
----------
data_to_check
Data array of shape (sample,).
low_factor
Low threshold factor.
high_factor
High threshold factor.
min_contiguous
Minimum number of contiguous samples outside the range to consider
that the signal was clipped.
Returns
-------
ndarray
Boolean mask with shape=data_to_check.shape indicating which samples were clipped.
"""
# Calculate the min/max range of the data
min_val = jnp.min(data_to_check)
max_val = jnp.max(data_to_check)
data_range = max_val - min_val
# Create mask for potential clips
# Start with an array of False values
potential_clips = jnp.zeros_like(data_to_check, dtype=bool)
# Check low threshold if it's active
if low_factor is not None:
# 0 magnitude is not clipping for complex-magnitude data
low_threshold = min_val + (low_factor * data_range)
potential_clips = potential_clips | (data_to_check <= low_threshold)
# Check high threshold if it's active
if high_factor is not None:
high_threshold = max_val - (high_factor * data_range)
potential_clips = potential_clips | (data_to_check >= high_threshold)
# If min_contiguous is 1, we assume any potential clips are real clips
has_contiguous_clip = jnp.array(min_contiguous <= 1)
# Check for contiguous clipped samples using convolution
if min_contiguous > 1:
# Create a kernel of ones with length min_contiguous
kernel = jnp.ones(min_contiguous)
# Convolve with the potential_clips
conv_result = jnp.convolve(potential_clips, kernel, mode="same")
# If any window sum equals min_contiguous, we have a contiguous segment of clipped samples
has_contiguous_clip = jnp.any(conv_result >= min_contiguous)
# Return the potential clips mask if contiguous clipping exists, otherwise zeros
return jnp.where(has_contiguous_clip, potential_clips, jnp.zeros_like(potential_clips))
@jaxtyped(typechecker=typechecker)
[docs]
def is_clipped_max_amplitude(
data: Num[Array, " *batch sample"],
*,
range_factor: Union[Scalar, tuple[Optional[Scalar], Optional[Scalar]]] = 0.01,
min_contiguous: int = 1,
) -> Bool[Array, " *batch sample"]:
"""
Detect clipped samples in the data by checking for contiguous segments outside the normal range.
Implementation note: allows for a different threshold for each time-series in the batch.
Parameters
----------
data
Data array of shape (..., n_samples).
range_factor
How far inside the data's min/max range to consider as clipping.
Can be either:
- float: symmetric factor for real data (e.g., 0.01 means within ±1% of the data range).
For complex data, this is the factor for the magnitude range.
- tuple: (lower, upper) factors for asymmetric thresholds
- None for either value in the tuple to disable clipping detection on that side
Use range_factor when clipping is soft, or when there is some preprocessing
(such as filtering or lossy compression) that softens hard-clipping.
min_contiguous
Minimum number of contiguous samples outside the range to consider
that the signal was clipped.
This is for guessing whether the signal was clipped, rather than just
that the actual signal amplitude has some range.
Returns
-------
ndarray
Boolean mask with shape=data.shape indicating which samples were clipped.
"""
if min_contiguous < 1:
raise ValueError("min_contiguous must be at least 1 for clip detection")
is_real = bool(jnp.isrealobj(data))
# Convert and check range_factor argument
if isinstance(range_factor, (int, float)):
if is_real:
# Symmetric range_factor
range_factor = (float(range_factor), float(range_factor))
else:
# For complex data, we take the magnitude, so the min-value (0) is not clipping
range_factor = (None, float(range_factor))
else:
assert isinstance(range_factor, tuple)
# Convert to float types while preserving None
range_factor = (
float(range_factor[0]) if range_factor[0] is not None else None,
float(range_factor[1]) if range_factor[1] is not None else None,
)
assert isinstance(range_factor, tuple)
assert len(range_factor) == 2
low_factor, high_factor = range_factor
if (low_factor is None) and (high_factor is None):
raise ValueError("At least one side of range_factor must not be None detect clips")
# Validate range factors that are not None
if (low_factor is not None) and (low_factor < 0):
raise ValueError("low_factor must be non-negative")
if (high_factor is not None) and (high_factor < 0):
raise ValueError("high_factor must be non-negative")
# For complex data, we take the magnitude, assuming that clipping happens before
# conversion from real -> complex (RF -> IQ).
data_to_check: Real[Array, " *batch sample"] = data if is_real else jnp.abs(data)
# Prepare data for processing
original_shape = data.shape
# Reshape to make this easier for vmap
reshaped_data: Real[Array, " batch_all sample"] = jax.lax.collapse(
data_to_check, start_dimension=0, stop_dimension=-1
)
# Handle each batch independently
batch_result = jax.vmap(
lambda x: _is_clipped_max_amplitude_core(
data_to_check=x,
low_factor=low_factor,
high_factor=high_factor,
min_contiguous=min_contiguous,
)
)(reshaped_data)
# Reshape back to original dimensions
return batch_result.reshape(original_shape)