Source code for CompSensePack.eval

"""
eval.py - Functions for Signal-to-Noise Ratio (SNR) calculation and signal plotting.

This module provides utility functions for:
1. Calculating the Signal-to-Noise Ratio (SNR) between an original and a reconstructed signal.
2. Plotting the original and reconstructed signals together, with options to display the SNR and save the plot.

Note:
-----
- Requires `matplotlib` and `numpy`.
- Useful for comparing signals and visualizing differences in various signal processing applications.

Example usage:
--------------
>>> import numpy as np
>>> from eval import calculate_snr, plot_signals

>>> original_signal = np.random.rand(100)
>>> reconstructed_signal = np.random.rand(100)

>>> snr = calculate_snr(original_signal, reconstructed_signal)
>>> plot_signals(original_signal, reconstructed_signal, snr=snr, save_path='./plots')
"""

import os
import numpy as np
import matplotlib.pyplot as plt


[docs] def calculate_snr(signal, recovered_signal): """ Calculates the Signal-to-Noise Ratio (SNR) between the original signal and the recovered signal. Parameters ---------- signal : numpy.ndarray The original signal. recovered_signal : numpy.ndarray The recovered signal after some processing or recovery algorithm. Returns ------- snr : float The Signal-to-Noise Ratio (SNR) in decibels (dB). Notes ----- - The SNR is calculated as 20 * log10(norm(signal) / norm(signal - recovered_signal)). - A higher SNR value indicates a better recovery, with less error relative to the original signal. - If the signals are identical, the SNR would be infinite. If the recovered signal has no similarity, the SNR will be very low or negative. Example ------- >>> original = np.random.rand(100) >>> recovered = original + np.random.normal(0, 0.1, 100) >>> snr = calculate_snr(original, recovered) >>> print(f"SNR: {snr:.2f} dB") """ # Ensure the signals are numpy arrays if not isinstance(signal, np.ndarray) or not isinstance(recovered_signal, np.ndarray): raise ValueError("Both signal and recovered_signal must be numpy arrays.") # Calculate the error between the signals error = recovered_signal - signal # Calculate and return the SNR in dB snr = 20 * np.log10(np.linalg.norm(signal) / np.linalg.norm(error)) return snr
[docs] def plot_signals(original_signal, reconstructed_signal, snr=None, original_name="Original Signal", reconstructed_name="Reconstructed Signal", save_path=None, filename=None, start_pct=0.0, num_samples=None, show_snr_box=False): """ Plots a section of the original signal and the reconstructed signal on the same plot with the given names, displays the Signal-to-Noise Ratio (SNR) in a text box if enabled, and saves the plot to a specified directory. Parameters ---------- original_signal : numpy.ndarray The original signal to be plotted. reconstructed_signal : numpy.ndarray The reconstructed signal to be plotted. snr : float, optional (default=None) The Signal-to-Noise Ratio to display. If None, it will be computed using the original and reconstructed signals. original_name : str, optional (default="Original Signal") The name to display for the original signal in the plot. reconstructed_name : str, optional (default="Reconstructed Signal") The name to display for the reconstructed signal in the plot. save_path : str, optional The directory path where the plot should be saved. If None, the plot will not be saved. filename : str, optional The name of the file to save the plot as. If None and save_path is provided, a default name will be generated. start_pct : float, optional (default=0.0) The percentage (between 0 and 1) of the way through the signal to start plotting. For example, 0.5 means start from the halfway point of the signals. num_samples : int, optional (default=None) The number of samples to plot from the start point. If None, it will plot to the end of the signals. show_snr_box : bool, optional (default=False) Whether to display the SNR value in a text box on the plot. Returns ------- None This function does not return any value. It either displays or saves the plot. Notes ----- - Ensure the original and reconstructed signals have the same length; otherwise, a ValueError will be raised. - The plot shows a section of the signals starting at `start_pct` and plots `num_samples` samples. - The SNR can be displayed in a text box if `show_snr_box=True` and the SNR value is provided or calculated. Example ------- >>> original = np.sin(np.linspace(0, 10, 100)) >>> reconstructed = original + np.random.normal(0, 0.1, 100) >>> plot_signals(original, reconstructed, snr=20, save_path='./plots') """ # Ensure the signals have the same length if len(original_signal) != len(reconstructed_signal): raise ValueError("The original signal and the reconstructed signal must have the same length.") # Calculate the start index based on percentage start_idx = int(start_pct * len(original_signal)) # Determine the end index based on num_samples if num_samples is not None: end_idx = start_idx + num_samples else: end_idx = len(original_signal) # Ensure that the end index does not exceed the signal length end_idx = min(end_idx, len(original_signal)) # Slice the signals to the selected section original_signal_section = original_signal[start_idx:end_idx] reconstructed_signal_section = reconstructed_signal[start_idx:end_idx] # Calculate SNR if not provided if snr is None and show_snr_box: snr = calculate_snr(original_signal_section, reconstructed_signal_section) # Create the plot plt.figure(figsize=(10, 6)) plt.plot(original_signal_section, label=original_name, color='#1f77b4', linewidth=1.5) plt.plot(reconstructed_signal_section, label=reconstructed_name, color='#ff7f0e', linestyle='--', linewidth=1.5) # Title and labels plt.title(f"{original_name} vs {reconstructed_name} (Section: {start_pct*100:.1f}% - {num_samples} samples)") plt.xlabel('Sample Index') plt.ylabel('Amplitude') # Add a legend in the upper-right corner with a white background plt.legend(loc='upper right', frameon=True, facecolor='white') # Display SNR in a text box in the top-left corner if show_snr_box is True if show_snr_box and snr is not None: plt.text(0.05, 0.95, f'SNR: {snr:.2f} dB', transform=plt.gca().transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) # Grid and show plot plt.grid(True) # Save the plot if a save path is provided if save_path is not None: # Ensure the save directory exists os.makedirs(save_path, exist_ok=True) # Use provided filename or generate a default one if filename is None: filename = f"{original_name}_vs_{reconstructed_name}_section.png" # Define the file path to save the plot file_path = os.path.join(save_path, filename) plt.savefig(file_path) print(f"Plot saved to {file_path}") # Display the plot plt.show()