Source code for plotez.backend.utilities

"""
PlotEZ Backend Utilities.

Utility classes and functions for plot parameter management and data validation.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

__all__ = [
    "dual_axes_data_validation",
    "ErrorBandConfig",
    "ErrorPlotConfig",
    "HistogramConfig",
    "LinePlotConfig",
    "plot_or_scatter",
    "ScatterPlotConfig",
    "split_dictionary",
]

from dataclasses import dataclass, field
from typing import Any, Literal

import numpy as np

from ..typing import ArrayLike, HatchStyle, NDArray
from .CONSTANTS import ERROR_ATTRS, ERROR_BAND_ATTRS, HIST_ATTRS, LINE_ATTRS, SCATTER_ATTRS
from .error_handling import AxisLabelError, EmptyDataError, TwinXDataError, TwinYDataError

if TYPE_CHECKING:
    from ..typing import LSE


def _populate(_class, dictionary: dict[str, Any], mapping):
    """
    Create a config dataclass instance from a dictionary, applying key aliases.

    Parameters
    ----------
    _class :
        The dataclass type to instantiate.
    dictionary :
        Raw parameter dictionary, possibly using shorthand keys.
    mapping :
        Alias-to-canonical-name mapping for shorthand keys.

    Returns
    -------
    instance
        An instance of ``_class`` populated from the mapped dictionary.
    """
    mapped = {mapping.get(k, k): v for k, v in dictionary.items()}

    known_fields = _class.__dict__["__annotations__"].keys() - {"_extra"}

    known = {k: v for k, v in mapped.items() if k in known_fields}
    extra = {k: v for k, v in mapped.items() if k not in known_fields}

    return _class(**known, _extra=extra)


[docs] @dataclass class LinePlotConfig: """Configuration class for line plots.""" color: str | list[str] | None = None linewidth: int | float | list[int | float] | None = None linestyle: str | list[str] | None = None alpha: int | float | list[int | float] | None = None marker: str | list[str] | None = None markersize: int | float | list[int | float] | None = None markerfacecolor: str | list[str] | None = None markeredgecolor: str | list[str] | None = None markeredgewidth: int | float | list[int | float] | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "LinePlotConfig": """Create a LinePlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=LINE_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ErrorBandConfig: """Configuration class for error bands (shaded fill regions).""" color: str | list[str] | None = None alpha: int | float | list[int | float] = 0.25 linewidth: int | float | list[int | float] | None = None edgecolor: str | list[str] | None = None linestyle: str | list[str] | None = None hatch: HatchStyle | list[HatchStyle] | None = None interpolate: bool | None = None step: str | Literal["pre", "post", "mid"] | None = None _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ErrorBandConfig": """Create an ErrorBandConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_BAND_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Return a string representation of the ErrorBandConfig instance.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ErrorPlotConfig: """Configuration class for error bar plots.""" # Core signal identity color: str | None = None linewidth: int | float | None = None linestyle: str | None = None alpha: int | float | None = None # Error structure (second layer of perception) ecolor: str | None = None elinewidth: int | float | None = None # Markers (data discreteness) marker: str | None = None markersize: int | float | None = None markerfacecolor: str | None = None markeredgecolor: str | None = None # Visual refinement capsize: int | float | None = None capthick: int | float | None = None errorevery: int | tuple | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ErrorPlotConfig": """Create an ErrorPlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class ScatterPlotConfig: """Configuration class for scatter plots.""" color: str | list[str] | None = None s: int | float | list[int | float] | None = None alpha: int | float | list[int | float] | None = None marker: str | list[str] | None = None cmap: str | list[str] | None = None edgecolors: str | list[str] | None = None facecolors: str | list[str] | None = None # For extra params - pass as dict to this field directly _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]) -> "ScatterPlotConfig": """Create a ScatterPlotConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=SCATTER_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters as dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items())) return f"{self.__class__.__name__}({param_str})"
[docs] @dataclass class HistogramConfig: """Configuration class for histogram plots.""" bins: int | None = None density: bool | None = None histtype: str | None = None color: str | None = None alpha: int | float | None = None edgecolor: str | None = None facecolor: str | None = None linewidth: int | float | None = None orientation: str | None = None cumulative: bool | None = None hatch: HatchStyle | None = None _extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs] @classmethod def populate(cls, dictionary: dict[str, Any]): """Create a HistogramConfig instance from a dictionary, using a mapping for shorthand keys.""" return _populate(_class=cls, dictionary=dictionary, mapping=HIST_ATTRS)
[docs] def get_dict(self) -> dict[str, Any]: """Get all parameters are dict for matplotlib.""" result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} result.update(self._extra) return result
[docs] def __repr__(self): """Pretty repr showing both explicit and extra params.""" all_params = self.get_dict() param_str = ", ".join(f"{k}={v!r}" for k, v in all_params.items()) return f"{self.__class__.__name__}({param_str})"
[docs] def plot_or_scatter(axes, scatter: bool): """ Return the plot or scatter method based on the specified plot type. Parameters ---------- axes : The matplotlib axis on which to apply the plot or scatter method. scatter : If True, returns the scatter method; otherwise, returns the plot method. Returns ------- function The matplotlib plotting method (`axes.scatter` if scatter is True, otherwise `axes.plot`). """ return axes.scatter if scatter else axes.plot
[docs] def split_dictionary(plot_instance: LSE) -> tuple[LSE, LSE]: """ Split a config instance's parameters into two separate instances. Parameters ---------- plot_instance : An instance with parameters stored as lists or tuples. Each parameter should be a list or tuple containing exactly two values, corresponding to settings for the two resulting instances. Returns ------- Tuple Two instances of the same type as `plot_instance`, with parameters split based on the values in `plot_instance`. The first instance (`instance1`) and second instance (`instance2`) will have their attributes set according to the first and second elements, respectively, from each list or tuple in `plot_instance`. Raises ------ ValueError If any parameter in `plot_instance` is not a list or tuple with exactly two elements. Notes ----- The parameters with only one element are broadcast to both instances rather than raising an error. """ parameters = plot_instance.get_dict() params_instance1, params_instance2 = {}, {} for param_name, values in parameters.items(): if isinstance(values, (list, tuple)) and len(values) >= 2: params_instance1[param_name] = values[0] params_instance2[param_name] = values[1] else: # Scalar or single-element: both instances get the same value params_instance1[param_name] = values params_instance2[param_name] = values instance1 = plot_instance.__class__.populate(params_instance1) instance2 = plot_instance.__class__.populate(params_instance2) return instance1, instance2
[docs] def dual_axes_data_validation( x1_data: ArrayLike, x2_data: ArrayLike | None, y1_data: ArrayLike, y2_data: ArrayLike | None, use_twin_x: bool, axis_labels: list[str | None], ) -> tuple[NDArray, NDArray, NDArray | None, NDArray | None]: """ Validate the data and parameters for dual-axes plotting. Parameters ---------- x1_data : Data for the primary x-axis. x2_data : Data for the secondary x-axis (used in dual x-axis plots). Should be `None` if `use_twin_x` is True. y1_data : Data for the primary y-axis. y2_data : Data for the secondary y-axis (used in dual y-axis plots). Should be `None` if `use_twin_x` is False. use_twin_x : If True, a dual y-axis plot is expected; otherwise, a dual x-axis plot is expected. axis_labels : List of axis labels. Must have exactly three elements: - Label for the x-axis of the primary plot. - Label for the y-axis of the primary plot. - Label for the secondary axis (x or y). Raises ------ AxisLabelError If ``axis_labels`` does not have exactly three elements. EmptyDataError If ``x1_data`` or ``y1_data`` is empty. TwinXDataError If ``x2_data`` is provided when ``use_twin_x`` is ``True``. TwinYDataError If ``y2_data`` is provided when ``use_twin_x`` is ``False``. """ if isinstance(axis_labels, str): raise AxisLabelError( f"axis_labels must be a list of 3 strings, not a plain string. Did you mean ['{axis_labels}']?" ) x1_data, y1_data = np.asarray(x1_data), np.asarray(y1_data) if x2_data is not None: x2_data = np.asarray(x2_data) if y2_data is not None: y2_data = np.asarray(y2_data) if len(axis_labels) != 3: # noqa raise AxisLabelError("The axis_labels should have a length of 3.") if len(x1_data) == 0 or len(y1_data) == 0: raise EmptyDataError("Primary x or y data is empty. Please provide valid data.") if use_twin_x and x2_data is not None: raise TwinXDataError("Dual Y-axis plot requested but 'x2_data' given.") if not use_twin_x and y2_data is not None: raise TwinYDataError("Dual X-axis plot requested but 'y2_data' given.") return x1_data, y1_data, x2_data, y2_data