Source code for plotez.plotez

"""
PlotEZ - Mundane plotting made easy.

This module provides simplified plotting functions for common visualization tasks.
"""

from __future__ import annotations

__all__ = [
    "plot_errorband",
    "plot_errorband_relative",
    "plot_errorbar",
    "plot_two_column_file",
    "plot_xy",
    "plot_xyy",
    "plot_xxy",
    "plot_with_dual_axes",
    "two_subplots",
    "n_plotter",
    "plot_density",
    "plot_hist",
]

from warnings import warn

import matplotlib.pyplot as plt
import numpy as np

from .backend import (
    ErrorBandConfig,
    ErrorPlotConfig,
    HistogramConfig,
    LinePlotConfig,
    ScatterPlotConfig,
    dual_axes_data_validation,
    plot_or_scatter,
)
from .backend.error_handling import ColumnCountError, ConfigurationError, OrientationError, ShapeError
from .typing import ArrayLike, Axes, AxesReturn, NDArray

# =============================================================================
# Error Visualization Functions
# =============================================================================


[docs] def plot_errorband_relative( x_data: ArrayLike, y_data: ArrayLike, y_lower: int | float | ArrayLike | None = None, y_upper: int | float | ArrayLike | None = None, x_label: str = "X", y_label: str = "Y", plot_title: str = "XY ErrorBand", data_label: str = "X vs. Y", line: bool = True, band_config: ErrorBandConfig | None = None, line_config: LinePlotConfig | dict | None = None, axis: Axes | None = None, figure_kwargs: dict | None = None, ) -> Axes: """ Plot a line graph with a shaded error band using relative (offset) errors. A convenience wrapper around :func:`plot_errorband` where ``y_lower`` and ``y_upper`` are interpreted as offsets from ``y_data`` rather than absolute bounds. Internally, the absolute bounds are computed as ``y_data - y_lower`` and ``y_data + y_upper`` before passing to :func:`plot_errorband`. Parameters ---------- x_data : The independent variable values to plot. y_data : The central values to plot. y_lower : The downward offset from ``y_data`` defining the lower band edge. If ``None``, it is inferred as equal to ``y_upper``, implying a symmetric band. At least one of ``y_lower`` or ``y_upper`` must be provided. y_upper : The upward offset from ``y_data`` defining the upper band edge. If ``None``, it is inferred as equal to ``y_lower``, implying a symmetric band. At least one of ``y_lower`` or ``y_upper`` must be provided. x_label : The label for the x-axis. y_label : The label for the y-axis. plot_title : The title of the plot. data_label : The label for the data series, used in the legend. If ``line=True``, the label is attached to the line. If ``line=False``, it is attached to the band. line : Whether to draw a line through the central values over the error band. band_config : Configuration for the error band styling. If ``None``, defaults are used. line_config : Configuration for the line styling. If ``None``, defaults are used. axis : Pre-existing Matplotlib axes to draw on. If provided, ``figure_kwargs`` is ignored. figure_kwargs : Keyword arguments passed to ``plt.subplots`` when creating a new figure. Ignored if ``axis`` is provided. Returns ------- Axes The Matplotlib Axes on which the plot was drawn. Raises ------ ConfigurationError If both ``y_lower`` and ``y_upper`` are ``None``. See Also -------- plot_errorband : The absolute-bounds version of this function. """ x, y = np.asarray(x_data), np.asarray(y_data) return plot_errorband( x_data=x, y_data=y, y_lower=(y - np.asarray(y_lower)) if y_lower is not None else None, y_upper=(y + np.asarray(y_upper)) if y_upper is not None else None, x_label=x_label, y_label=y_label, plot_title=plot_title, data_label=data_label, line=line, band_config=band_config, line_config=line_config, axis=axis, figure_kwargs=figure_kwargs, )
[docs] def plot_errorband( x_data: ArrayLike, y_data: ArrayLike, y_lower: int | float | ArrayLike | None = None, y_upper: int | float | ArrayLike | None = None, x_label: str = "X", y_label: str = "Y", plot_title: str = "XY ErrorBand", data_label: str = "X vs. Y", line: bool = True, band_config: ErrorBandConfig | None = None, line_config: LinePlotConfig | dict | None = None, axis: Axes | None = None, figure_kwargs: dict | None = None, ) -> Axes: """ Plot a line graph with a shaded error band representing uncertainty. Parameters ---------- x_data : The independent variable values to plot. y_data : The central values to plot. y_lower : The lower bound of the error band. If ``None``, it is inferred as a symmetric reflection of ``y_upper`` through ``y_data``. At least one of ``y_lower`` or ``y_upper`` must be provided. y_upper : The upper bound of the error band. If ``None``, it is inferred as a symmetric reflection of ``y_lower`` through ``y_data``. At least one of ``y_lower`` or ``y_upper`` must be provided. x_label : The label for the x-axis. y_label : The label for the y-axis. plot_title : The title of the plot. data_label : The label for the data series, used in the legend. If ``line=True``, the label is attached to the line. If ``line=False``, it is attached to the band. line : Whether to draw a line through the central values over the error band. band_config : Configuration for the error band styling. If ``None``, defaults are used. line_config : Configuration for the line styling. If ``None``, defaults are used. axis : Pre-existing Matplotlib axes to draw on. If provided, ``figure_kwargs`` is ignored. figure_kwargs : Keyword arguments passed to ``plt.subplots`` when creating a new figure. Ignored if ``axis`` is provided. Returns ------- Axes The Matplotlib Axes on which the plot was drawn. Raises ------ ConfigurationError If both ``y_lower`` and ``y_upper`` are ``None``. """ x, y = np.asarray(x_data), np.asarray(y_data) if y_lower is None and y_upper is None: raise ConfigurationError("At least one of `y_lower` or `y_upper` must be provided for the error band.") if y_lower is not None: y_lower = np.asarray(y_lower) if y_upper is not None: y_upper = np.asarray(y_upper) if y_lower is None: y_lower = y - (y_upper - y) elif y_upper is None: y_upper = y + (y - y_lower) if axis is not None: ax = axis if figure_kwargs: warn("`figure_kwargs` is ignored when `axis` is provided.", UserWarning, stacklevel=2) else: _, ax = plt.subplots(**(figure_kwargs or {})) error_band_config = band_config.get_dict() if band_config else ErrorBandConfig().get_dict() if isinstance(line_config, dict): line_config: LinePlotConfig = LinePlotConfig.populate(line_config) l_conf = line_config.get_dict() if line_config else LinePlotConfig().get_dict() if line: _data_label = data_label or l_conf.get("label") or None if data_label and "label" in l_conf: warn("Both `data_label` and `line_config['label']` are provided. Using `data_label`.") l_conf.pop("label", None) ax.fill_between(x, y_lower, y_upper, **error_band_config) ax.plot(x, y, label=_data_label, **l_conf) else: ax.fill_between(x, y_lower, y_upper, label=data_label, **error_band_config) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(plot_title) ax.legend() return ax
[docs] def plot_errorbar( x_data: ArrayLike, y_data: ArrayLike, x_err: int | float | ArrayLike | None = None, y_err: int | float | ArrayLike | None = None, x_label: str = "X", y_label: str = "Y", plot_title: str = "XY ErrorBar", data_label: str = "X vs. Y", errorbar_config: ErrorPlotConfig | None = None, axis: Axes | None = None, figure_kwargs: dict | None = None, ) -> Axes: """ Plot an error bar graph with optional error ranges, labels, and configurations. Parameters ---------- x_data : The x-coordinates of the data points. y_data : The y-coordinates of the data points. x_err : Error margins for x-coordinates. Can be: - Scalar: symmetric error for all points - 1D array (N,): symmetric errors, one per point - 2D array (2, N): asymmetric [lower_errors, upper_errors] y_err : Error margins for y-coordinates. Can be: - Scalar: symmetric error for all points - 1D array (N,): symmetric errors, one per point - 2D array (2, N): asymmetric [lower_errors, upper_errors] x_label : The label for the x-axis. y_label : The label for the y-axis. plot_title : The title of the plot. data_label : The label for the data points, which will appear in the plot legend. If `None`, the legend is not displayed. errorbar_config : Custom configurations for the error bars. If `None`, default configurations are used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : A matplotlib Axes object on which the plot will be rendered. If `None`, a new subplot is created using ``figure_kwargs``. Returns ------- Axes The Matplotlib Axes on which the plot was drawn. """ x, y = np.asarray(x_data), np.asarray(y_data) if x_err is not None: x_err: NDArray = np.asarray(x_err) if x_err.ndim == 2 and x_err.shape[0] != 2: raise ShapeError(f"Asymmetric x_err must have shape (2, N), got {x_err.shape}") if y_err is not None: y_err: NDArray = np.asarray(y_err) if y_err.ndim == 2 and y_err.shape[0] != 2: raise ShapeError(f"Asymmetric y_err must have shape (2, N), got {y_err.shape}") if axis is not None: if figure_kwargs: warn("`figure_kwargs` is ignored when `axis` is provided.", UserWarning, stacklevel=2) ax = axis else: _, ax = plt.subplots(**(figure_kwargs or {})) ebc = errorbar_config.get_dict() if errorbar_config else ErrorPlotConfig().get_dict() ax.errorbar(x, y, xerr=x_err, yerr=y_err, label=data_label, **ebc) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(plot_title) ax.legend() return ax
# ============================================================================= # File I/O Functions # =============================================================================
[docs] def plot_two_column_file( file_name: str, delimiter: str = ",", skip_header: bool = False, x_label: str = "X", y_label: str = "Y", data_label: str = "XY Data", plot_title: str = "XY Plot", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, axis: Axes | None = None, ) -> Axes: """Read a two-column file (x, y) and plot the data. Parameters ---------- file_name : The path to the file to be plotted. The file should contain two columns (x and y data). delimiter : The delimiter used in the file (default is ','). skip_header : If True, skips the first row in the given data file, otherwise does nothing. Default is False. x_label : The label for the x-axis. y_label : The label for the y-axis. data_label : Data label for the plot to put in the legend. Defaults to 'X vs Y'. plot_title : The title for the plot. is_scatter : If True, creates a scatter plot. Otherwise, creates a line plot. Default is False. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : The axis object to draw the plots on. If not passed, a new axis object will be created internally. Returns ------- Axes The axes object of the plot. Raises ------ ColumnCountError If the file does not contain exactly two columns. """ data = np.genfromtxt(file_name, delimiter=delimiter, skip_header=skip_header) if data.shape[1] != 2: raise ColumnCountError("The file must contain exactly two columns of data.") x_data, y_data = data.T return plot_xy( x_data=x_data, y_data=y_data, x_label=x_label, y_label=y_label, data_label=data_label, plot_title=plot_title, is_scatter=is_scatter, plot_config=plot_config, figure_kwargs=figure_kwargs, axis=axis, )
# ============================================================================= # Simple Plotting Functions # =============================================================================
[docs] def plot_xy( x_data: ArrayLike, y_data: ArrayLike, x_label: str = "X", y_label: str = "Y", data_label: str = "XY Data", plot_title: str = "XY Plot", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, axis: Axes | None = None, ) -> Axes: """Plot the x_data against y_data with customizable options. Parameters ---------- x_data : The data for the x-axis. y_data : The data for the y-axis. x_label : The label for the x-axis. y_label : The label for the y-axis. plot_title : The title for the plot. data_label : Data label for the plot to put in the legend. is_scatter : If True, creates a scatter plot. Otherwise, creates a line plot. Default is False. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : The axis object to draw the plots on. If not passed, a new axis object will be created internally. Returns ------- Axes The axes object of the plot. """ _axis = plot_with_dual_axes( x1_data=x_data, y1_data=y_data, x1y1_label=data_label, use_twin_x=False, axis_labels=(x_label, y_label, ""), plot_title=plot_title, is_scatter=is_scatter, plot_config=plot_config, figure_kwargs=figure_kwargs, axis=axis, ) _axis: Axes return _axis
# ============================================================================= # Dual-Axis Plotting Functions # =============================================================================
[docs] def plot_xyy( x_data: ArrayLike, y1_data: ArrayLike, y2_data: ArrayLike, x_label: str = "X", y1_label: str = r"Y$_1$", y2_label: str = r"Y$_2$", data_labels: list[str] | tuple[str, ...] | None = None, # noqa plot_title: str = "XYY Plot", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, axis: Axes | None = None, ) -> tuple[Axes, Axes]: """Plot two sets of y-data (`y1_data` and `y2_data`) against the same x-data (`x_data`) on the same plot. Parameters ---------- x_data : The x-axis data for both plots. y1_data : The first set of y-axis data to be plotted against `x_data`. y2_data : The second set of y-axis data to be plotted against `x_data`. x_label : The label for the x-axis. y1_label : The label for the first y-axis. y2_label : The label for the second y-axis. plot_title : The title for the plot. data_labels : The labels for the two datasets. Default is ``(r"X vs. Y$_1$", r"X vs. Y$_2$")``. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. is_scatter : Whether to create a scatter plot (`True`) or a line plot (`False`). Default is `False`. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : A Matplotlib axis to plot on. If `None`, a new axis is created. Default is `None`. Returns ------- tuple[Axes, Axes] A tuple of ``(primary_axis, secondary_axis)`` for the dual y-axis plot. """ _data_labels: list[str] = list(data_labels) if data_labels is not None else [r"X vs. Y$_1$", r"X vs. Y$_2$"] _axis = plot_with_dual_axes( x1_data=x_data, y1_data=y1_data, y2_data=y2_data, x1y1_label=_data_labels[0], x1y2_label=_data_labels[1], use_twin_x=True, axis_labels=(x_label, y1_label, y2_label), plot_title=plot_title, is_scatter=is_scatter, plot_config=plot_config, figure_kwargs=figure_kwargs, axis=axis, ) _axis: tuple[Axes, Axes] return _axis
[docs] def plot_xxy( x1_data: ArrayLike, x2_data: ArrayLike, y_data: ArrayLike, y_label: str = "Y", x1_label: str = r"X$_1$", x2_label: str = r"X$_2$", data_labels: list[str] | tuple[str, ...] | None = None, # noqa plot_title: str = "", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, axis: Axes | None = None, ) -> tuple[Axes, Axes]: """Plot two sets of x-data (`x1_data` and `x2_data`) against the same y-data (`y_data`) on the same plot. Parameters ---------- x1_data : The first set of x-axis data to be plotted against `y_data`. x2_data : The second set of x-axis data to be plotted against `y_data`. y_data : The y-axis data for both plots. x1_label : The label for the first x-axis. x2_label : The label for the second x-axis. y_label : The label for the y-axis. plot_title : The title for the plot. data_labels : The labels for the two datasets. Default is ``(r"Y vs. X$_1$", r"Y vs. X$_2$")``. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. is_scatter : Whether to create a scatter plot (`True`) or a line plot (`False`). Default is `False`. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : A Matplotlib axis to plot on. If `None`, a new axis is created. Default is `None`. Returns ------- tuple[Axes, Axes] A tuple of ``(primary_axis, secondary_axis)`` for the dual x-axis plot. """ _data_labels: list[str] = list(data_labels) if data_labels is not None else [r"Y vs. X$_1$", r"Y vs. X$_2$"] _axis = plot_with_dual_axes( x1_data=x1_data, y1_data=y_data, x2_data=x2_data, x1y1_label=_data_labels[0], x2y1_label=_data_labels[1], use_twin_x=False, axis_labels=(x1_label, y_label, x2_label), plot_title=plot_title, is_scatter=is_scatter, plot_config=plot_config, figure_kwargs=figure_kwargs, axis=axis, ) _axis: tuple[Axes, Axes] return _axis
[docs] def plot_with_dual_axes( x1_data: ArrayLike, y1_data: ArrayLike, x2_data: ArrayLike | None = None, y2_data: ArrayLike | None = None, x1y1_label: str = r"X$_1$ vs. Y$_1$", x1y2_label: str = r"X$_1$ vs. Y$_2$", x2y1_label: str = r"X$_2$ vs. Y$_1$", use_twin_x: bool = False, axis_labels: list[str] | tuple[str, ...] | None = None, # noqa plot_title: str = "DualAxesPlot", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, axis: Axes | None = None, ) -> AxesReturn: """Plot the data with options for dual axes (x or y) or single axis. Parameters ---------- x1_data : Data for the primary x-axis. y1_data : Data for the primary y-axis. x2_data : Data for the secondary x-axis (used for dual x-axis plots). y2_data : Data for the secondary y-axis (used for dual y-axis plots). x1y1_label : Label for the plot of X1 vs. Y1. x1y2_label : Label for the plot of X1 vs. Y2 (when using dual Y-axes). x2y1_label : Label for the plot of X2 vs. Y1 (when using dual X-axes). use_twin_x : If True, creates a dual y-axis plot. If False, creates a dual x-axis plot. Default is False. axis_labels : List of axis labels in the form ``[x_label, y_label1, y_label2]``. Defaults to ``["X", r"Y$_1$", r"Y$_2$"]`` when not provided. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. plot_title : Title of the plot. is_scatter : If True, creates scatter plot; otherwise, line plot. Default is False. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. axis : The axis object to draw the plots on. If not passed, a new axis object will be created internally. Returns ------- tuple[Axes, Axes] or Axes A tuple of ``(primary_axis, secondary_axis)`` when dual axes are used, otherwise a single ``Axes``. """ _axis_labels: list[str] = list(axis_labels) if axis_labels is not None else ["X", r"Y$_1$", r"Y$_2$"] x1_data, y1_data, x2_data, y2_data = dual_axes_data_validation( x1_data=x1_data, x2_data=x2_data, y1_data=y1_data, y2_data=y2_data, use_twin_x=use_twin_x, axis_labels=_axis_labels, ) if axis is not None: ax1 = axis else: _, ax1 = plt.subplots(**(figure_kwargs or {})) if plot_config is not None: plot_dict = plot_config.get_dict() else: plot_dict = LinePlotConfig().get_dict() dict1 = {key: (value[0] if isinstance(value, list) else value) for key, value in plot_dict.items()} plot_or_scatter(axes=ax1, scatter=is_scatter)(x1_data, y1_data, label=x1y1_label, **dict1) ax2 = None ax1.set_xlabel(_axis_labels[0]) ax1.set_ylabel(_axis_labels[1]) if plot_title: ax1.set_title(plot_title) dict2 = { key: ( value[1] if isinstance(value, list) and len(value) > 1 else (value[0] if isinstance(value, list) else value) ) for key, value in plot_dict.items() } if use_twin_x: ax2 = ax1.twinx() if y2_data is not None: plot_or_scatter(axes=ax2, scatter=is_scatter)(x1_data, y2_data, label=x1y2_label, **dict2) ax2.set_ylabel(_axis_labels[2]) elif x2_data is not None: ax2 = ax1.twiny() plot_or_scatter(axes=ax2, scatter=is_scatter)(x2_data, y1_data, label=x2y1_label, **dict2) ax2.set_xlabel(_axis_labels[2]) if x1y1_label or x1y2_label or x2y1_label: handles, labels = ax1.get_legend_handles_labels() if ax2: handles2, labels2 = ax2.get_legend_handles_labels() handles += handles2 labels += labels2 ax1.legend(handles, labels, loc="best") return (ax1, ax2) if ax2 else ax1
# ============================================================================= # Multi-Panel Plotting Functions # =============================================================================
[docs] def two_subplots( x_data: ArrayLike | list[ArrayLike], y_data: ArrayLike | list[ArrayLike], x_labels: list[str] | tuple[str, ...] | None = None, # noqa y_labels: list[str] | tuple[str, ...] | None = None, # noqa data_labels: list[str] | tuple[str, ...] | None = None, # noqa plot_title: str = "TwoSubPlots", subplot_titles: list[str] | tuple[str, ...] | None = None, # noqa orientation: str = "h", is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, ) -> NDArray: """Create two subplots arranged horizontally or vertically, with optional customization. Parameters ---------- x_data : List containing x-axis data arrays for each subplot. y_data : List containing y-axis data arrays for each subplot. x_labels : List of labels for the x-axes in each subplot. Defaults to ``[r"X$_1$", r"X$_2$"]``. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. y_labels : List of labels for the y-axes in each subplot. Defaults to ``[r"Y$_1$", r"Y$_2$"]``. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. data_labels : List of labels for the data series in each subplot. Defaults to ``[r"X$_1$ vs. Y$_1$", r"X$_2$ vs. Y$_2$"]``. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. plot_title : Title of the plot. subplot_titles : Titles for the individual subplots, if required. Passing a mutable ``list`` is deprecated; use a ``tuple`` instead. orientation : Orientation of the subplots, either ``'h'`` for horizontal or ``'v'`` for vertical. is_scatter : If `True`, plots data as scatter plots; otherwise, plots as line plots. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axes. Passed directly to ``plt.subplots``. Returns ------- NDArray A shaped ``(n_rows, n_cols)`` array of Matplotlib ``Axes`` objects. Raises ------ OrientationError If ``orientation`` is not ``'h'`` or ``'v'``. """ _x_labels: list[str] = list(x_labels) if x_labels is not None else [r"X$_1$", r"X$_2$"] _y_labels: list[str] = list(y_labels) if y_labels is not None else [r"Y$_1$", r"Y$_2$"] _data_labels: list[str] = list(data_labels) if data_labels is not None else [r"X$_1$ vs. Y$_1$", r"X$_2$ vs. Y$_2$"] _subplot_titles: list[str] = list(subplot_titles) if subplot_titles is not None else ["", ""] if orientation in ["h", "horizontal"]: n_rows, n_cols = 1, 2 elif orientation in ["v", "vertical"]: n_rows, n_cols = 2, 1 else: raise OrientationError("The orientation must be either 'h/horizontal' or 'v/vertical'.") return n_plotter( x_data=x_data, y_data=y_data, n_rows=n_rows, n_cols=n_cols, x_labels=_x_labels, y_labels=_y_labels, data_labels=_data_labels, plot_title=plot_title, subplot_titles=_subplot_titles, is_scatter=is_scatter, plot_config=plot_config, figure_kwargs=figure_kwargs, )
def _label_sanitizer( n_rows: int, n_cols: int, x_labels: list[str] | None, y_labels: list[str] | None, data_labels: list[str] | None, subplot_titles: list[str] | None, plot_title: str | None, ) -> tuple[list[str], list[str], list[str], list[str], str]: n = n_rows * n_cols def _pad(labels: list[str] | None, name: str) -> list[str]: if labels is None: return [""] * n if len(labels) < n: warn( f"`{name}` has {len(labels)} element(s) but a {n_rows}×{n_cols} grid " f"({n} subplots) was requested. Padding with empty strings for the remaining {n - len(labels)}.", UserWarning, stacklevel=3, ) return labels + [""] * (n - len(labels)) if len(labels) > n: warn( f"`{name}` has {len(labels)} element(s) but a {n_rows}×{n_cols} grid " f"({n} subplots) was requested. Trimming the last {len(labels) - n} element(s).", UserWarning, stacklevel=3, ) return labels[:n] return labels return ( _pad(labels=x_labels, name="x_labels"), _pad(labels=y_labels, name="y_labels"), _pad(labels=data_labels, name="data_labels"), _pad(labels=subplot_titles, name="subplot_titles"), plot_title or "", )
[docs] def n_plotter( x_data: ArrayLike | list[ArrayLike], y_data: ArrayLike | list[ArrayLike], n_rows: int, n_cols: int, x_labels: list[str] | None = None, y_labels: list[str] | None = None, data_labels: list[str] | None = None, plot_title: str | None = None, subplot_titles: list[str] | None = None, is_scatter: bool = False, plot_config: LinePlotConfig | ScatterPlotConfig | None = None, figure_kwargs: dict | None = None, ) -> NDArray: """ Plot multiple subplots in a grid with optional customization for each subplot. Parameters ---------- x_data : List of x-axis data arrays for each subplot. y_data : List of y-axis data arrays for each subplot. n_rows : Number of rows in the subplot grid. n_cols : Number of columns in the subplot grid. x_labels : List of labels for the x-axes of each subplot. y_labels : List of labels for the y-axes of each subplot. data_labels : List of labels for the data series in each subplot. plot_title : Title of the plot. subplot_titles : Titles for the individual subplots, if required. is_scatter : If `True`, plots data as scatter plots; otherwise, plots as line plots. plot_config : Configuration object for line or scatter styling. If None, a default ``LinePlotConfig`` is used. figure_kwargs : Keyword arguments for creating the figure and axes. Passed directly to ``plt.subplots``. Returns ------- NDArray A shaped ``(n_rows, n_cols)`` array of Matplotlib ``Axes`` objects. """ sp_dict = dict(figure_kwargs) if figure_kwargs else {} sp_dict.pop("nrows", None) sp_dict.pop("ncols", None) if isinstance(plot_config, ScatterPlotConfig) and not is_scatter: raise ConfigurationError( "`plot_config` is a `ScatterPlotConfig` but `is_scatter=False`. " "Set `is_scatter=True` or pass a `LinePlotConfig` instead." ) plot_items = plot_config.get_dict() if plot_config else LinePlotConfig().get_dict() # type: ignore fig, axs = plt.subplots(n_rows, n_cols, **sp_dict, squeeze=False) flat_axs = axs.flatten() main_dict = [ { key: (value[c % len(value)] if isinstance(value, (list, tuple)) else value) for key, value in plot_items.items() } for c in range(n_cols * n_rows) ] _x_labels, _y_labels, _data_labels, _subplot_titles, _plot_title = _label_sanitizer( n_rows=n_rows, n_cols=n_cols, x_labels=x_labels, y_labels=y_labels, data_labels=data_labels, subplot_titles=subplot_titles, plot_title=plot_title, ) shared_y = sp_dict.get("sharey", False) shared_x1 = sp_dict.get("sharex", False) shared_x2 = len(flat_axs) - int(len(flat_axs) / n_rows if n_rows > n_cols else n_cols) for index, ax, x_, y_, sp_ in zip(range(n_cols * n_rows), flat_axs, _x_labels, _y_labels, _subplot_titles): label = f"{_x_labels[index]} vs {_y_labels[index]}" if _data_labels is None else _data_labels[index] plot_or_scatter(axes=ax, scatter=is_scatter)(x_data[index], y_data[index], label=label, **main_dict[index]) if shared_x1: if not index < shared_x2: ax.set_xlabel(x_) else: ax.set_xlabel(x_) if not (shared_y and index % n_cols != 0): ax.set_ylabel(y_) if label: ax.legend(loc="best") ax.set_title(sp_) fig.suptitle(_plot_title) return axs
[docs] def plot_density( x_data: ArrayLike, x_label: str = "X", y_label: str = "Density", plot_title: str = "Density Plot", data_label: str | None = None, hist_config: HistogramConfig | dict | None = None, axis: Axes | None = None, figure_kwargs: dict | None = None, ) -> Axes: """ Plot a density histogram based on the given data and configuration. Parameters ---------- x_data : The data array used for generating the density plot. x_label : The label for the x-axis. Default is "X". y_label : The label for the y-axis. Default is "Density". plot_title : The title of the density plot. Default is "Density Plot". data_label : The optional label for the dataset being visualized. Default is None. hist_config : The histogram configuration, either as an instance of `HistogramConfig`, a dictionary, or None. If provided, it is used to configure the histogram and ensures that `density=True` is set. Default is None. axis : The Matplotlib Axes object on which to draw the plot. If None, a new set of axes is created. Default is None. figure_kwargs : Optional keyword arguments passed when creating a new Matplotlib figure. These arguments are ignored if an existing axis is provided. Default is None. Returns ------- Axes The Matplotlib Axes on which the density plot was drawn. """ if isinstance(hist_config, dict): if not hist_config.get("density"): warn("Setting `density=True` in `hist_config` for density plot.", UserWarning, stacklevel=2) hist_config = {**hist_config, "density": True} elif isinstance(hist_config, HistogramConfig): hist_config.density = True else: hist_config = HistogramConfig(density=True) return plot_hist( x_data=x_data, x_label=x_label, y_label=y_label, plot_title=plot_title, data_label=data_label, hist_config=hist_config, axis=axis, figure_kwargs=figure_kwargs, )
[docs] def plot_hist( x_data: ArrayLike, x_label: str = "X", y_label: str = "Counts", plot_title: str = "Histogram", data_label: str | None = None, hist_config: HistogramConfig | dict | None = None, axis: Axes | None = None, figure_kwargs: dict | None = None, ) -> Axes: """ Plot a histogram of the data. Parameters ---------- x_data : Array or sequence of data points to be histogrammed. x_label : Label for the x-axis. y_label : Label for the y-axis. plot_title : Title for the plot. data_label : Label(s) for the data series. This is used in plot's legend generation. hist_config : Configuration object for histogram styling. If `None`, default configurations are used. axis : An existing matplotlib axis object on which to plot. If `None`, a new figure and axis are created. figure_kwargs : Keyword arguments for creating the figure and axis when `axis` is not provided. Ignored if `axis` is provided. Returns ------- Axes The Matplotlib Axes on which the histogram was drawn. """ x = np.asarray(x_data) if isinstance(hist_config, dict): h_config = HistogramConfig.populate(hist_config).get_dict() elif isinstance(hist_config, HistogramConfig): h_config = hist_config.get_dict() else: h_config = HistogramConfig().get_dict() if axis is not None: ax = axis else: _, ax = plt.subplots(**(figure_kwargs or {})) if data_label and "label" in h_config: raise ConfigurationError("Both `data_label` and `hist_config['label']` cannot be provided.") if not h_config.get("bins"): h_config["bins"] = 32 ax.hist(x, label=data_label, **h_config) ax.set_xlabel(x_label) ax.set_ylabel("Density" if h_config.get("density") else y_label) ax.set_title(plot_title) if data_label: ax.legend() return ax