Coverage for model_plots/prob_density.py: 18%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import copy 

2from typing import Sequence, Union 

3 

4import scipy.ndimage 

5import xarray as xr 

6from dantro.plot.funcs._utils import plot_errorbar 

7from dantro.plot.funcs.generic import make_facet_grid_plot 

8 

9from utopya.eval import PlotHelper, is_plot_func 

10 

11 

12@make_facet_grid_plot( 

13 map_as="dataset", 

14 encodings=("x", "y", "yerr", "hue", "col", "row", "alpha", "lw"), 

15 supported_hue_styles=("discrete",), 

16 hue_style="discrete", 

17 add_guide=False, 

18 register_as_kind='density' 

19) 

20def plot_prob_density( 

21 ds: xr.Dataset, 

22 hlpr: PlotHelper, 

23 *, 

24 _is_facetgrid: bool, 

25 x: str, 

26 y: str, 

27 yerr: str = None, 

28 hue: str = None, 

29 label: str = None, 

30 add_legend: bool = True, 

31 smooth_kwargs: dict = {}, 

32 linestyle: Union[str, Sequence] = "solid", 

33 **plot_kwargs, 

34): 

35 """Probability density plot for estimated parameters, which combines line- and errorband functionality into a 

36 single plot. Crucially, the x-value does not need to be a dataset coordinate. Is xarray facet_grid compatible. 

37 

38 :param ds: dataset to plot 

39 :param hlpr: PlotHelper 

40 :param _is_facetgrid: whether the plot is a facet_grid instance or not (determined by the decorator function) 

41 :param x: coordinate or variable to use as the x-value 

42 :param y: values to plot onto the y-axis 

43 :param yerr (optional): variable to use for the errorbands. If None, no errorbands are plotted. 

44 :param hue: (optional) variable to plot onto the hue dimension 

45 :param label: (optional) label for the plot, if the hue dimension is unused 

46 :param add_legend: whether to add a legend 

47 :param smooth_kwargs: dictionary for the smoothing settings. Smoothing can be set for all parameters or by parameter 

48 :param plot_kwargs: passed to the plot function 

49 """ 

50 

51 def _plot_1d(*, _x, _y, _yerr, _smooth_kwargs, _ax, _label=None, **_plot_kwargs): 

52 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend.""" 

53 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop( 

54 "smoothing", None 

55 ) 

56 # Smooth the y values, if given 

57 if smooth: 

58 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs) 

59 

60 # If no yerr is given, plot a single line 

61 if _yerr is None: 

62 (ebar,) = hlpr.ax.plot(_x, _y, label=_label, **_plot_kwargs) 

63 return ebar 

64 

65 # Else, plot errorbands 

66 else: 

67 # Smooth the y error, if set 

68 if smooth: 

69 _yerr = scipy.ndimage.gaussian_filter1d(_yerr, sigma, **_smooth_kwargs) 

70 

71 return plot_errorbar( 

72 ax=_ax, 

73 x=_x, 

74 y=_y, 

75 yerr=_yerr, 

76 label=_label, 

77 fill_between=True, 

78 **_plot_kwargs, 

79 ) 

80 

81 # Get the dataset and parameter name 

82 if "parameter" in list(ds.coords): 

83 pname = ds.coords["parameter"].values.item() 

84 else: 

85 for _c in ds.coords: 

86 # Exclude 1D variables and the hue variable 

87 if ds.coords[_c].shape == (): 

88 continue 

89 if hue is not None and _c == hue: 

90 continue 

91 pname = _c 

92 

93 # Track the legend handles and labels 

94 _handles, _labels = [], [] 

95 if hue: 

96 for i, coord in enumerate(ds.coords[hue].values): 

97 if x in ds.coords: 

98 x_vals = ds.coords[x] 

99 else: 

100 x_vals = ds[x].sel({hue: coord}) 

101 

102 y_vals = ds[y].sel({hue: coord}) 

103 yerr_vals = ds[yerr].sel({hue: coord}) if yerr is not None else None 

104 

105 handle = _plot_1d( 

106 _x=x_vals, 

107 _y=y_vals, 

108 _yerr=yerr_vals, 

109 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)), 

110 _ax=hlpr.ax, 

111 _label=f"{coord}", 

112 linestyle=linestyle if isinstance(linestyle, str) else linestyle[i], 

113 **plot_kwargs, 

114 ) 

115 

116 _handles.append(handle) 

117 _labels.append(f"{coord}") 

118 

119 if not _is_facetgrid: 

120 if add_legend: 

121 hlpr.ax.legend(_handles, _labels, title=hue) 

122 else: 

123 hlpr.track_handles_labels(_handles, _labels) 

124 if add_legend: 

125 hlpr.provide_defaults("set_figlegend", title=hue) 

126 

127 else: 

128 if x in ds.coords: 

129 x_vals = ds.coords[x] 

130 else: 

131 x_vals = ds[x] 

132 y_vals = ds[y] 

133 yerr_vals = ds[yerr] if yerr is not None else None 

134 

135 _plot_1d( 

136 _x=x_vals, 

137 _y=y_vals, 

138 _yerr=yerr_vals, 

139 _ax=hlpr.ax, 

140 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)), 

141 _label=label, 

142 linestyle=linestyle, 

143 **plot_kwargs, 

144 )