Multi-Backend Support#

This library supports multiple array computation backends through the Array API Standard, allowing you to use the same metric functions with NumPy, JAX, PyTorch, and CuPy arrays.

Why Multi-Backend Support?#

  • Interoperability: Integrate with your preferred array library

  • Performance: Leverage GPU acceleration and JIT compilation when available

  • Future-Proofing: Add new backends as the ecosystem evolves

Supported Backends#

Backend

CPU

GPU

JIT Compilation

Automatic Differentiation

NumPy

JAX

PyTorch

⚠️*

CuPy

⚠️*

*torch.jit and cupy.fuse support depend on metric complexity

Quick Start#

Basic Usage#

import numpy as np
import jax.numpy as jnp
import torch
import ultrasound_metrics as um

# Same function, different backends
data_np = np.random.complex128((64, 100, 100))
data_jax = jnp.array(data_np)
data_torch = torch.tensor(data_np)

# All return arrays of the same backend type
coherence_np = um.coherence_factor(data_np)      # numpy array
coherence_jax = um.coherence_factor(data_jax)    # jax array
coherence_torch = um.coherence_factor(data_torch) # torch tensor

GPU Acceleration#

import torch
import ultrasound_metrics as um

# Create data on GPU
data_gpu = torch.randn(128, 200, 300, device='cuda', dtype=torch.complex64)

# Computation happens on GPU
coherence_gpu = um.coherence_factor(data_gpu)  # Result stays on GPU

JIT Compilation with JAX#

import jax
import jax.numpy as jnp
import ultrasound_metrics as um

# JIT compile for repeated use
coherence_factor_jit = jax.jit(um.coherence_factor)

# First call compiles, subsequent calls are fast
data = jnp.ones((64, 100, 100), dtype=jnp.complex64)
coherence = coherence_factor_jit(data)  # Compilation + execution
coherence = coherence_factor_jit(data)  # Fast execution only

Gradient-Based Optimization#

Coming soon!

Key Features#

Type Preservation#

Output arrays always match the input array type:

# numpy in → numpy out
coherence_np = um.coherence_factor(numpy_data)      # numpy.ndarray
# jax in → jax out
coherence_jax = um.coherence_factor(jax_data)       # jax.numpy.ndarray
# torch in → torch out
coherence_torch = um.coherence_factor(torch_data)   # torch.Tensor

Performance Tips#

Choosing a Backend#

  • NumPy: CPU computations, prototyping, simple workflows

  • JAX: Research, JIT compilation, complex mathematical operations

  • PyTorch: Deep learning integration, gradient computation

  • CuPy: GPU acceleration with NumPy-like API

Best Practices#

  • Stay in one backend: Avoid switching between array types during computation

  • Use JIT when available: JAX and PyTorch can dramatically speed up repeated computations

  • Keep data on same device: Ensure all arrays are on CPU or GPU consistently

Common Issues#

  • Mixed Backends: Don’t mix array types in the same operation

numpy_data = np.array([1, 2, 3])
torch_data = torch.tensor([4, 5, 6])
um.some_metric(numpy_data, torch_data)  # Error!