Coverage for model_plots / Covid / violinplot.py: 17%
52 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import copy
3import numpy as np
4import scipy.integrate
5import scipy.ndimage
6import xarray as xr
7from dantro.plot.funcs.generic import make_facet_grid_plot
9from utopya.eval import PlotHelper, is_plot_func
12@make_facet_grid_plot(
13 map_as="dataset",
14 encodings=("x", "y", "hue", "col", "row"),
15 supported_hue_styles=("discrete",),
16 hue_style="discrete",
17 add_guide=False,
18)
19def violin_plot(
20 ds: xr.Dataset,
21 hlpr: PlotHelper,
22 *,
23 _is_facetgrid: bool,
24 x: str,
25 y: str,
26 hue: str,
27 add_legend: bool = True,
28 show_means: bool = True,
29 show_modes: bool = True,
30 format_y_label: bool = False,
31 mean_kwargs: dict = dict(s=15, color="#48675A", lw=0.3, edgecolor="#3D4244"),
32 mode_kwargs: dict = dict(s=15, color="#F5DDA9", lw=0.3, edgecolor="#3D4244"),
33 smooth_kwargs: dict = {},
34 **plot_kwargs,
35):
36 """Plots a violinplot of different datasets. The ``hue`` dimension is plotted in an alternating
37 fashion on the left and right sides of the plot, although this renders the plot somewhat pointless if the length
38 of the ``hue`` dimension is greater than 2. Means and modes of the modes can also be shown as discrete points.
40 :param ds: ``xr.Dataset`` of data values
41 :param hlpr: ``PlotHelper`` instance
42 :param x: variable to plot on the x dimension
43 :param y: variable to plot on the y dimension
44 :param hue: variable to alternately plot on the left and right side of the y-axis
45 :param add_legend: passed to ``xr.facet_grid``
46 :param show_means: (optional) whether to show the means of the distributions
47 :param show_modes: (optional) whether to show the modes of the distribution
48 :param format_y_label: (optional) whether to format the y-labels to match the Berlin SEIRD publication style
49 ``$\\lambda_{\rm X}$``.
50 :param mean_kwargs: plot_kwargs for the mean dots, passed to ``ax.scatter``
51 :param mode_kwargs plot_kwargs for the mean dots, passed to ``ax.scatter``
52 :param plot_kwargs: plot_kwargs for the distribution, passed to ``ax.fillbetweenx``
53 """
55 def _plot_1d(
56 _x, _y, _yfactor, *, _smooth_kwargs: dict = {}, label: str, **_plot_kwargs
57 ):
58 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend."""
59 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop(
60 "smoothing", None
61 )
62 # Smooth the y values, if given
63 if smooth:
64 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs)
66 _handle = hlpr.ax.fill_betweenx(
67 _x,
68 _yfactor * _y,
69 np.zeros(len(_y)),
70 alpha=0.6,
71 lw=2,
72 label=label,
73 **_plot_kwargs,
74 )
76 if show_means:
77 mean_x = scipy.integrate.trapezoid(_x * _y, _x)
78 mean_y = _y.data[np.argmin(np.abs(_x - mean_x).data)]
79 _mean_handle = hlpr.ax.scatter(
80 _yfactor * mean_y, mean_x, **mean_kwargs, label="Mean"
81 )
82 else:
83 _mean_handle = None
84 if show_modes:
85 mode_x, mode_y = _x[_y.argmax()], np.max(_y)
86 _mode_handle = hlpr.ax.scatter(
87 _yfactor * mode_y, mode_x, **mode_kwargs, label="Mode"
88 )
89 else:
90 _mode_handle = None
91 return _handle, _mean_handle, _mode_handle
93 if "parameter" in list(ds.coords):
94 pname = ds.coords["parameter"].values.item()
95 else:
96 pname = list(ds.coords.keys())[0]
98 _handles, _labels = [], []
99 for i, coord in enumerate(ds.coords[hue].values):
100 if x in ds.coords:
101 x_vals = ds.coords[x]
102 else:
103 x_vals = ds[x].sel({hue: coord})
104 y_vals = ds[y].sel({hue: coord})
105 _handle, _mean_handle, _mode_handle = _plot_1d(
106 x_vals,
107 y_vals,
108 ((-1) ** (i + 1)),
109 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)),
110 label=hue,
111 **plot_kwargs,
112 )
113 _handles.append(_handle)
114 _labels.append(f"{coord}")
116 if _mean_handle:
117 _handles.append(_mean_handle)
118 _labels.append("Mean")
119 if _mode_handle:
120 _handles.append(_mode_handle)
121 _labels.append("Mode")
123 if not _is_facetgrid:
124 if add_legend:
125 hlpr.ax.legend(_handles, _labels, title="")
126 else:
127 if add_legend:
128 hlpr.track_handles_labels(_handles, _labels)
129 hlpr.provide_defaults("set_figlegend", title="")
131 if format_y_label:
132 y_label = (
133 r"$\lambda_{\rm "
134 + ds.coords["parameter"].item()[2:].replace("_", ",")
135 + "}$"
136 )
137 hlpr.provide_defaults("set_labels", y={"label": y_label})
139 # Positive values on both axes
140 hlpr.ax.set_xticks(
141 hlpr.ax.get_xticks()[1:], labels=np.round(np.abs(hlpr.ax.get_xticks())[1:], 2)
142 )