Source code for plotez.backend.utilities

"""
PlotEZ Backend Utilities.

Utility classes and functions for plot parameter management and data validation.
"""

from __future__ import annotations

__all__ = [
    "LinePlotConfig",
    "ScatterPlotConfig",
    "ErrorPlotConfig",
    "ErrorBandConfig",
    "plot_or_scatter",
    "split_dictionary",
    "dual_axes_data_validation",
    "dual_axes_label_management",
]

from dataclasses import dataclass, field
from typing import Any, Literal, Sequence
from warnings import warn

from numpy.typing import ArrayLike

from plotez.backend.CONSTANTS import ERROR_ATTRS, ERROR_BAND_ATTRS, LINE_ATTRS, SCATTER_ATTRS
from plotez.backend.error_handling import (
    AxisLabelError,
    EmptyDataError,
    LabelConflictWarning,
    TwinXDataError,
    TwinYDataError,
)

label_management = tuple[str, str, str, str, list[str]]


def _populate(_class, dictionary: dict[str, Any], mapping):
    """
    Create a config dataclass instance from a dictionary, applying key aliases.

    Parameters
    ----------
    _class :
        The dataclass type to instantiate.
    dictionary :
        Raw parameter dictionary, possibly using shorthand keys.
    mapping :
        Alias-to-canonical-name mapping for shorthand keys.

    Returns
    -------
    instance
        An instance of ``_class`` populated from the mapped dictionary.
    """
    mapped = {mapping.get(k, k): v for k, v in dictionary.items()}

    known_fields = _class.__dict__["__annotations__"].keys() - {"_extra"}

    known = {k: v for k, v in mapped.items() if k in known_fields}
    extra = {k: v for k, v in mapped.items() if k not in known_fields}

    return _class(**known, _extra=extra)


[docs] @dataclass class LinePlotConfig: """Configuration class for line plots.""" color: str | Sequence[str] | None = None linewidth: float | Sequence[float] | None = None linestyle: str | Sequence[str] | None = None alpha: float | Sequence[float] | None = None marker: str | Sequence[str] | None = None markersize: float | Sequence[float] | None = None markerfacecolor: str | Sequence[str] | None = None markeredgecolor: str | Sequence[str] | None = None markeredgewidth: float | Sequence[float] | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "LinePlotConfig": """Create a LinePlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=LINE_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ErrorBandConfig: """Configuration class for error bands (shaded fill regions).""" color: str | None = None alpha: float = 0.25 linewidth: float | None = None edgecolor: str | None = None linestyle: str | None = None hatch: str | Literal["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"] | None = None interpolate: bool | None = None step: str | Literal["pre", "post", "mid"] | None = None _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ErrorBandConfig": """Create an ErrorBandConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_BAND_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Return a string representation of the ErrorBandConfig instance.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ErrorPlotConfig: """Configuration class for error bar plots.""" # Core signal identity color: str | None = None linewidth: float | None = None linestyle: str | None = None alpha: float | None = None # Error structure (second layer of perception) ecolor: str | None = None elinewidth: float | None = None # Markers (data discreteness) marker: str | None = None markersize: float | None = None markerfacecolor: str | None = None markeredgecolor: str | None = None # Visual refinement capsize: float | None = None capthick: float | None = None errorevery: int | tuple | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ErrorPlotConfig": """Create an ErrorPlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ScatterPlotConfig: """Configuration class for scatter plots.""" color: str | None = None s: float | None = None alpha: float | None = None marker: str | None = None cmap: str | None = None edgecolors: str | None = None facecolors: str | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ScatterPlotConfig": """Create a ScatterPlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=SCATTER_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
LSE = LinePlotConfig | ScatterPlotConfig | ErrorPlotConfig | ErrorBandConfig
[docs] def plot_or_scatter(axes, scatter: bool): """ Return the plot or scatter method based on the specified plot type. Parameters ---------- axes : The matplotlib axis on which to apply the plot or scatter method. scatter : If True, returns the scatter method; otherwise, returns the plot method. Returns ------- function The matplotlib plotting method (`axes.scatter` if scatter is True, otherwise `axes.plot`). """ return axes.scatter if scatter else axes.plot
[docs] def split_dictionary(plot_instance: LSE) -> tuple[LSE, LSE]: """ Split a config instance's parameters into two separate instances. Parameters ---------- plot_instance : An instance with parameters stored as lists or tuples. Each parameter should be a list or tuple containing exactly two values, corresponding to settings for the two resulting instances. Returns ------- Tuple Two instances of the same type as `plot_instance`, with parameters split based on the values in `plot_instance`. The first instance (`instance1`) and second instance (`instance2`) will have their attributes set according to the first and second elements, respectively, from each list or tuple in `plot_instance`. Raises ------ ValueError If any parameter in `plot_instance` is not a list or tuple with exactly two elements. Notes ----- The parameters with only one element are broadcast to both instances rather than raising an error. """ parameters = plot_instance.get_dict() params_instance1, params_instance2 = {}, {} for param_name, values in parameters.items(): if isinstance(values, (list, tuple)) and len(values) >= 2: params_instance1[param_name] = values[0] params_instance2[param_name] = values[1] else: # Scalar or single-element: both instances get the same value params_instance1[param_name] = values params_instance2[param_name] = values instance1 = plot_instance.__class__.populate(params_instance1) instance2 = plot_instance.__class__.populate(params_instance2) return instance1, instance2
[docs] def dual_axes_data_validation( x1_data: ArrayLike, x2_data: ArrayLike, y1_data: ArrayLike, y2_data: ArrayLike, use_twin_x: bool, axis_labels: Sequence[str] | None, ) -> None: """ Validate the data and parameters for dual-axes plotting. Parameters ---------- x1_data : Data for the primary x-axis. x2_data : Data for the secondary x-axis (used in dual x-axis plots). Should be `None` if `use_twin_x` is True. y1_data : Data for the primary y-axis. y2_data : Data for the secondary y-axis (used in dual y-axis plots). Should be `None` if `use_twin_x` is False. use_twin_x : If True, a dual y-axis plot is expected; otherwise, a dual x-axis plot is expected. axis_labels : List of axis labels. Must have exactly three elements: - Label for the x-axis of the primary plot. - Label for the y-axis of the primary plot. - Label for the secondary axis (x or y). Raises ------ AxisLabelError If ``axis_labels`` does not have exactly three elements. EmptyDataError If ``x1_data`` or ``y1_data`` is empty. TwinXDataError If ``x2_data`` is provided when ``use_twin_x`` is ``True``. TwinYDataError If ``y2_data`` is provided when ``use_twin_x`` is ``False``. """ if isinstance(axis_labels, str): raise AxisLabelError( f"axis_labels must be a list of 3 strings, not a plain string. Did you mean ['{axis_labels}']?" ) if len(axis_labels) != 3: raise AxisLabelError("The axis_labels should have a length of 3.") if len(x1_data) == 0 or len(y1_data) == 0: raise EmptyDataError("Primary x or y data is empty. Please provide valid data.") if use_twin_x and x2_data is not None: raise TwinXDataError("Dual Y-axis plot requested but 'x2_data' given.") if not use_twin_x and y2_data is not None: raise TwinYDataError("Dual X-axis plot requested but 'y2_data' given.")
[docs] def dual_axes_label_management( x1y1_label: str | None = None, x1y2_label: str | None = None, x2y1_label: str | None = None, auto_label: bool = False, axis_labels: Sequence[str] | None = None, plot_title: str | None = None, use_twin_x: bool = True, ) -> label_management: """ Manage labels and titles for dual-axes plots. Parameters ---------- x1y1_label : str, optional Label for the primary plot (X1 vs. Y1). Ignored if `auto_label=True`. x1y2_label : str, optional Label for the secondary Y-axis plot (X1 vs. Y2), used if `use_twin_x` is True. Ignored if `auto_label=True`. x2y1_label : str, optional Label for the secondary X-axis plot (X2 vs. Y1), used if `use_twin_x` is False. Ignored if `auto_label=True`. auto_label : bool, default False If True, **overwrites all provided labels** with automatic defaults. When True, all label parameters are ignored. axis_labels : Sequence[str], optional Axis labels as [x_label, y1_label, y2_or_x2_label]. Ignored if `auto_label=True`. - Dual Y-axis: [primary x, primary y, secondary y] - Dual X-axis: [primary x, primary y, secondary x] plot_title : str, optional Plot title. Ignored if `auto_label=True`. use_twin_x : bool, default True If True, dual Y-axis plot. If False, dual X-axis plot. Returns ------- tuple[str, str, str, str, list[str]] (x1y1_label, x1y2_label, x2y1_label, plot_title, axis_labels) Notes ----- When `auto_label=True`, all user-provided labels are **replaced** with: - Dual Y-axis defaults: axis_labels=['X', 'Y₁', 'Y₂'], x1y1_label='X₁ vs. Y₁', x1y2_label='X₁ vs. Y₂' - Dual X-axis defaults: axis_labels=['X₁', 'Y', 'X₂'], x1y1_label='Y vs. X₁', x2y1_label='Y vs. X₂' - plot_title='Plot' When `auto_label=False`, missing labels are replaced with empty strings. """ # Warn if the user provided labels but `auto_label` is True if isinstance(axis_labels, str): raise AxisLabelError( f"axis_labels must be a list of 3 strings, not a plain string. Did you mean ['{axis_labels}']?" ) if auto_label: _auto_handler(axis_labels=axis_labels, x1y1_label=x1y1_label, x1y2_label=x1y2_label, x2y1_label=x2y1_label) if use_twin_x: axis_labels = ["X", r"$Y_1$", r"$Y_2$"] x1y1_label = r"$X_1$ vs $Y_1$" x1y2_label = r"$X_1$ vs $Y_2$" x2y1_label = "" else: axis_labels = [r"$X_1$", "Y", r"$X_2$"] x1y1_label = r"Y vs $X_1$" x1y2_label = "" x2y1_label = r"Y vs $X_2$" plot_title = "Plot" else: # Use provided values or empty strings axis_labels = list(axis_labels) if axis_labels else ["", "", ""] x1y1_label = x1y1_label or "" x1y2_label = x1y2_label or "" x2y1_label = x2y1_label or "" plot_title = plot_title or "" return x1y1_label, x1y2_label, x2y1_label, plot_title, axis_labels
def _auto_handler( axis_labels: Sequence[str] | None, x1y1_label: str | None, x1y2_label: str | None, x2y1_label: str | None ): provided_labels = [] if x1y1_label is not None: provided_labels.append("x1y1_label") if x1y2_label is not None: provided_labels.append("x1y2_label") if x2y1_label is not None: provided_labels.append("x2y1_label") if axis_labels is not None and not all(x is None for x in axis_labels): provided_labels.append("axis_labels") if provided_labels: warn( message=f"`auto_label=True` will override provided labels: {', '.join(provided_labels)}", category=LabelConflictWarning, stacklevel=2, )