Coverage for model_plots / Covid / violinplot.py: 17%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import copy 

2 

3import numpy as np 

4import scipy.integrate 

5import scipy.ndimage 

6import xarray as xr 

7from dantro.plot.funcs.generic import make_facet_grid_plot 

8 

9from utopya.eval import PlotHelper, is_plot_func 

10 

11 

12@make_facet_grid_plot( 

13 map_as="dataset", 

14 encodings=("x", "y", "hue", "col", "row"), 

15 supported_hue_styles=("discrete",), 

16 hue_style="discrete", 

17 add_guide=False, 

18) 

19def violin_plot( 

20 ds: xr.Dataset, 

21 hlpr: PlotHelper, 

22 *, 

23 _is_facetgrid: bool, 

24 x: str, 

25 y: str, 

26 hue: str, 

27 add_legend: bool = True, 

28 show_means: bool = True, 

29 show_modes: bool = True, 

30 format_y_label: bool = False, 

31 mean_kwargs: dict = dict(s=15, color="#48675A", lw=0.3, edgecolor="#3D4244"), 

32 mode_kwargs: dict = dict(s=15, color="#F5DDA9", lw=0.3, edgecolor="#3D4244"), 

33 smooth_kwargs: dict = {}, 

34 **plot_kwargs, 

35): 

36 """Plots a violinplot of different datasets. The ``hue`` dimension is plotted in an alternating 

37 fashion on the left and right sides of the plot, although this renders the plot somewhat pointless if the length 

38 of the ``hue`` dimension is greater than 2. Means and modes of the modes can also be shown as discrete points. 

39 

40 :param ds: ``xr.Dataset`` of data values 

41 :param hlpr: ``PlotHelper`` instance 

42 :param x: variable to plot on the x dimension 

43 :param y: variable to plot on the y dimension 

44 :param hue: variable to alternately plot on the left and right side of the y-axis 

45 :param add_legend: passed to ``xr.facet_grid`` 

46 :param show_means: (optional) whether to show the means of the distributions 

47 :param show_modes: (optional) whether to show the modes of the distribution 

48 :param format_y_label: (optional) whether to format the y-labels to match the Berlin SEIRD publication style 

49 ``$\\lambda_{\rm X}$``. 

50 :param mean_kwargs: plot_kwargs for the mean dots, passed to ``ax.scatter`` 

51 :param mode_kwargs plot_kwargs for the mean dots, passed to ``ax.scatter`` 

52 :param plot_kwargs: plot_kwargs for the distribution, passed to ``ax.fillbetweenx`` 

53 """ 

54 

55 def _plot_1d( 

56 _x, _y, _yfactor, *, _smooth_kwargs: dict = {}, label: str, **_plot_kwargs 

57 ): 

58 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend.""" 

59 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop( 

60 "smoothing", None 

61 ) 

62 # Smooth the y values, if given 

63 if smooth: 

64 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs) 

65 

66 _handle = hlpr.ax.fill_betweenx( 

67 _x, 

68 _yfactor * _y, 

69 np.zeros(len(_y)), 

70 alpha=0.6, 

71 lw=2, 

72 label=label, 

73 **_plot_kwargs, 

74 ) 

75 

76 if show_means: 

77 mean_x = scipy.integrate.trapezoid(_x * _y, _x) 

78 mean_y = _y.data[np.argmin(np.abs(_x - mean_x).data)] 

79 _mean_handle = hlpr.ax.scatter( 

80 _yfactor * mean_y, mean_x, **mean_kwargs, label="Mean" 

81 ) 

82 else: 

83 _mean_handle = None 

84 if show_modes: 

85 mode_x, mode_y = _x[_y.argmax()], np.max(_y) 

86 _mode_handle = hlpr.ax.scatter( 

87 _yfactor * mode_y, mode_x, **mode_kwargs, label="Mode" 

88 ) 

89 else: 

90 _mode_handle = None 

91 return _handle, _mean_handle, _mode_handle 

92 

93 if "parameter" in list(ds.coords): 

94 pname = ds.coords["parameter"].values.item() 

95 else: 

96 pname = list(ds.coords.keys())[0] 

97 

98 _handles, _labels = [], [] 

99 for i, coord in enumerate(ds.coords[hue].values): 

100 if x in ds.coords: 

101 x_vals = ds.coords[x] 

102 else: 

103 x_vals = ds[x].sel({hue: coord}) 

104 y_vals = ds[y].sel({hue: coord}) 

105 _handle, _mean_handle, _mode_handle = _plot_1d( 

106 x_vals, 

107 y_vals, 

108 ((-1) ** (i + 1)), 

109 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)), 

110 label=hue, 

111 **plot_kwargs, 

112 ) 

113 _handles.append(_handle) 

114 _labels.append(f"{coord}") 

115 

116 if _mean_handle: 

117 _handles.append(_mean_handle) 

118 _labels.append("Mean") 

119 if _mode_handle: 

120 _handles.append(_mode_handle) 

121 _labels.append("Mode") 

122 

123 if not _is_facetgrid: 

124 if add_legend: 

125 hlpr.ax.legend(_handles, _labels, title="") 

126 else: 

127 if add_legend: 

128 hlpr.track_handles_labels(_handles, _labels) 

129 hlpr.provide_defaults("set_figlegend", title="") 

130 

131 if format_y_label: 

132 y_label = ( 

133 r"$\lambda_{\rm " 

134 + ds.coords["parameter"].item()[2:].replace("_", ",") 

135 + "}$" 

136 ) 

137 hlpr.provide_defaults("set_labels", y={"label": y_label}) 

138 

139 # Positive values on both axes 

140 hlpr.ax.set_xticks( 

141 hlpr.ax.get_xticks()[1:], labels=np.round(np.abs(hlpr.ax.get_xticks())[1:], 2) 

142 )