Source code for ultrasound_metrics.channel.noise

"""
Estimate the signal-to-noise ratio (SNR) of an ultrasound transducer system.

This module provides methods to calculate SNR using repeated pulse-echo
measurements of a phantom or tissue.

Notes
-----
Motivation:
    We would like to understand the power/SNR tradeoffs, e.g.:
    - How much does AFE power improve SNR?
    - How much does combining channels via summation improve SNR?
    - How much do increasing chip-repeats improve SNR?

Implemented Methods:
    Repeated Measurements Variance (best for a static phantom):
        - Acquire multiple identical frames (same configuration)
        - Calculate the mean signal across frames (this represents the true signal)
        - Calculate the standard deviation across frames (this represents noise)
        - SNR = mean signal / standard deviation of noise

    Signal Differences in Adjacent Frames:
        - Since tissue moves slowly and we acquire ultrafast ultrasound:
        - Calculate frame-to-frame differences
        - Assuming the tissue moves much slower than the frame rate,
          frame-to-frame differences primarily represent transmit and receive-channel noise
        - Compare mean signal amplitude to this difference (noise) to estimate SNR

Alternative Methods (Not Implemented):
    Transmit-Off Noise Measurements:
        - Collect data with the transmitter turned off (tx-off)
        - This is all noise, so it's easy to calculate SNR.
        - Downside: noise may be nonlinear or signal-dependent, so this
          would not capture the practical SNR.

    Spectral Analysis:
        - Assumes that the signal is narrowband and that the noise is flat.
        - This is not true for our data, so we don't implement this.

Differences from B-Mode SNR:
    In ultrasound or image analysis, we often calculate SNR in the B-mode image.
    For example, checking the strength of a reflector in the image.

    You might take a spatial region analysis approach, e.g.:
        - Select regions devoid of signal (e.g. water) to calculate noise power
        - If a reflector is present, you can also calculate signal power

    This is a great way to calculate image-related SNR. However, it does not
    directly characterize the electronic noise, e.g. thermal or quantization noise.
    Scatterer SNR probably belongs in a different submodule of this same repository.
"""

from enum import Enum

import equinox as eqx
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Float, Num, jaxtyped


[docs] class NoiseEstimationMethod(str, Enum): """ Enumeration of noise estimation methods. Attributes ---------- REPEATED_MEASUREMENTS Use repeated measurements variance method. ADJACENT_FRAMES Use adjacent frames difference method. """
[docs] REPEATED_MEASUREMENTS = "repeated_measurements"
[docs] ADJACENT_FRAMES = "adjacent_frames"
@jaxtyped(typechecker=typechecker) @eqx.filter_jit def noise_repeated_measurements( data: Num[Array, " repetition *samples"], ) -> tuple[Float[Array, ""], Float[Array, ""], Float[Array, ""]]: """ Calculate noise from repeated measurements of a static phantom. Assumes the average signal is representative of the true signal. Then, the noise is the standard deviation of the signal across loops. Note: aggregates over samples dimension. Parameters ---------- data Transducer data, (e.g. complex I/Q) data with shape (repetition, space, time) where `repetition` represents repeated acquisitions. Returns ------- tuple Tuple (snr_db, signal_power, noise_power), where: - snr_db: signal-to-noise power-ratio in decibels (10*log10(signal_power/noise_power)) - signal_power: signal power - noise_power: noise power """ # Calculate signal power (magnitude squared of mean signal) signal_mean = data.mean(axis=0) signal_power_per_repetition = (jnp.abs(signal_mean) ** 2).sum() # Calculate noise variance in complex domain # Variance of complex random variable = Var(I) + Var(Q) noise_power_per_repetition = (jnp.var(data, axis=0)).sum() # Calculate SNR snr = signal_power_per_repetition / noise_power_per_repetition # Convert to decibels (power ratio, so 10*log10) snr_db = 10 * jnp.log10(snr) return snr_db, signal_power_per_repetition, noise_power_per_repetition @jaxtyped(typechecker=typechecker) @eqx.filter_jit def noise_adjacent_frames( data: Num[Array, "repetition ..."], ) -> tuple[Float[Array, ""], Float[Array, ""], Float[Array, ""]]: """ Calculate SNR using differences between adjacent frames. Assumes tissue moves slowly, so frame-to-frame differences primarily represent noise. Parameters ---------- data Complex I/Q data with shape (repetition, space, time). Returns ------- tuple Tuple (snr_db, signal_power, noise_power), where: - snr_db: signal-to-noise power-ratio in decibels (10*log10(signal_power/noise_power)) - signal_power: signal power - noise_power: noise power """ # i.e. we assume that we can average out the noise across 2 adjacent frames signal = (data[:-1] + data[1:]) / 2 # Assume noise is difference from the rolling mean noise = jnp.diff(data, axis=0) / 2 # Calculate signal power (magnitude squared) signal_power = jnp.abs(signal) ** 2 signal_power = signal_power.sum() # Calculate noise power (magnitude squared) noise_power = jnp.abs(noise) ** 2 noise_power = noise_power.sum() # Calculate SNR snr = signal_power / noise_power # Convert to decibels snr_db = 10 * jnp.log10(snr) return snr_db, signal_power, noise_power def channel_noise( data: Num[Array, "repetition ..."], method: NoiseEstimationMethod ) -> tuple[Float[Array, ""], Float[Array, ""], Float[Array, ""]]: """ Noise estimation method, helper function to select the appropriate method. Parameters ---------- data Complex I/Q data with shape (repetition, space, time). method Noise estimation method. Returns ------- tuple Tuple (snr_db, signal_power, noise_power). """ result: tuple[Float[Array, ""], Float[Array, ""], Float[Array, ""]] if method == NoiseEstimationMethod.REPEATED_MEASUREMENTS: result = noise_repeated_measurements(data) return result elif method == NoiseEstimationMethod.ADJACENT_FRAMES: result = noise_adjacent_frames(data) return result else: raise ValueError(f"Invalid noise estimation method: {method}") if __name__ == "__main__": import time # Example usage print("Example usage of noise_repeated_measurements")
[docs] key = jax.random.PRNGKey(0)
# Create example data: 10 loops, 64 channels, 100 time points # Signal component (same across loops but with different amplitude per channel/time) signal = jax.random.normal(key, (1, 64, 100)) + 1j * jax.random.normal(key, (1, 64, 100)) # Add noise component (different for each loop) noise_scale = 0.1 noise = noise_scale * (jax.random.normal(key, (10, 64, 100)) + 1j * jax.random.normal(key, (10, 64, 100))) # Combine signal and noise data = signal + noise # First call will compile for method in NoiseEstimationMethod: start = time.time() snr_db, signal_power, noise_power = channel_noise(data, method) end = time.time() print(f"First call with {method=!s} took: {end - start} seconds") print(f"{method=!s}: {jnp.mean(snr_db):.2f} dB") # Subsequent calls will be faster for method in NoiseEstimationMethod: start = time.time() snr_db, signal_power, noise_power = channel_noise(data, method) end = time.time() print(f"Second call with {method=!s} took: {end - start} seconds")