"""
PlotEZ Backend Utilities.
Utility classes and functions for plot parameter management and data validation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
__all__ = [
"dual_axes_data_validation",
"error_offset_validation",
"errorband_validation",
"errorbar_validation",
"ErrorBandConfig",
"ErrorPlotConfig",
"HistogramConfig",
"LinePlotConfig",
"plot_or_scatter",
"ScatterPlotConfig",
"split_dictionary",
"validate_1d",
"validate_equal_length",
]
from dataclasses import dataclass, field
from typing import Any, Literal
import numpy as np
from ..errors import (
AxisLabelError,
ConfigurationError,
DataLengthError,
EmptyDataError,
ShapeError,
TwinXDataError,
TwinYDataError,
)
from ..typing import ArrayLike, HatchStyle, NDArray
from .CONSTANTS import ERROR_ATTRS, ERROR_BAND_ATTRS, HIST_ATTRS, LINE_ATTRS, SCATTER_ATTRS
if TYPE_CHECKING:
from ..typing import LSE
def _populate(_class, dictionary: dict[str, Any], mapping):
"""
Create a config dataclass instance from a dictionary, applying key aliases.
Parameters
----------
_class :
The dataclass type to instantiate.
dictionary :
Raw parameter dictionary, possibly using shorthand keys.
mapping :
Alias-to-canonical-name mapping for shorthand keys.
Returns
-------
instance
An instance of ``_class`` populated from the mapped dictionary.
"""
mapped = {mapping.get(k, k): v for k, v in dictionary.items()}
known_fields = _class.__dict__["__annotations__"].keys() - {"_extra"}
known = {k: v for k, v in mapped.items() if k in known_fields}
extra = {k: v for k, v in mapped.items() if k not in known_fields}
return _class(**known, _extra=extra)
[docs]
def dual_axes_data_validation(
x1_data: ArrayLike,
x2_data: ArrayLike | None,
y1_data: ArrayLike,
y2_data: ArrayLike | None,
use_twin_x: bool,
axis_labels: list[str | None],
) -> tuple[NDArray, NDArray, NDArray | None, NDArray | None]:
"""
Validate the data and parameters for dual-axes plotting.
Parameters
----------
x1_data :
Data for the primary x-axis.
x2_data :
Data for the secondary x-axis (used in dual x-axis plots). Should be `None` if `use_twin_x` is True.
y1_data :
Data for the primary y-axis.
y2_data :
Data for the secondary y-axis (used in dual y-axis plots). Should be `None` if `use_twin_x` is False.
use_twin_x :
If True, a dual y-axis plot is expected; otherwise, a dual x-axis plot is expected.
axis_labels :
List of axis labels. Must have exactly three elements:
- Label for the x-axis of the primary plot.
- Label for the y-axis of the primary plot.
- Label for the secondary axis (x or y).
Raises
------
AxisLabelError
If ``axis_labels`` does not have exactly three elements.
EmptyDataError
If ``x1_data`` or ``y1_data`` is empty.
TwinXDataError
If ``x2_data`` is provided when ``use_twin_x`` is ``True``.
TwinYDataError
If ``y2_data`` is provided when ``use_twin_x`` is ``False``.
"""
if isinstance(axis_labels, str):
raise AxisLabelError(
f"axis_labels must be a list of 3 strings, not a plain string. Did you mean ['{axis_labels}']?"
)
x1_data, y1_data = np.asarray(x1_data), np.asarray(y1_data)
validate_1d(x1_data, y1_data, names=["x1_data", "y1_data"])
if len(x1_data) == 0 or len(y1_data) == 0:
raise EmptyDataError("Primary x or y data is empty. Please provide valid data.")
validate_equal_length(x1_data, y1_data, names=["x1_data", "y1_data"])
if x2_data is not None:
x2_data = np.asarray(x2_data)
validate_1d(x2_data, names=["x2_data"])
if y2_data is not None:
y2_data = np.asarray(y2_data)
validate_1d(y2_data, names=["y2_data"])
if len(axis_labels) != 3: # noqa
raise AxisLabelError("The axis_labels should have a length of 3.")
if use_twin_x and x2_data is not None:
raise TwinXDataError("Dual Y-axis plot requested but 'x2_data' given.")
if not use_twin_x and y2_data is not None:
raise TwinYDataError("Dual X-axis plot requested but 'y2_data' given.")
if use_twin_x and y2_data is not None:
validate_equal_length(x1_data, y2_data, names=["x1_data", "y2_data"])
if not use_twin_x and x2_data is not None:
validate_equal_length(x2_data, y1_data, names=["x2_data", "y1_data"])
return x1_data, y1_data, x2_data, y2_data
[docs]
@dataclass
class ErrorBandConfig:
"""Configuration class for error bands (shaded fill regions)."""
color: str | list[str] | None = None
alpha: float | list[float] = 0.25
linewidth: float | list[float] | None = None
edgecolor: str | list[str] | None = None
linestyle: str | list[str] | None = None
hatch: HatchStyle | list[HatchStyle] | None = None
interpolate: bool | None = None
step: str | Literal["pre", "post", "mid"] | None = None
_extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
@classmethod
def populate(cls, dictionary: dict[str, Any]) -> "ErrorBandConfig":
"""Create an ErrorBandConfig instance from a dictionary, using a mapping for shorthand keys."""
return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_BAND_ATTRS)
[docs]
def get_dict(self) -> dict[str, Any]:
"""Get all parameters as dict for matplotlib."""
result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}
result.update(self._extra)
return result
[docs]
def __repr__(self):
"""Return a string representation of the ErrorBandConfig instance."""
all_params = self.get_dict()
param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items()))
return f"{self.__class__.__name__}({param_str})"
[docs]
@dataclass
class ErrorPlotConfig:
"""Configuration class for error bar plots."""
# Core signal identity
color: str | None = None
linewidth: float | None = None
linestyle: str | None = None
alpha: float | None = None
# Error structure (second layer of perception)
ecolor: str | None = None
elinewidth: float | None = None
# Markers (data discreteness)
marker: str | None = None
markersize: float | None = None
markerfacecolor: str | None = None
markeredgecolor: str | None = None
# Visual refinement
capsize: float | None = None
capthick: float | None = None
errorevery: int | tuple | None = None
# For extra params - pass as dict to this field directly
_extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
@classmethod
def populate(cls, dictionary: dict[str, Any]) -> "ErrorPlotConfig":
"""Create an ErrorPlotConfig instance from a dictionary, using a mapping for shorthand keys."""
return _populate(_class=cls, dictionary=dictionary, mapping=ERROR_ATTRS)
[docs]
def get_dict(self) -> dict[str, Any]:
"""Get all parameters as dict for matplotlib."""
result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}
result.update(self._extra)
return result
[docs]
def __repr__(self):
"""Pretty repr showing both explicit and extra params."""
all_params = self.get_dict()
param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items()))
return f"{self.__class__.__name__}({param_str})"
[docs]
@dataclass
class HistogramConfig:
"""Configuration class for histogram plots."""
bins: int | None = None
density: bool | None = None
histtype: str | None = None
color: str | None = None
alpha: float | None = None
edgecolor: str | None = None
facecolor: str | None = None
linewidth: float | None = None
orientation: str | None = None
cumulative: bool | None = None
hatch: HatchStyle | None = None
_extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
@classmethod
def populate(cls, dictionary: dict[str, Any]):
"""Create a HistogramConfig instance from a dictionary, using a mapping for shorthand keys."""
return _populate(_class=cls, dictionary=dictionary, mapping=HIST_ATTRS)
[docs]
def get_dict(self) -> dict[str, Any]:
"""Get all parameters are dict for matplotlib."""
result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}
result.update(self._extra)
return result
[docs]
def __repr__(self):
"""Pretty repr showing both explicit and extra params."""
all_params = self.get_dict()
param_str = ", ".join(f"{k}={v!r}" for k, v in all_params.items())
return f"{self.__class__.__name__}({param_str})"
[docs]
@dataclass
class LinePlotConfig:
"""Configuration class for line plots."""
color: str | list[str] | None = None
linewidth: float | list[float] | None = None
linestyle: str | list[str] | None = None
alpha: float | list[float] | None = None
marker: str | list[str] | None = None
markersize: float | list[float] | None = None
markerfacecolor: str | list[str] | None = None
markeredgecolor: str | list[str] | None = None
markeredgewidth: float | list[float] | None = None
# For extra params - pass as dict to this field directly
_extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
@classmethod
def populate(cls, dictionary: dict[str, Any]) -> "LinePlotConfig":
"""Create a LinePlotConfig instance from a dictionary, using a mapping for shorthand keys."""
return _populate(_class=cls, dictionary=dictionary, mapping=LINE_ATTRS)
[docs]
def get_dict(self) -> dict[str, Any]:
"""Get all parameters as dict for matplotlib."""
result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}
result.update(self._extra)
return result
[docs]
def __repr__(self):
"""Pretty repr showing both explicit and extra params."""
all_params = self.get_dict()
param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items()))
return f"{self.__class__.__name__}({param_str})"
[docs]
def plot_or_scatter(axes, scatter: bool):
"""
Return the plot or scatter method based on the specified plot type.
Parameters
----------
axes :
The matplotlib axis on which to apply the plot or scatter method.
scatter :
If True, returns the scatter method; otherwise, returns the plot method.
Returns
-------
function
The matplotlib plotting method (`axes.scatter` if scatter is True, otherwise `axes.plot`).
"""
return axes.scatter if scatter else axes.plot
[docs]
@dataclass
class ScatterPlotConfig:
"""Configuration class for scatter plots."""
color: str | list[str] | None = None
s: float | list[float] | None = None
alpha: float | list[float] | None = None
marker: str | list[str] | None = None
cmap: str | list[str] | None = None
edgecolors: str | list[str] | None = None
facecolors: str | list[str] | None = None
# For extra params - pass as dict to this field directly
_extra: dict[str, Any] = field(default_factory=dict, repr=False)
[docs]
@classmethod
def populate(cls, dictionary: dict[str, Any]) -> "ScatterPlotConfig":
"""Create a ScatterPlotConfig instance from a dictionary, using a mapping for shorthand keys."""
return _populate(_class=cls, dictionary=dictionary, mapping=SCATTER_ATTRS)
[docs]
def get_dict(self) -> dict[str, Any]:
"""Get all parameters as dict for matplotlib."""
result = {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None}
result.update(self._extra)
return result
[docs]
def __repr__(self):
"""Pretty repr showing both explicit and extra params."""
all_params = self.get_dict()
param_str = ", ".join(f"{k}={v!r}" for k, v in sorted(all_params.items()))
return f"{self.__class__.__name__}({param_str})"
[docs]
def split_dictionary(plot_instance: LSE) -> tuple[LSE, LSE]:
"""
Split a config instance's parameters into two separate instances.
Parameters
----------
plot_instance :
An instance with parameters stored as lists or tuples.
Each parameter should be a list or tuple containing exactly two values, corresponding to settings for the
two resulting instances.
Returns
-------
Tuple
Two instances of the same type as `plot_instance`, with parameters split based on the values in `plot_instance`.
The first instance (`instance1`) and second instance (`instance2`) will have their attributes set according
to the first and second elements, respectively, from each list or tuple in `plot_instance`.
Raises
------
ValueError
If any parameter in `plot_instance` is not a list or tuple with exactly two elements.
Notes
-----
The parameters with only one element are broadcast to both instances rather than raising an error.
"""
parameters = plot_instance.get_dict()
params_instance1, params_instance2 = {}, {}
for param_name, values in parameters.items():
if isinstance(values, (list, tuple)) and len(values) >= 2:
params_instance1[param_name] = values[0]
params_instance2[param_name] = values[1]
else:
# Scalar or single-element: both instances get the same value
params_instance1[param_name] = values
params_instance2[param_name] = values
instance1 = plot_instance.__class__.populate(params_instance1)
instance2 = plot_instance.__class__.populate(params_instance2)
return instance1, instance2
[docs]
def validate_1d(*arrays: NDArray, names: list[str]) -> None:
"""Verify that every supplied array is exactly 1-D.
Parameters
----------
*arrays :
Arrays to check (already converted with ``np.asarray``).
names :
Human-readable name for each array, used in the error message.
Raises
------
ShapeError
If any array has `ndim != 1`.
"""
for arr, name in zip(arrays, names):
if arr.ndim != 1:
raise ShapeError(f"'{name}' must be 1D, got shape {arr.shape}")
[docs]
def validate_equal_length(*arrays: NDArray, names: list[str]) -> None:
"""Verify that all supplied arrays share the same length.
Scalar (0-d) arrays are skipped — they broadcast freely and do not need a length check.
Parameters
----------
*arrays :
Arrays to check (already converted with `np.asarray`).
names :
Human-readable name for each array, used in the error message.
Raises
------
DataLengthError
If the arrays do not all have the same length.
"""
checkable = [(arr, name) for arr, name in zip(arrays, names) if arr.ndim >= 1]
if not checkable:
return
lengths = [len(arr) for arr, _ in checkable]
if len(set(lengths)) > 1:
parts = ", ".join(f"'{n}'={length}" for (_, n), length in zip(checkable, lengths))
raise DataLengthError(f"Arrays must have equal length, got: {parts}")
[docs]
def errorbar_validation(
x: NDArray, y: NDArray, x_err: ArrayLike | None, y_err: ArrayLike | None
) -> tuple[NDArray | None, NDArray | None]:
"""Validate the input arrays for error bars to ensure compatibility with the provided data arrays.
Parameters
----------
x :
The array of x-coordinate data.
y :
The array of y-coordinate data.
x_err :
The error values associated with the x-coordinate data.
Can represent either symmetric or asymmetric errors.
- If symmetric, it should be a 1D array.
- If asymmetric, it should be a 2D array with shape (2, N).
- Can also be `None`, in which case no validation is performed for x-error.
y_err :
The error values associated with the y-coordinate data.
Can represent either symmetric or asymmetric errors.
- If symmetric, it should be a 1D array.
- If asymmetric, it should be a 2D array with shape (2, N).
- Can also be `None`, in which case no validation is performed for y-error.
Returns
-------
tuple[NDArray | None, NDArray | None]
A tuple containing the validated `x_err` and `y_err` arrays, respectively.
If no validation is performed for a specific error array (i.e., it is `None`), it is returned as `None`.
Raises
------
ShapeError
Raised if the shape of `x_err` or `y_err` does not conform to valid symmetric or asymmetric error formats.
DataLengthError
Raised if the length of the `x_err` or `y_err` array does not match the length of `x` or `y` data array.
"""
if x_err is not None:
x_err: NDArray = np.asarray(x_err)
if x_err.ndim > 2 or (x_err.ndim == 2 and x_err.shape[0] != 2):
raise ShapeError(f"Asymmetric `x_err` must have shape (2, N), got {x_err.shape}")
if x_err.ndim == 1:
validate_equal_length(x, x_err, names=["x_data", "x_err"])
elif x_err.ndim == 2 and x_err.shape[1] != len(x):
raise DataLengthError(f"'x_err' must have length {len(x)} along axis 1, got {x_err.shape[1]}")
if y_err is not None:
y_err: NDArray = np.asarray(y_err)
if y_err.ndim > 2 or (y_err.ndim == 2 and y_err.shape[0] != 2):
raise ShapeError(f"Asymmetric `y_err` must have shape (2, N), got {y_err.shape}")
if y_err.ndim == 1:
validate_equal_length(y, y_err, names=["y_data", "y_err"])
elif y_err.ndim == 2 and y_err.shape[1] != len(y):
raise DataLengthError(f"'y_err' must have length {len(y)} along axis 1, got {y_err.shape[1]}")
return x_err, y_err
[docs]
def errorband_validation(
x: NDArray, y: NDArray, y_lower: ArrayLike | None, y_upper: ArrayLike | None
) -> tuple[NDArray | None, NDArray | None]:
"""Validate the input parameters for an error band visualization process.
Parameters
----------
x :
Primary data on the x-axis to be validated.
y :
Primary data on the y-axis to be validated.
y_lower :
Optional lower-bound data for the error band, validated if provided.
y_upper :
Optional upper-bound data for the error band, validated if provided.
Returns
-------
tuple[NDArray | None, NDArray | None]
A tuple containing the validated `x_err` and `y_err` arrays, respectively.
If no validation is performed for a specific error array (i.e., it is `None`), it is returned as `None`.
"""
if y_lower is None and y_upper is None:
raise ConfigurationError("At least one of `y_lower` or `y_upper` must be provided for the error band.")
if y_lower is not None:
y_lower = np.asarray(y_lower)
if y_lower.ndim > 0:
validate_1d(y_lower, names=["y_lower"])
if y_upper is not None:
y_upper = np.asarray(y_upper)
if y_upper.ndim > 0:
validate_1d(y_upper, names=["y_upper"])
supplied_bounds = [array for array in (y_lower, y_upper) if array is not None]
supplied_names = [name for array, name in ((y_lower, "y_lower"), (y_upper, "y_upper")) if array is not None]
validate_equal_length(x, y, *supplied_bounds, names=["x_data", "y_data", *supplied_names])
return y_lower, y_upper
[docs]
def error_offset_validation(
y: NDArray, y_lower: ArrayLike | None, y_upper: ArrayLike | None
) -> tuple[NDArray | None, NDArray | None]:
"""Validate and processes the lower and upper error offset arrays relative to a target array.
Parameters
----------
y :
The target array against which the error offsets are validated.
y_lower :
The lower error offsets.
Can be None.
If provided, it is validated for dimensionality and length.
y_upper :
The upper error offsets.
Can be None.
If provided, it is validated for dimensionality and length.
Returns
-------
tuple[NDArray | None, NDArray | None]
A tuple containing the validated `x_lower` and `y_lower` arrays, respectively.
If no validation is performed for a specific error array (i.e., it is `None`), it is returned as `None`.
"""
lower_offset = np.asarray(y_lower) if y_lower is not None else None
upper_offset = np.asarray(y_upper) if y_upper is not None else None
if lower_offset is not None and lower_offset.ndim > 0:
validate_1d(lower_offset, names=["y_lower"])
validate_equal_length(y, lower_offset, names=["y_data", "y_lower"])
if upper_offset is not None and upper_offset.ndim > 0:
validate_1d(upper_offset, names=["y_upper"])
validate_equal_length(y, upper_offset, names=["y_data", "y_upper"])
return lower_offset, upper_offset