ctf4science.visualization_module.Visualization#

class ctf4science.visualization_module.Visualization(config_path: Path | str | None = None)#

Bases: object

Generates visualizations of model predictions and evaluation metrics.

Configuration is loaded from a YAML file (default or custom path) and stored in config.

Parameters:
config_pathstr or Path, optional

Path to a custom YAML config file. If None, uses default from package.

Methods

compare_prediction(truth, predictions[, ...])

Create a side-by-side comparison of truth, prediction(s), and error (2D data).

generate_all_plots(dataset_name, batch_path, ...)

Generate all applicable plot types for the dataset and save under each pair dir.

plot_errors(errors[, labels])

Plot error metrics over time or across sub-datasets.

plot_from_batch(dataset_name, pair_id, batch_id)

Plot data from a batch directory for a given plot type.

plot_histograms(truth, predictions, modes, bins)

Plot histograms of variables over the last modes time steps.

plot_prediction(ax, data[, vmin, vmax, ...])

Plot a 2D array on the given axes with optional color scale and labels.

plot_psd(truth, predictions, k, modes[, labels])

Plot log power spectral density over the last k time steps.

plot_trajectories(truth, predictions[, labels])

Plot stacked trajectories for each variable, comparing truth and predictions.

save_figure_results(fig, dataset_name, ...)

Save the figure to the results directory under a visualizations subfolder.

Raises:
FileNotFoundError

If the custom config file does not exist.

Notes

Class Methods:

plot_trajectories(self, truth, predictions, labels=None, **kwargs):

  • Plot stacked trajectories for each variable, comparing truth and predictions.

  • Parameters:
    • truth : ndarray. Ground truth, shape (time_steps, variables).

    • predictions : list of ndarray. List of prediction arrays, each same shape as truth.

    • labels : list of str, optional. Labels for each prediction.

    • **kwargs : Passed to config (e.g. figure_size).

  • Returns:
    • plt.Figure. The generated figure.

  • Raises ValueError if prediction shapes do not match truth.

plot_errors(self, errors, labels=None, **kwargs):

  • Plot error metrics over time or across sub-datasets.

  • Parameters:
    • errors : dict. Keys are metric names; values are lists of floats.

    • labels : list of str, optional. Labels for each error type.

    • **kwargs : Custom config overrides.

  • Returns:
    • plt.Figure. The generated figure.

  • Raises ValueError if errors is empty, lengths differ, or labels count mismatch.

plot_histograms(self, truth, predictions, modes, bins, labels=None, **kwargs):

  • Plot histograms over the last modes time steps per variable.

  • Parameters:
    • truth : ndarray. Ground truth, shape (time_steps, variables).

    • predictions : list of ndarray. Same shape as truth.

    • modes : int. Number of last time steps to use.

    • bins : int. Number of histogram bins.

    • labels : list of str, optional. Labels for each prediction.

    • **kwargs : Custom config overrides.

  • Returns:
    • plt.Figure. The generated figure.

  • Raises ValueError if modes > time_steps or shapes mismatch.

plot_psd(self, truth, predictions, k, modes, labels=None, **kwargs):

  • Plot log power spectral density over the last k time steps.

  • Parameters:
    • truth : ndarray. Ground truth, shape (time_steps, spatial_points).

    • predictions : list of ndarray. Same shape as truth.

    • k : int. Number of last time steps for PSD.

    • modes : int. Number of frequency modes to plot.

    • labels : list of str, optional. Labels for each prediction.

    • **kwargs : Custom config overrides.

  • Returns:
    • plt.Figure. The generated figure.

  • Raises ValueError if k or modes are invalid or shapes mismatch.

plot_from_batch(self, dataset_name, pair_id, batch_id, plot_type=’trajectories’, **kwargs):

  • Load data from a batch directory and produce one of the supported plot types.

  • Parameters:
    • dataset_name : str. Dataset name (e.g. 'ODE_Lorenz').

    • pair_id : int. Pair ID for the sub-dataset.

    • batch_id : str or path-like. Path to batch dir (predictions.npy, optionally evaluation_results.yaml).

    • plot_type : str, optional. One of 'trajectories', 'histograms', 'psd', 'errors', '2d_comparison'. Default 'trajectories'.

    • **kwargs : Passed to the underlying plot method.

  • Returns:
    • plt.Figure. The generated figure.

  • Raises FileNotFoundError if required files missing; ValueError if plot_type unsupported or data invalid.

generate_all_plots(self, dataset_name, batch_path, **kwargs):

  • Generate all applicable plot types for the dataset and save under each pair dir in batch_path.

  • Parameters:
    • dataset_name : str. Dataset name.

    • batch_path : str or path-like. Path to the batch directory (containing pair* subdirs).

    • **kwargs : Passed to plot_from_batch.

  • Returns:
    • None.

save_figure_results(self, fig, dataset_name, model_name, batch_name, pair_id, plot_type, results_dir=None):

  • Save the figure to the results directory under a visualizations subfolder.

  • Parameters:
    • fig : plt.Figure. The figure to save.

    • dataset_name : str. Name of the dataset.

    • model_name : str. Name of the model.

    • batch_name : str. Batch identifier.

    • pair_id : int. Sub-dataset identifier.

    • plot_type : str. Type of plot (e.g. 'trajectories', 'histograms').

    • results_dir : str or Path, optional. Base path for results; default is results/{dataset}/{model}/{batch}/pair{pair_id}/visualizations.

  • Returns:
    • None.

plot_prediction(self, ax, data, vmin=None, vmax=None, show_ticks=True, show_xlabel=False, show_ylabel=False):

  • Plot a 2D array on the given axes (e.g. for spatio-temporal data).

  • Parameters:
    • ax : matplotlib.axes.Axes. Axes to plot on.

    • data : ndarray. 2D array, shape (time_steps, spatial_dim).

    • vmin, vmax : float, optional. Color scale limits.

    • show_ticks : bool, optional. Whether to show axis ticks. Default True.

    • show_xlabel : bool, optional. Whether to show x-axis label. Default False.

    • show_ylabel : bool, optional. Whether to show y-axis label. Default False.

  • Returns:
    • matplotlib.image.AxesImage. The image from imshow.

compare_prediction(self, truth, predictions, cbar_options=None, show_ticks=True, show_titles=True):

  • Create a side-by-side comparison of truth, prediction(s), and error (2D data).

  • Parameters:
    • truth : ndarray. Ground truth, shape (time_steps, spatial_dim).

    • predictions : list of ndarray. Prediction arrays, same shape as truth.

    • cbar_options : dict, optional. Colorbar options (show, orientation, shrink, ticks, label).

    • show_ticks : bool, optional. Whether to show axis ticks. Default True.

    • show_titles : bool, optional. Whether to show subplot titles. Default True.

  • Returns:
    • plt.Figure. The comparison figure.