Coverage for model_plots/prob_density.py: 18%
50 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import copy
2from typing import Sequence, Union
4import scipy.ndimage
5import xarray as xr
6from dantro.plot.funcs._utils import plot_errorbar
7from dantro.plot.funcs.generic import make_facet_grid_plot
9from utopya.eval import PlotHelper, is_plot_func
12@make_facet_grid_plot(
13 map_as="dataset",
14 encodings=("x", "y", "yerr", "hue", "col", "row", "alpha", "lw"),
15 supported_hue_styles=("discrete",),
16 hue_style="discrete",
17 add_guide=False,
18 register_as_kind='density'
19)
20def plot_prob_density(
21 ds: xr.Dataset,
22 hlpr: PlotHelper,
23 *,
24 _is_facetgrid: bool,
25 x: str,
26 y: str,
27 yerr: str = None,
28 hue: str = None,
29 label: str = None,
30 add_legend: bool = True,
31 smooth_kwargs: dict = {},
32 linestyle: Union[str, Sequence] = "solid",
33 **plot_kwargs,
34):
35 """Probability density plot for estimated parameters, which combines line- and errorband functionality into a
36 single plot. Crucially, the x-value does not need to be a dataset coordinate. Is xarray facet_grid compatible.
38 :param ds: dataset to plot
39 :param hlpr: PlotHelper
40 :param _is_facetgrid: whether the plot is a facet_grid instance or not (determined by the decorator function)
41 :param x: coordinate or variable to use as the x-value
42 :param y: values to plot onto the y-axis
43 :param yerr (optional): variable to use for the errorbands. If None, no errorbands are plotted.
44 :param hue: (optional) variable to plot onto the hue dimension
45 :param label: (optional) label for the plot, if the hue dimension is unused
46 :param add_legend: whether to add a legend
47 :param smooth_kwargs: dictionary for the smoothing settings. Smoothing can be set for all parameters or by parameter
48 :param plot_kwargs: passed to the plot function
49 """
51 def _plot_1d(*, _x, _y, _yerr, _smooth_kwargs, _ax, _label=None, **_plot_kwargs):
52 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend."""
53 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop(
54 "smoothing", None
55 )
56 # Smooth the y values, if given
57 if smooth:
58 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs)
60 # If no yerr is given, plot a single line
61 if _yerr is None:
62 (ebar,) = hlpr.ax.plot(_x, _y, label=_label, **_plot_kwargs)
63 return ebar
65 # Else, plot errorbands
66 else:
67 # Smooth the y error, if set
68 if smooth:
69 _yerr = scipy.ndimage.gaussian_filter1d(_yerr, sigma, **_smooth_kwargs)
71 return plot_errorbar(
72 ax=_ax,
73 x=_x,
74 y=_y,
75 yerr=_yerr,
76 label=_label,
77 fill_between=True,
78 **_plot_kwargs,
79 )
81 # Get the dataset and parameter name
82 if "parameter" in list(ds.coords):
83 pname = ds.coords["parameter"].values.item()
84 else:
85 for _c in ds.coords:
86 # Exclude 1D variables and the hue variable
87 if ds.coords[_c].shape == ():
88 continue
89 if hue is not None and _c == hue:
90 continue
91 pname = _c
93 # Track the legend handles and labels
94 _handles, _labels = [], []
95 if hue:
96 for i, coord in enumerate(ds.coords[hue].values):
97 if x in ds.coords:
98 x_vals = ds.coords[x]
99 else:
100 x_vals = ds[x].sel({hue: coord})
102 y_vals = ds[y].sel({hue: coord})
103 yerr_vals = ds[yerr].sel({hue: coord}) if yerr is not None else None
105 handle = _plot_1d(
106 _x=x_vals,
107 _y=y_vals,
108 _yerr=yerr_vals,
109 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)),
110 _ax=hlpr.ax,
111 _label=f"{coord}",
112 linestyle=linestyle if isinstance(linestyle, str) else linestyle[i],
113 **plot_kwargs,
114 )
116 _handles.append(handle)
117 _labels.append(f"{coord}")
119 if not _is_facetgrid:
120 if add_legend:
121 hlpr.ax.legend(_handles, _labels, title=hue)
122 else:
123 hlpr.track_handles_labels(_handles, _labels)
124 if add_legend:
125 hlpr.provide_defaults("set_figlegend", title=hue)
127 else:
128 if x in ds.coords:
129 x_vals = ds.coords[x]
130 else:
131 x_vals = ds[x]
132 y_vals = ds[y]
133 yerr_vals = ds[yerr] if yerr is not None else None
135 _plot_1d(
136 _x=x_vals,
137 _y=y_vals,
138 _yerr=yerr_vals,
139 _ax=hlpr.ax,
140 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)),
141 _label=label,
142 linestyle=linestyle,
143 **plot_kwargs,
144 )