# Multi-Backend Support

This library supports multiple array computation backends through the [Array API Standard](https://data-apis.org/array-api/), allowing you to use the same metric functions with NumPy, JAX, PyTorch, and CuPy arrays.

## Why Multi-Backend Support?

- **Interoperability**: Integrate with your preferred array library
- **Performance**: Leverage GPU acceleration and JIT compilation when available
- **Future-Proofing**: Add new backends as the ecosystem evolves

## Supported Backends

| Backend | CPU | GPU | JIT Compilation | Automatic Differentiation |
|---------|-----|-----|-----------------|---------------------------|
| NumPy   | ✅   | ❌   | ❌               | ❌                         |
| JAX     | ✅   | ✅   | ✅               | ✅                         |
| PyTorch | ✅   | ✅   | ⚠️*              | ✅                         |
| CuPy    | ❌   | ✅   | ⚠️*               | ❌                         |

*`torch.jit` and `cupy.fuse` support depend on metric complexity

## Quick Start

### Basic Usage
```python
import numpy as np
import jax.numpy as jnp
import torch
import ultrasound_metrics as um

# Same function, different backends
data_np = np.random.complex128((64, 100, 100))
data_jax = jnp.array(data_np)
data_torch = torch.tensor(data_np)

# All return arrays of the same backend type
coherence_np = um.coherence_factor(data_np)      # numpy array
coherence_jax = um.coherence_factor(data_jax)    # jax array
coherence_torch = um.coherence_factor(data_torch) # torch tensor
```

### GPU Acceleration
```python
import torch
import ultrasound_metrics as um

# Create data on GPU
data_gpu = torch.randn(128, 200, 300, device='cuda', dtype=torch.complex64)

# Computation happens on GPU
coherence_gpu = um.coherence_factor(data_gpu)  # Result stays on GPU
```

### JIT Compilation with JAX
```python
import jax
import jax.numpy as jnp
import ultrasound_metrics as um

# JIT compile for repeated use
coherence_factor_jit = jax.jit(um.coherence_factor)

# First call compiles, subsequent calls are fast
data = jnp.ones((64, 100, 100), dtype=jnp.complex64)
coherence = coherence_factor_jit(data)  # Compilation + execution
coherence = coherence_factor_jit(data)  # Fast execution only
```

### Gradient-Based Optimization
Coming soon!

## Key Features

### Type Preservation
Output arrays always match the input array type:
```python
# numpy in → numpy out
coherence_np = um.coherence_factor(numpy_data)      # numpy.ndarray
# jax in → jax out
coherence_jax = um.coherence_factor(jax_data)       # jax.numpy.ndarray
# torch in → torch out
coherence_torch = um.coherence_factor(torch_data)   # torch.Tensor
```


## Performance Tips

### Choosing a Backend
- **NumPy**: CPU computations, prototyping, simple workflows
- **JAX**: Research, JIT compilation, complex mathematical operations
- **PyTorch**: Deep learning integration, gradient computation
- **CuPy**: GPU acceleration with NumPy-like API

### Best Practices
- **Stay in one backend**: Avoid switching between array types during computation
- **Use JIT when available**: JAX and PyTorch can dramatically speed up repeated computations
- **Keep data on same device**: Ensure all arrays are on CPU or GPU consistently

## Common Issues

- **Mixed Backends**: Don't mix array types in the same operation
```python
numpy_data = np.array([1, 2, 3])
torch_data = torch.tensor([4, 5, 6])
um.some_metric(numpy_data, torch_data)  # Error!
```
