Coverage for model_plots / plots.py: 15%
88 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import xarray as xr
2from dantro.plot.funcs.generic import make_facet_grid_plot
3import copy
4from typing import Union, Sequence
5from utopya.eval import PlotHelper, is_plot_func
6import scipy
7from dantro.plot.funcs._utils import plot_errorbar
9@make_facet_grid_plot(
10 map_as="dataset",
11 encodings=("x", "y", "yerr", "hue", "col", "row", "alpha", "lw"),
12 supported_hue_styles=("discrete",),
13 hue_style="discrete",
14 add_guide=False,
15 register_as_kind='density'
16)
17def plot_prob_density(
18 ds: xr.Dataset,
19 hlpr: PlotHelper,
20 *,
21 _is_facetgrid: bool,
22 x: str,
23 y: str,
24 yerr: str = None,
25 hue: str = None,
26 label: str = None,
27 add_legend: bool = True,
28 smooth_kwargs: dict = {},
29 linestyle: Union[str, Sequence] = "solid",
30 **plot_kwargs,
31):
32 """Probability density plot for estimated parameters, which combines line- and errorband functionality into a
33 single plot. Crucially, the x-value does not need to be a dataset coordinate. Is xarray facet_grid compatible.
35 :param ds: dataset to plot
36 :param hlpr: PlotHelper
37 :param _is_facetgrid: whether the plot is a facet_grid instance or not (determined by the decorator function)
38 :param x: coordinate or variable to use as the x-value
39 :param y: values to plot onto the y-axis
40 :param yerr (optional): variable to use for the errorbands. If None, no errorbands are plotted.
41 :param hue: (optional) variable to plot onto the hue dimension
42 :param label: (optional) label for the plot, if the hue dimension is unused
43 :param add_legend: whether to add a legend
44 :param smooth_kwargs: dictionary for the smoothing settings. Smoothing can be set for all parameters or by parameter
45 :param plot_kwargs: passed to the plot function
46 """
48 def _plot_1d(*, _x, _y, _yerr, _smooth_kwargs, _ax, _label=None, **_plot_kwargs):
49 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend."""
50 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop(
51 "smoothing", None
52 )
53 # Smooth the y values, if given
54 if smooth:
55 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs)
57 # If no yerr is given, plot a single line
58 if _yerr is None:
59 (ebar,) = hlpr.ax.plot(_x, _y, label=_label, **_plot_kwargs)
60 return ebar
62 # Else, plot errorbands
63 else:
64 # Smooth the y error, if set
65 if smooth:
66 _yerr = scipy.ndimage.gaussian_filter1d(_yerr, sigma, **_smooth_kwargs)
68 return plot_errorbar(
69 ax=_ax,
70 x=_x,
71 y=_y,
72 yerr=_yerr,
73 label=_label,
74 fill_between=True,
75 **_plot_kwargs,
76 )
78 # Get the dataset and parameter name
79 if "parameter" in list(ds.coords):
80 pname = ds.coords["parameter"].values.item()
81 else:
82 for _c in ds.coords:
83 # Exclude 1D variables and the hue variable
84 if ds.coords[_c].shape == ():
85 continue
86 if hue is not None and _c == hue:
87 continue
88 pname = _c
90 # Track the legend handles and labels
91 _handles, _labels = [], []
92 if hue:
93 for i, coord in enumerate(ds.coords[hue].values):
94 if x in ds.coords:
95 x_vals = ds.coords[x]
96 else:
97 x_vals = ds[x].sel({hue: coord})
99 y_vals = ds[y].sel({hue: coord})
100 yerr_vals = ds[yerr].sel({hue: coord}) if yerr is not None else None
102 handle = _plot_1d(
103 _x=x_vals,
104 _y=y_vals,
105 _yerr=yerr_vals,
106 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)),
107 _ax=hlpr.ax,
108 _label=f"{coord}",
109 linestyle=linestyle if isinstance(linestyle, str) else linestyle[i],
110 **plot_kwargs,
111 )
113 _handles.append(handle)
114 _labels.append(f"{coord}")
116 if not _is_facetgrid:
117 if add_legend:
118 hlpr.ax.legend(_handles, _labels, title=hue)
119 else:
120 hlpr.track_handles_labels(_handles, _labels)
121 if add_legend:
122 hlpr.provide_defaults("set_figlegend", title=hue)
124 else:
125 if x in ds.coords:
126 x_vals = ds.coords[x]
127 else:
128 x_vals = ds[x]
129 y_vals = ds[y]
130 yerr_vals = ds[yerr] if yerr is not None else None
132 _plot_1d(
133 _x=x_vals,
134 _y=y_vals,
135 _yerr=yerr_vals,
136 _ax=hlpr.ax,
137 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)),
138 _label=label,
139 linestyle=linestyle,
140 **plot_kwargs,
141 )
144@make_facet_grid_plot(
145 map_as="dataset",
146 encodings=("x", "y", "hue", "col", "row"),
147 supported_hue_styles=("discrete",),
148 hue_style="discrete",
149 add_guide=False,
150 register_as_kind="line_and_scatter"
151)
152def line_and_scatter(
153 ds: xr.Dataset,
154 hlpr: PlotHelper,
155 *,
156 _is_facetgrid: bool,
157 x: str = None,
158 y: str = None,
159 scatter: str,
160 hue: str,
161 add_legend: bool = True,
162 line_kwargs: dict = {},
163 scatter_kwargs: dict = {}
164):
165 """ Combined line and scatter plot.
167 :param ds:
168 :param hlpr:
169 :param _is_facetgrid:
170 :param x:
171 :param y:
172 :param scatter:
173 :param hue:
174 :param add_legend:
175 :param line_kwargs:
176 :param scatter_kwargs:
177 :return:
178 """
179 handles = []
180 labels = []
181 for i, coord in enumerate(ds.coords[hue].values):
182 _handle = hlpr.ax.plot(ds.coords[x].data, ds[y].sel({hue: coord}), **line_kwargs, label=coord)
183 _handle_2 = hlpr.ax.scatter(ds.coords[x].data, ds[scatter].sel({hue: coord}), **scatter_kwargs, label=None)
184 handles.append(_handle[0])
185 labels.append(f"{coord}")
187 # Create a dummy handle for the legend
188 from matplotlib.lines import Line2D
190 true_data_handle = Line2D(
191 [], [],
192 marker=_handle_2.get_paths()[0], # Optional: match marker style
193 markersize=_handle_2.get_sizes()[0]**0.5,
194 linestyle='None',
195 color='grey',
196 label='True data',
197 markerfacecolor='grey',
198 markeredgecolor='grey'
199 )
201 handles.append(true_data_handle)
202 labels.append('True data')
204 # Add legend
205 if not _is_facetgrid:
206 if add_legend:
207 hlpr.ax.legend(handles, labels, title=hue)
208 else:
209 hlpr.track_handles_labels(handles, labels)
210 if add_legend:
211 hlpr.provide_defaults("set_figlegend", title=hue)
213@make_facet_grid_plot(
214 map_as="dataset",
215 encodings=("x", "y", "hue", "col", "row"),
216 supported_hue_styles=("discrete",),
217 hue_style="discrete",
218 add_guide=False,
219 register_as_kind="errorbands_and_scatter"
220)
221def errorbands_and_scatter(
222 ds: xr.Dataset,
223 hlpr: PlotHelper,
224 *,
225 _is_facetgrid: bool,
226 x: str = None,
227 y: str,
228 yerr: str,
229 scatter: str,
230 hue: str,
231 add_legend: bool = True,
232 line_kwargs: dict = {},
233 scatter_kwargs: dict = {}
234):
235 handles = []
236 labels = []
237 for i, coord in enumerate(ds.coords[hue].values):
238 _handle = plot_errorbar(
239 ax=hlpr.ax,
240 x=ds.coords[x].data,
241 y=ds.sel({hue: coord})[y],
242 yerr=ds.sel({hue:coord})[yerr],
243 label=f'{coord}',
244 fill_between=True,
245 **line_kwargs
246 )
247 _handle_2 = hlpr.ax.scatter(ds.coords[x].data, ds[scatter].sel({hue: coord}), **scatter_kwargs,
248 label=None)
249 handles.append(_handle)
250 labels.append(f"{coord}")
252 # Create a dummy handle for the legend
253 from matplotlib.lines import Line2D
255 true_data_handle = Line2D(
256 [], [],
257 marker=_handle_2.get_paths()[0], # Optional: match marker style
258 markersize=_handle_2.get_sizes()[0]**0.5,
259 linestyle='None',
260 color='grey',
261 label='True data',
262 markerfacecolor='grey',
263 markeredgecolor='grey'
264 )
266 handles.append(true_data_handle)
267 labels.append('True data')
269 # Add legend
270 if not _is_facetgrid:
271 if add_legend:
272 hlpr.ax.legend(handles, labels, title=hue)
273 else:
274 hlpr.track_handles_labels(handles, labels)
275 if add_legend:
276 hlpr.provide_defaults("set_figlegend", title=hue)