"""
Signal-to-Noise Ratio (SNR) Metric for Raw RF Data
====================================================

This example demonstrates how to:

1. Load raw RF data from the PICMUS dataset
2. Visualize the RF signal to identify regions of interest
3. Automatically detect signal and noise regions
4. (Optional) Manually adjust parameters based on what you see
5. Compute and print the RF SNR

This workflow uses real raw RF data from the PICMUS challenge dataset [1]_.
The SNR calculation follows standard ultrasound analysis practices, with automated
region detection inspired by analysis tools like ultraspy [2]_.

References
----------
.. [1] H. Liebgott, A. Rodriguez-Molares, F. Cervenansky, J. A. Jensen and O. Bernard,
       "Plane-Wave Imaging Challenge in Medical Ultrasound," 2016 IEEE International
       Ultrasonics Symposium (IUS), Tours, France, 2016, pp. 1-4,
       doi: 10.1109/ULTSYM.2016.7728908.

.. [2] P. Ecarlat, E. Carcreff, F. Varray, H. Liebgott and B. Nicolas, "Get Ready to Spy on Ultrasound: Meet ultraspy,"
       2023 IEEE International Ultrasonics Symposium (IUS), Montreal, QC, Canada, 2023, pp. 1-4,
       doi: 10.1109/IUS51837.2023.10307778.

Example
-------

"""
# %%
# 1. Import required modules and functions
# ------------------------------------------------------------------------

import matplotlib.pyplot as plt

from ultrasound_metrics.data.uff import inspect_dataset, load_dataset
from ultrasound_metrics.metrics.rf_snr import compute_rf_snr, find_signal_and_noise

# %%
# 2. Load the ultrasound dataset and extract raw RF data
# ------------------------------------------------------------------------

# Load raw RF data directly from the PICMUS dataset
dataset_info = inspect_dataset("picmus_resolution_experiment")
print(f"Dataset: {dataset_info['name']}")
print(f"Description: {dataset_info['description']}")

# Load raw RF data from channel_data
print("\n=== Loading Raw RF Data ===")
channel_data = load_dataset("picmus_resolution_experiment", key="/channel_data")
print(f"Raw channel data shape: {channel_data.shape}")
print(f"Raw channel data type: {channel_data.dtype}")

# Extract a single channel for RF SNR analysis
# Channel data shape: (samples, channels, frames)
samples, channels, frames = channel_data.shape
channel_idx = channels // 2  # Middle channel
frame_idx = 0  # First frame

# Extract single channel RF data
rf_data = channel_data[:, channel_idx, frame_idx]
print(f"Extracted RF data shape: {rf_data.shape}")
print(f"Channel index: {channel_idx} (out of {channels})")
print(f"Frame index: {frame_idx} (out of {frames})")
print(f"Number of samples: {samples}")

# %%
# 3. Visualize the RF data to help identify regions
# ------------------------------------------------------------------------

plt.figure(figsize=(15, 8))

# Plot the full RF signal
plt.subplot(2, 1, 1)
plt.plot(rf_data, "b-", linewidth=0.5, alpha=0.8)
plt.title(f"Full RF Signal - Channel {channel_idx}")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude")
plt.grid(True, alpha=0.3)

# Add some reference lines to help identify regions
plt.axhline(y=0, color="k", linestyle="--", alpha=0.5)
plt.axhline(y=rf_data.std(), color="g", linestyle=":", alpha=0.7, label=f"±1std = ±{rf_data.std():.3f}")
plt.axhline(y=-rf_data.std(), color="g", linestyle=":", alpha=0.7)
plt.legend()

# Plot a zoomed view of the first 500 samples to see startup artifacts
plt.subplot(2, 1, 2)
zoom_samples = min(500, len(rf_data))
plt.plot(rf_data[:zoom_samples], "r-", linewidth=1)
plt.title(f"Zoomed View - First {zoom_samples} Samples")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude")
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color="k", linestyle="--", alpha=0.5)

plt.tight_layout()
plt.show()

# %%
# 4. Manual parameter adjustment based on visualization
# ------------------------------------------------------------------------

print("\n=== Manual Parameter Adjustment ===")
print("Based on the visualization above, you can adjust these parameters:")

# Current parameters - users can modify these based on what they see
ignore_until = 10  # Skip initial startup artifacts
signal_width = 50  # Width of signal region around maximum
noise_offset = 100  # Gap between signal and noise regions

print("Current parameters:")
print(f"- ignore_until: {ignore_until} (skip first {ignore_until} samples)")
print(f"- signal_width: {signal_width} (signal region width)")
print(f"- noise_offset: {noise_offset} (gap between signal and noise)")

print("\nTo adjust these parameters:")
print("1. Look at the RF signal plot above")
print("2. Identify where startup artifacts end (usually first 50-200 samples)")
print("3. Modify the 'ignore_until' parameter below")
print("4. Re-run the analysis")

# %%
# 5. Automatic region detection with user-defined parameters
# ------------------------------------------------------------------------

print("\n=== Automatic Region Detection ===")
print(f"Using parameters: ignore_until={ignore_until}, signal_width={signal_width}, noise_offset={noise_offset}")

signal_data_auto, noise_data_auto = find_signal_and_noise(
    rf_data, signal_width=signal_width, noise_offset=noise_offset, ignore_until=ignore_until, show=True
)

# Get the automatic region boundaries (approximate calculation)
# Note: This is an approximation since find_signal_and_noise doesn't return exact indices
auto_signal_samples = len(signal_data_auto)
auto_noise_samples = len(noise_data_auto)

print("\nAutomatic region summary:")
print(f"- Signal: {auto_signal_samples} samples")
print(f"- Noise: {auto_noise_samples} samples")

# %%
# 6. Manual adjustment of automatic regions (OPTIONAL)
# ------------------------------------------------------------------------

print("\n=== Manual Adjustment of Automatic Regions ===")
print("The automatic detection may not capture the ideal regions.")
print("You can manually adjust the bounds based on what you see in the plot above.")

# Manual adjustment parameters (modify these based on visualization)
print("\n=== Manual Adjustment Parameters ===")
print("Modify these values based on the visualization above:")

# Example: Restrict noise to samples up to the 200th sample
manual_signal_start = 20  # Start of signal region (big spike)
manual_signal_end = 40  # End of signal region
manual_noise_start = 40  # Start of noise region
manual_noise_end = 200  # Limit noise region to first 200 samples

print("Manual adjustments:")
print(f"- Signal: samples {manual_signal_start}-{manual_signal_end} (manual selection)")
print(f"- Noise: samples {manual_noise_start}-{manual_noise_end} (limited to 200)")

# Extract manually adjusted regions
signal_data = rf_data[manual_signal_start:manual_signal_end]
noise_data = rf_data[manual_noise_start:manual_noise_end]

print("\nAdjusted regions:")
print(f"- Signal: {len(signal_data)} samples")
print(f"- Noise: {len(noise_data)} samples")


# %%
# 7. Print region summary
# ------------------------------------------------------------------------

signal_samples = len(signal_data)
noise_samples = len(noise_data)

print("\n=== Region Selection Summary ===")
print(f"Signal region: {signal_samples} samples (manually adjusted)")
print(f"Noise region: {noise_samples} samples (manually adjusted)")

# %%
# 8. Print basic statistics
# ------------------------------------------------------------------------

print("\n=== Region Statistics ===")
print(f"Signal mean: {signal_data.mean():.3f}")
print(f"Signal std: {signal_data.std():.3f}")
print(f"Noise mean: {noise_data.mean():.3f}")
print(f"Noise std: {noise_data.std():.3f}")

# %%
# 9. Compute and print RF SNR
# ------------------------------------------------------------------------

snr_db = compute_rf_snr(signal_data, noise_data, show=True)

print("\n=== RF SNR Results ===")
print(f"RF SNR: {snr_db:.2f} dB")
print("Data source: Raw RF data from /channel_data (manually adjusted regions)")

# %%
# 10. Compare different SNR calculation methods
# ------------------------------------------------------------------------

print("\n=== SNR Method Comparison ===")

# Standard SNR (centered, average power)
snr_standard = compute_rf_snr(signal_data, noise_data, center=True, max_power=False)
print(f"Standard SNR (centered, avg power): {snr_standard:.2f} dB")

# SNR with max power
snr_max_power = compute_rf_snr(signal_data, noise_data, center=True, max_power=True)
print(f"SNR with max power: {snr_max_power:.2f} dB")

# SNR without centering
snr_uncentered = compute_rf_snr(signal_data, noise_data, center=False, max_power=False)
print(f"SNR without centering: {snr_uncentered:.2f} dB")
