Multi-Backend Support#
This library supports multiple array computation backends through the Array API Standard, allowing you to use the same metric functions with NumPy, JAX, PyTorch, and CuPy arrays.
Why Multi-Backend Support?#
Interoperability: Integrate with your preferred array library
Performance: Leverage GPU acceleration and JIT compilation when available
Future-Proofing: Add new backends as the ecosystem evolves
Supported Backends#
Backend |
CPU |
GPU |
JIT Compilation |
Automatic Differentiation |
|---|---|---|---|---|
NumPy |
✅ |
❌ |
❌ |
❌ |
JAX |
✅ |
✅ |
✅ |
✅ |
PyTorch |
✅ |
✅ |
⚠️* |
✅ |
CuPy |
❌ |
✅ |
⚠️* |
❌ |
*torch.jit and cupy.fuse support depend on metric complexity
Quick Start#
Basic Usage#
import numpy as np
import jax.numpy as jnp
import torch
import ultrasound_metrics as um
# Same function, different backends
data_np = np.random.complex128((64, 100, 100))
data_jax = jnp.array(data_np)
data_torch = torch.tensor(data_np)
# All return arrays of the same backend type
coherence_np = um.coherence_factor(data_np) # numpy array
coherence_jax = um.coherence_factor(data_jax) # jax array
coherence_torch = um.coherence_factor(data_torch) # torch tensor
GPU Acceleration#
import torch
import ultrasound_metrics as um
# Create data on GPU
data_gpu = torch.randn(128, 200, 300, device='cuda', dtype=torch.complex64)
# Computation happens on GPU
coherence_gpu = um.coherence_factor(data_gpu) # Result stays on GPU
JIT Compilation with JAX#
import jax
import jax.numpy as jnp
import ultrasound_metrics as um
# JIT compile for repeated use
coherence_factor_jit = jax.jit(um.coherence_factor)
# First call compiles, subsequent calls are fast
data = jnp.ones((64, 100, 100), dtype=jnp.complex64)
coherence = coherence_factor_jit(data) # Compilation + execution
coherence = coherence_factor_jit(data) # Fast execution only
Gradient-Based Optimization#
Coming soon!
Key Features#
Type Preservation#
Output arrays always match the input array type:
# numpy in → numpy out
coherence_np = um.coherence_factor(numpy_data) # numpy.ndarray
# jax in → jax out
coherence_jax = um.coherence_factor(jax_data) # jax.numpy.ndarray
# torch in → torch out
coherence_torch = um.coherence_factor(torch_data) # torch.Tensor
Performance Tips#
Choosing a Backend#
NumPy: CPU computations, prototyping, simple workflows
JAX: Research, JIT compilation, complex mathematical operations
PyTorch: Deep learning integration, gradient computation
CuPy: GPU acceleration with NumPy-like API
Best Practices#
Stay in one backend: Avoid switching between array types during computation
Use JIT when available: JAX and PyTorch can dramatically speed up repeated computations
Keep data on same device: Ensure all arrays are on CPU or GPU consistently
Common Issues#
Mixed Backends: Don’t mix array types in the same operation
numpy_data = np.array([1, 2, 3])
torch_data = torch.tensor([4, 5, 6])
um.some_metric(numpy_data, torch_data) # Error!