Coverage for model_plots / plots.py: 15%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import xarray as xr 

2from dantro.plot.funcs.generic import make_facet_grid_plot 

3import copy 

4from typing import Union, Sequence 

5from utopya.eval import PlotHelper, is_plot_func 

6import scipy 

7from dantro.plot.funcs._utils import plot_errorbar 

8 

9@make_facet_grid_plot( 

10 map_as="dataset", 

11 encodings=("x", "y", "yerr", "hue", "col", "row", "alpha", "lw"), 

12 supported_hue_styles=("discrete",), 

13 hue_style="discrete", 

14 add_guide=False, 

15 register_as_kind='density' 

16) 

17def plot_prob_density( 

18 ds: xr.Dataset, 

19 hlpr: PlotHelper, 

20 *, 

21 _is_facetgrid: bool, 

22 x: str, 

23 y: str, 

24 yerr: str = None, 

25 hue: str = None, 

26 label: str = None, 

27 add_legend: bool = True, 

28 smooth_kwargs: dict = {}, 

29 linestyle: Union[str, Sequence] = "solid", 

30 **plot_kwargs, 

31): 

32 """Probability density plot for estimated parameters, which combines line- and errorband functionality into a 

33 single plot. Crucially, the x-value does not need to be a dataset coordinate. Is xarray facet_grid compatible. 

34 

35 :param ds: dataset to plot 

36 :param hlpr: PlotHelper 

37 :param _is_facetgrid: whether the plot is a facet_grid instance or not (determined by the decorator function) 

38 :param x: coordinate or variable to use as the x-value 

39 :param y: values to plot onto the y-axis 

40 :param yerr (optional): variable to use for the errorbands. If None, no errorbands are plotted. 

41 :param hue: (optional) variable to plot onto the hue dimension 

42 :param label: (optional) label for the plot, if the hue dimension is unused 

43 :param add_legend: whether to add a legend 

44 :param smooth_kwargs: dictionary for the smoothing settings. Smoothing can be set for all parameters or by parameter 

45 :param plot_kwargs: passed to the plot function 

46 """ 

47 

48 def _plot_1d(*, _x, _y, _yerr, _smooth_kwargs, _ax, _label=None, **_plot_kwargs): 

49 """Plots a single parameter density and smooths the marginal. Returns the artists for the legend.""" 

50 smooth, sigma = _smooth_kwargs.pop("enabled", False), _smooth_kwargs.pop( 

51 "smoothing", None 

52 ) 

53 # Smooth the y values, if given 

54 if smooth: 

55 _y = scipy.ndimage.gaussian_filter1d(_y, sigma, **_smooth_kwargs) 

56 

57 # If no yerr is given, plot a single line 

58 if _yerr is None: 

59 (ebar,) = hlpr.ax.plot(_x, _y, label=_label, **_plot_kwargs) 

60 return ebar 

61 

62 # Else, plot errorbands 

63 else: 

64 # Smooth the y error, if set 

65 if smooth: 

66 _yerr = scipy.ndimage.gaussian_filter1d(_yerr, sigma, **_smooth_kwargs) 

67 

68 return plot_errorbar( 

69 ax=_ax, 

70 x=_x, 

71 y=_y, 

72 yerr=_yerr, 

73 label=_label, 

74 fill_between=True, 

75 **_plot_kwargs, 

76 ) 

77 

78 # Get the dataset and parameter name 

79 if "parameter" in list(ds.coords): 

80 pname = ds.coords["parameter"].values.item() 

81 else: 

82 for _c in ds.coords: 

83 # Exclude 1D variables and the hue variable 

84 if ds.coords[_c].shape == (): 

85 continue 

86 if hue is not None and _c == hue: 

87 continue 

88 pname = _c 

89 

90 # Track the legend handles and labels 

91 _handles, _labels = [], [] 

92 if hue: 

93 for i, coord in enumerate(ds.coords[hue].values): 

94 if x in ds.coords: 

95 x_vals = ds.coords[x] 

96 else: 

97 x_vals = ds[x].sel({hue: coord}) 

98 

99 y_vals = ds[y].sel({hue: coord}) 

100 yerr_vals = ds[yerr].sel({hue: coord}) if yerr is not None else None 

101 

102 handle = _plot_1d( 

103 _x=x_vals, 

104 _y=y_vals, 

105 _yerr=yerr_vals, 

106 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)), 

107 _ax=hlpr.ax, 

108 _label=f"{coord}", 

109 linestyle=linestyle if isinstance(linestyle, str) else linestyle[i], 

110 **plot_kwargs, 

111 ) 

112 

113 _handles.append(handle) 

114 _labels.append(f"{coord}") 

115 

116 if not _is_facetgrid: 

117 if add_legend: 

118 hlpr.ax.legend(_handles, _labels, title=hue) 

119 else: 

120 hlpr.track_handles_labels(_handles, _labels) 

121 if add_legend: 

122 hlpr.provide_defaults("set_figlegend", title=hue) 

123 

124 else: 

125 if x in ds.coords: 

126 x_vals = ds.coords[x] 

127 else: 

128 x_vals = ds[x] 

129 y_vals = ds[y] 

130 yerr_vals = ds[yerr] if yerr is not None else None 

131 

132 _plot_1d( 

133 _x=x_vals, 

134 _y=y_vals, 

135 _yerr=yerr_vals, 

136 _ax=hlpr.ax, 

137 _smooth_kwargs=copy.deepcopy(smooth_kwargs.get(pname, smooth_kwargs)), 

138 _label=label, 

139 linestyle=linestyle, 

140 **plot_kwargs, 

141 ) 

142 

143 

144@make_facet_grid_plot( 

145 map_as="dataset", 

146 encodings=("x", "y", "hue", "col", "row"), 

147 supported_hue_styles=("discrete",), 

148 hue_style="discrete", 

149 add_guide=False, 

150 register_as_kind="line_and_scatter" 

151) 

152def line_and_scatter( 

153 ds: xr.Dataset, 

154 hlpr: PlotHelper, 

155 *, 

156 _is_facetgrid: bool, 

157 x: str = None, 

158 y: str = None, 

159 scatter: str, 

160 hue: str, 

161 add_legend: bool = True, 

162 line_kwargs: dict = {}, 

163 scatter_kwargs: dict = {} 

164): 

165 """ Combined line and scatter plot. 

166 

167 :param ds: 

168 :param hlpr: 

169 :param _is_facetgrid: 

170 :param x: 

171 :param y: 

172 :param scatter: 

173 :param hue: 

174 :param add_legend: 

175 :param line_kwargs: 

176 :param scatter_kwargs: 

177 :return: 

178 """ 

179 handles = [] 

180 labels = [] 

181 for i, coord in enumerate(ds.coords[hue].values): 

182 _handle = hlpr.ax.plot(ds.coords[x].data, ds[y].sel({hue: coord}), **line_kwargs, label=coord) 

183 _handle_2 = hlpr.ax.scatter(ds.coords[x].data, ds[scatter].sel({hue: coord}), **scatter_kwargs, label=None) 

184 handles.append(_handle[0]) 

185 labels.append(f"{coord}") 

186 

187 # Create a dummy handle for the legend 

188 from matplotlib.lines import Line2D 

189 

190 true_data_handle = Line2D( 

191 [], [], 

192 marker=_handle_2.get_paths()[0], # Optional: match marker style 

193 markersize=_handle_2.get_sizes()[0]**0.5, 

194 linestyle='None', 

195 color='grey', 

196 label='True data', 

197 markerfacecolor='grey', 

198 markeredgecolor='grey' 

199 ) 

200 

201 handles.append(true_data_handle) 

202 labels.append('True data') 

203 

204 # Add legend 

205 if not _is_facetgrid: 

206 if add_legend: 

207 hlpr.ax.legend(handles, labels, title=hue) 

208 else: 

209 hlpr.track_handles_labels(handles, labels) 

210 if add_legend: 

211 hlpr.provide_defaults("set_figlegend", title=hue) 

212 

213@make_facet_grid_plot( 

214 map_as="dataset", 

215 encodings=("x", "y", "hue", "col", "row"), 

216 supported_hue_styles=("discrete",), 

217 hue_style="discrete", 

218 add_guide=False, 

219 register_as_kind="errorbands_and_scatter" 

220) 

221def errorbands_and_scatter( 

222 ds: xr.Dataset, 

223 hlpr: PlotHelper, 

224 *, 

225 _is_facetgrid: bool, 

226 x: str = None, 

227 y: str, 

228 yerr: str, 

229 scatter: str, 

230 hue: str, 

231 add_legend: bool = True, 

232 line_kwargs: dict = {}, 

233 scatter_kwargs: dict = {} 

234): 

235 handles = [] 

236 labels = [] 

237 for i, coord in enumerate(ds.coords[hue].values): 

238 _handle = plot_errorbar( 

239 ax=hlpr.ax, 

240 x=ds.coords[x].data, 

241 y=ds.sel({hue: coord})[y], 

242 yerr=ds.sel({hue:coord})[yerr], 

243 label=f'{coord}', 

244 fill_between=True, 

245 **line_kwargs 

246 ) 

247 _handle_2 = hlpr.ax.scatter(ds.coords[x].data, ds[scatter].sel({hue: coord}), **scatter_kwargs, 

248 label=None) 

249 handles.append(_handle) 

250 labels.append(f"{coord}") 

251 

252 # Create a dummy handle for the legend 

253 from matplotlib.lines import Line2D 

254 

255 true_data_handle = Line2D( 

256 [], [], 

257 marker=_handle_2.get_paths()[0], # Optional: match marker style 

258 markersize=_handle_2.get_sizes()[0]**0.5, 

259 linestyle='None', 

260 color='grey', 

261 label='True data', 

262 markerfacecolor='grey', 

263 markeredgecolor='grey' 

264 ) 

265 

266 handles.append(true_data_handle) 

267 labels.append('True data') 

268 

269 # Add legend 

270 if not _is_facetgrid: 

271 if add_legend: 

272 hlpr.ax.legend(handles, labels, title=hue) 

273 else: 

274 hlpr.track_handles_labels(handles, labels) 

275 if add_legend: 

276 hlpr.provide_defaults("set_figlegend", title=hue)