Source code for ultrasound_metrics.motion.corrcoef

"""
Detect motion artifacts in ultrasound ensembles.

Uses the correlation between I/Q data across time, or across images.
"""

from typing import Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from einops import rearrange
from jaxtyping import Array, Int, Num, PRNGKeyArray, Real, jaxtyped


@jaxtyped(typechecker=typechecker)
def pairwise_corrcoef_agg(
    vectors: Num[Array, " n_vectors *each_vector_dims"],
) -> float:
    """
    Aggregate pairwise-correlation-coefficients between all vectors.

    Note: makes an opinionated choice to manage complex correlation coefficients,
        assuming that the real-part is the important component. This should be
        reasonable for data that is expected to be in-phase.

    Parameters
    ----------
    vectors
        Array of shape (n_vectors, ...) containing the vectors to correlate.
        If each vector is multidimensional, it will be flattened to a 1D array.

    Returns
    -------
    float
        Aggregation across all pairwise-correlation-coefficients.
    """
    corrcoefs, _ = pairwise_corrcoef(vectors=vectors)
    return float(jnp.real(jnp.mean(corrcoefs)))


@jaxtyped(typechecker=typechecker)
def pairwise_corrcoef(
    vectors: Num[Array, " n_vectors *each_vector_dims"],
    n_pairs: Optional[int] = None,
    key: Optional[PRNGKeyArray] = None,
) -> tuple[Num[Array, " n_pairs"], Int[Array, "n_pairs 2"]]:
    """
    Compute pairwise-correlation-coefficients between vectors efficiently using JAX.

    Parameters
    ----------
    vectors
        Array of shape (n_vectors, ...) containing the vectors to correlate.
        If each vector is multidimensional, it will be flattened to a 1D array.
    n_pairs
        Number of random pairs to compute. If None, computes all pairs.
    key
        JAX PRNG key for random sampling. If None, uses PRNGKey(0).

    Returns
    -------
    correlations
        Array of correlation coefficients.
    pairs
        Array of shape (n_pairs, 2) containing the indices of correlated pairs.
    """
    # If n_pairs is None or equal to max_pairs, compute all pairs
    if n_pairs is None:
        return pairwise_corrcoef_all_jit(vectors)

    tril_indices_tup = jnp.tril_indices(vectors.shape[0], k=-1)
    pairs = jnp.column_stack(tup=tril_indices_tup)
    max_pairs = pairs.shape[0]
    if n_pairs >= max_pairs:
        return pairwise_corrcoef_all_jit(vectors)

    # Randomly select index-pairs outside of JIT
    # because JIT cannot handle dynamically-sized arrays
    assert n_pairs is not None
    if key is None:
        key = jax.random.PRNGKey(0)
    indices = jax.random.permutation(key, max_pairs)[:n_pairs]
    pairs = pairs[indices]

    return pairwise_corrcoef_jit(vectors, pairs)


@jaxtyped(typechecker=typechecker)
@eqx.filter_jit
def pairwise_corrcoef_all_jit(
    vectors: Num[Array, " n_vectors *each_vector_dims"],
) -> tuple[Num[Array, " n_pairs"], Int[Array, "n_pairs 2"]]:
    """
    Compute pairwise-correlation-coefficients between all vectors with JIT-compiled JAX.

    Parameters
    ----------
    vectors
        Array of shape (n_vectors, ...) containing the vectors to correlate.
        If each vector is multidimensional, it will be flattened to a 1D array.

    Returns
    -------
    correlations
        Array of correlation coefficients.
    pairs
        Array of shape (n_pairs, 2) containing the indices of correlated pairs.
    """
    n_vectors = vectors.shape[0]
    vectors = rearrange(tensor=vectors, pattern="n_vectors ... -> n_vectors (...)")
    correlation_matrix = corrcoef(vectors)
    tril_indices_tup = jnp.tril_indices(n_vectors, k=-1)
    return correlation_matrix[tril_indices_tup], jnp.column_stack(tril_indices_tup)


@jaxtyped(typechecker=typechecker)
@eqx.filter_jit
def corrcoef(vectors: Num[Array, "n_observations n_features"]) -> Num[Array, " n_observations n_observations"]:
    """
    Compute correlation coefficient matrix (fast, simplified implementation of jnp.corrcoef).

    Parameters
    ----------
    vectors
        Input data with shape (n_observations, n_features).
        e.g. each observation is a B-Mode within an ensemble
        and each feature is a voxel in the B-Mode.

    Returns
    -------
    ndarray
        Correlation-coefficient matrix.
    """
    vectors_centered = vectors - jnp.mean(vectors, axis=1)[:, None]
    # Use ddof=1 for unbiased estimator, to match numpy's default
    vectors_normalized = vectors_centered / jnp.std(vectors, axis=1, ddof=1)[:, None]
    corr = jnp.matmul(vectors_normalized, vectors_normalized.conj().T) / (vectors.shape[1] - 1)
    return _clip_corrcoef(corrcoefs=corr)


@jaxtyped(typechecker=typechecker)
@eqx.filter_jit
def pairwise_corrcoef_jit(
    vectors: Num[Array, " n_vectors *each_vector_dims"],
    pairs: Int[Array, " n_pairs 2"],
) -> tuple[Num[Array, " n_pairs"], Int[Array, "n_pairs 2"]]:
    """
    Compute subset of pairwise-correlation-coefficients between vectors with JIT-compiled JAX.

    Parameters
    ----------
    vectors
        Array of shape (n_vectors, ...) containing the vectors to correlate.
        If each vector is multidimensional, it will be flattened to a 1D array.
    pairs
        Array of shape (n_pairs, 2) containing the indices of pairs
        to compute correlations for.

    Returns
    -------
    correlations
        Array of correlation coefficients.
    pairs
        Array of shape (n_pairs, 2) containing the indices of correlated pairs.
    """
    # Flatten all dimensions after the first into a single dimension
    vectors = rearrange(vectors, "n_vectors ... -> n_vectors (...)")

    # Normalize vectors for faster computation
    # Use ddof=1 for unbiased estimator, to match later normalization
    vectors_normalized = (vectors - jnp.mean(vectors, axis=1)[:, None]) / jnp.std(vectors, axis=1, ddof=1)[:, None]

    # Compute correlations efficiently using dot product of normalized vectors
    correlations = (
        jnp.sum(vectors_normalized[pairs[:, 0]] * vectors_normalized[pairs[:, 1]].conj(), axis=1) / (vectors.shape[1])
    )

    correlations = _clip_corrcoef(corrcoefs=correlations)
    return correlations, pairs


def _clip_corrcoef(corrcoefs: Num[Array, "*sizes"], tol: float = 0.2) -> Num[Array, "*sizes"]:
    """
    Clip correlation coefficients to [-1, 1] to avoid numerical errors.

    Parameters
    ----------
    corrcoefs
        Array of correlation coefficients.
    tol
        Tolerance for values slightly outside of [-1, 1].

    Returns
    -------
    ndarray
        Clipped correlation coefficients.
    """
    if jnp.iscomplexobj(corrcoefs):
        return _clip_corrcoef_complex(corrcoefs=corrcoefs, tol=tol)
    else:
        return _clip_corrcoef_real(corrcoefs=corrcoefs, tol=tol)


@jaxtyped(typechecker=typechecker)
def _clip_corrcoef_real(corrcoefs: Real[Array, "*sizes"], tol: float = 0.01) -> Real[Array, "*sizes"]:
    """
    Clip real correlation coefficients to [-1, 1] to avoid numerical errors.

    Parameters
    ----------
    corrcoefs
        Array of real correlation coefficients.
    tol
        Tolerance for values slightly outside of [-1, 1].

    Returns
    -------
    ndarray
        Correlation coefficients clipped to [-1, 1].
    """
    corrcoefs = eqx.error_if(
        x=corrcoefs,
        pred=corrcoefs < -(1 + tol),
        msg="Correlation coefficient below -1",
    )
    corrcoefs = eqx.error_if(
        x=corrcoefs,
        pred=corrcoefs > 1 + tol,
        msg="Correlation coefficient above 1",
    )
    corrcoefs = jnp.clip(corrcoefs, min=-1, max=1)
    return corrcoefs


@jaxtyped(typechecker=typechecker)
def _clip_corrcoef_complex(corrcoefs: Num[Array, "*sizes"], tol: float = 0.01) -> Num[Array, "*sizes"]:
    """
    Clip complex correlation coefficients to [-1, 1] to avoid numerical errors.

    Parameters
    ----------
    corrcoefs
        Array of complex correlation coefficients.
    tol
        Tolerance for values slightly outside of [-1, 1].

    Returns
    -------
    ndarray
        Correlation coefficients with real and imaginary parts clipped to [-1, 1].
    """
    real_corrcoefs = jnp.real(corrcoefs)
    imag_corrcoefs = jnp.imag(corrcoefs)
    real_corrcoefs = _clip_corrcoef_real(corrcoefs=real_corrcoefs, tol=tol)
    imag_corrcoefs = _clip_corrcoef_real(corrcoefs=imag_corrcoefs, tol=tol)
    return real_corrcoefs + 1j * imag_corrcoefs


if __name__ == "__main__":
    import time

    import jax.numpy as jnp

    # Example usage:
    print("Example usage of pairwise_corrcoef")
[docs] key = jax.random.PRNGKey(0)
# Example 200-ensemble, 8960-channel, 60-timepoint data vectors = jax.random.normal(key, (200, 8_960, 60), dtype=jnp.complex64) # Example: Use all pairs n_pairs = None print(f"Vectors size and info: {vectors.shape} {vectors.dtype} {vectors.nbytes / 1e6} MB") # First call will compile start = time.time() correlations, pairs = pairwise_corrcoef(vectors, n_pairs=n_pairs) end = time.time() print(f"First call took: {end - start} seconds") print(f"Correlations shape: {correlations.shape}") # Subsequent calls will be much faster start = time.time() correlations, pairs = pairwise_corrcoef(vectors, n_pairs=n_pairs) end = time.time() print(f"Second call took: {end - start} seconds") print(f"Correlations shape: {correlations.shape}")