ctf4science.visualization_module.Visualization#
- class ctf4science.visualization_module.Visualization(config_path: Path | str | None = None)#
Bases:
objectGenerates visualizations of model predictions and evaluation metrics.
Configuration is loaded from a YAML file (default or custom path) and stored in config.
- Parameters:
- config_pathstr or Path, optional
Path to a custom YAML config file. If
None, uses default from package.
Methods
compare_prediction(truth, predictions[, ...])Create a side-by-side comparison of truth, prediction(s), and error (2D data).
generate_all_plots(dataset_name, batch_path, ...)Generate all applicable plot types for the dataset and save under each pair dir.
plot_errors(errors[, labels])Plot error metrics over time or across sub-datasets.
plot_from_batch(dataset_name, pair_id, batch_id)Plot data from a batch directory for a given plot type.
plot_histograms(truth, predictions, modes, bins)Plot histograms of variables over the last modes time steps.
plot_prediction(ax, data[, vmin, vmax, ...])Plot a 2D array on the given axes with optional color scale and labels.
plot_psd(truth, predictions, k, modes[, labels])Plot log power spectral density over the last k time steps.
plot_trajectories(truth, predictions[, labels])Plot stacked trajectories for each variable, comparing truth and predictions.
save_figure_results(fig, dataset_name, ...)Save the figure to the results directory under a visualizations subfolder.
- Raises:
- FileNotFoundError
If the custom config file does not exist.
Notes
Class Methods:
plot_trajectories(self, truth, predictions, labels=None, **kwargs):
Plot stacked trajectories for each variable, comparing truth and predictions.
- Parameters:
truth : ndarray. Ground truth, shape (time_steps, variables).
predictions : list of ndarray. List of prediction arrays, each same shape as truth.
labels : list of str, optional. Labels for each prediction.
**kwargs : Passed to config (e.g. figure_size).
- Returns:
plt.Figure. The generated figure.
Raises
ValueErrorif prediction shapes do not match truth.
plot_errors(self, errors, labels=None, **kwargs):
Plot error metrics over time or across sub-datasets.
- Parameters:
errors : dict. Keys are metric names; values are lists of floats.
labels : list of str, optional. Labels for each error type.
**kwargs : Custom config overrides.
- Returns:
plt.Figure. The generated figure.
Raises
ValueErrorif errors is empty, lengths differ, or labels count mismatch.
plot_histograms(self, truth, predictions, modes, bins, labels=None, **kwargs):
Plot histograms over the last modes time steps per variable.
- Parameters:
truth : ndarray. Ground truth, shape (time_steps, variables).
predictions : list of ndarray. Same shape as truth.
modes : int. Number of last time steps to use.
bins : int. Number of histogram bins.
labels : list of str, optional. Labels for each prediction.
**kwargs : Custom config overrides.
- Returns:
plt.Figure. The generated figure.
Raises
ValueErrorif modes > time_steps or shapes mismatch.
plot_psd(self, truth, predictions, k, modes, labels=None, **kwargs):
Plot log power spectral density over the last k time steps.
- Parameters:
truth : ndarray. Ground truth, shape (time_steps, spatial_points).
predictions : list of ndarray. Same shape as truth.
k : int. Number of last time steps for PSD.
modes : int. Number of frequency modes to plot.
labels : list of str, optional. Labels for each prediction.
**kwargs : Custom config overrides.
- Returns:
plt.Figure. The generated figure.
Raises
ValueErrorif k or modes are invalid or shapes mismatch.
plot_from_batch(self, dataset_name, pair_id, batch_id, plot_type=’trajectories’, **kwargs):
Load data from a batch directory and produce one of the supported plot types.
- Parameters:
dataset_name : str. Dataset name (e.g.
'ODE_Lorenz').pair_id : int. Pair ID for the sub-dataset.
batch_id : str or path-like. Path to batch dir (predictions.npy, optionally evaluation_results.yaml).
plot_type : str, optional. One of
'trajectories','histograms','psd','errors','2d_comparison'. Default'trajectories'.**kwargs : Passed to the underlying plot method.
- Returns:
plt.Figure. The generated figure.
Raises
FileNotFoundErrorif required files missing;ValueErrorif plot_type unsupported or data invalid.
generate_all_plots(self, dataset_name, batch_path, **kwargs):
Generate all applicable plot types for the dataset and save under each pair dir in batch_path.
- Parameters:
dataset_name : str. Dataset name.
batch_path : str or path-like. Path to the batch directory (containing pair* subdirs).
**kwargs : Passed to plot_from_batch.
- Returns:
None.
save_figure_results(self, fig, dataset_name, model_name, batch_name, pair_id, plot_type, results_dir=None):
Save the figure to the results directory under a visualizations subfolder.
- Parameters:
fig : plt.Figure. The figure to save.
dataset_name : str. Name of the dataset.
model_name : str. Name of the model.
batch_name : str. Batch identifier.
pair_id : int. Sub-dataset identifier.
plot_type : str. Type of plot (e.g.
'trajectories','histograms').results_dir : str or Path, optional. Base path for results; default is
results/{dataset}/{model}/{batch}/pair{pair_id}/visualizations.
- Returns:
None.
plot_prediction(self, ax, data, vmin=None, vmax=None, show_ticks=True, show_xlabel=False, show_ylabel=False):
Plot a 2D array on the given axes (e.g. for spatio-temporal data).
- Parameters:
ax : matplotlib.axes.Axes. Axes to plot on.
data : ndarray. 2D array, shape (time_steps, spatial_dim).
vmin, vmax : float, optional. Color scale limits.
show_ticks : bool, optional. Whether to show axis ticks. Default True.
show_xlabel : bool, optional. Whether to show x-axis label. Default False.
show_ylabel : bool, optional. Whether to show y-axis label. Default False.
- Returns:
matplotlib.image.AxesImage. The image from
imshow.
compare_prediction(self, truth, predictions, cbar_options=None, show_ticks=True, show_titles=True):
Create a side-by-side comparison of truth, prediction(s), and error (2D data).
- Parameters:
truth : ndarray. Ground truth, shape (time_steps, spatial_dim).
predictions : list of ndarray. Prediction arrays, same shape as truth.
cbar_options : dict, optional. Colorbar options (show, orientation, shrink, ticks, label).
show_ticks : bool, optional. Whether to show axis ticks. Default True.
show_titles : bool, optional. Whether to show subplot titles. Default True.
- Returns:
plt.Figure. The comparison figure.