# Array API Design Principles

*This page is for developers who want to understand the implementation details or contribute to ultrasound-metrics. For user-focused information, see [Multi-Backend Support](multi_backend_support.md).*

This page explains the technical design principles behind ultrasound-metrics' multi-backend support and how it enables consistent behavior across different array computation libraries.

## The Array API Standard

The [Array API Standard](https://data-apis.org/array-api/) provides a unified interface for array operations across NumPy, JAX, PyTorch, CuPy, and other libraries. It focuses on a core set of operations that work consistently across backends.

### Key Principles

- **Unified Interface**: Same function signatures work across all supported backends
- **Type Preservation**: Input array type determines output array type
- **Functional Design**: Avoids in-place mutations for better compatibility
- **Consistent Behavior**: Standardized broadcasting, type promotion, and error handling

### Coverage

The Array API implements the most common array operations:
- **Core functions**: `asarray`, `reshape`, `matrix_transpose`, arithmetic operations
- **Mathematical functions**: `sin`, `cos`, `exp`, `log`, statistical functions
- **Indexing and slicing**: Advanced indexing, boolean masking
- and [more](https://data-apis.org/array-api/latest/API_specification/index.html)

Optional [extensions](https://data-apis.org/array-api/latest/extensions/index.html) include:

- **Linear algebra (optional)**: Extension includes norms and decompositions
- **FFT operations (optional)**: Extension includes forward and inverse transforms

You can read the official [Array API standard](https://data-apis.org/array-api/latest/index.html)
for more detail, but hopefully the above information
gives you enough context to review `ultrasound-metrics` code.

## Design Patterns in `ultrasound-metrics`

**Single Implementation**: Write metric code once, works across all backends
**Type Safety**: Runtime checking prevents shape/dtype mismatches
**Performance**: JIT compilation and GPU acceleration come automatically
**Future-Proof**: New Array API backends work without code changes

### 1. Array Namespace Detection

Use [`array_api_compat.array_namespace`](https://data-apis.org/array-api-compat/helper-functions.html#array_api_compat.array_namespace) to detect the input array library and dispatch to the correct operations.

```python
from array_api_compat import array_namespace

def coherence_factor(channel_images):
    # Automatically detect the array library
    xp = array_namespace(channel_images)

    # Use detected namespace for all operations
    coherent_sum = xp.sum(channel_images, axis=0)
    power_sum = xp.sum(xp.abs(channel_images)**2, axis=0)
    return xp.abs(coherent_sum)**2 / (power_sum * channel_images.shape[0])
```

### 2. Type Annotations with jaxtyping

Use `jaxtyping` for `.shape` and `.dtype` run-time validation:
```python
from jaxtyping import Float, Complex, jaxtyped
from beartype import beartype

@jaxtyped(typechecker=beartype)  # check type-hints at run-time
def coherence_factor(
    # annotate arguments and return-values for documentation
    channel_images: Complex[ArrayAPIObj, "receive_elements *img_dims"],
) -> Float[ArrayAPIObj, "*img_dims"]:
    # Implementation with guaranteed type safety
```

We use the separate library `jaxtyping` to annotate types for now, but
in the future, Array API extensions may support
[type-hinting out-of-the-box](https://github.com/data-apis/array-api-typing)

### 3. Handling Missing Functions
When Array API doesn't provide needed operations:
```python
def custom_nanmax(x, axis=None):
    xp = array_namespace(x)

    # Prefer backend-specific implementation
    if hasattr(xp, 'nanmax'):
        return xp.nanmax(x, axis=axis)

    # Fallback: implement using Array API primitives
    mask = xp.isnan(x)
    inf_val = xp.asarray(-xp.inf, device=device(x))
    return xp.max(xp.where(mask, inf_val, x), axis=axis)
```

If the function cannot easily be implemented by Array API components
but is implemented by the underlying modules, use conditional backend-specific code:

```python
def histogram(x, bins, range=None, weights=None, density=False):
    xp = array_namespace(x)

    # Use optimized backend-specific implementations when available
    if is_numpy_array(x):
        return np.histogram(x, bins=bins, range=range, weights=weights, density=density)
    elif is_jax_array(x):
        import jax.numpy as jnp
        return jnp.histogram(x, bins=bins, range=range, weights=weights, density=density)
    elif is_torch_array(x):
        import torch
        # ...
```

The conditional approach is less ideal because it:

- drops support for unknown array libraries
- might not work with JIT-compilation

but it does simplify initial implementation.

### 4. Test across backends

Tests should test across all supported backends,
as well as the `array-api-strict` backend if the metric strictly follows the Array API.

See the `tests/` folder for examples using the correct decorator for multi-backend testing.

## Resources

- **Array API Standard**: [data-apis.org/array-api](https://data-apis.org/array-api/)
- **Compatibility Library**: [array-api-compat](https://data-apis.org/array-api-compat/)
- **SciPy Development Guide**: [Adding Array API Support](https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html)
- **Array API in Scikit-learn**: [PyData 2023 talk by Thomas J. Fan](https://www.youtube.com/watch?v=c_s8tr1AizA)
- **Type Checking**: [jaxtyping Documentation](https://docs.kidger.site/jaxtyping/)
