Coverage for model_plots/data_ops.py: 81%

271 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1from typing import Any, Sequence, Union 

2 

3import numpy as np 

4import pandas as pd 

5import scipy.signal 

6import xarray as xr 

7 

8from utopya.eval import is_operation 

9 

10from ._op_utils import _get_hist_bins_ranges, _hist, _interpolate, apply_along_dim 

11 

12# --- Custom DAG operations for the NeuralABM model -------------------------------------------------------------------- 

13 

14 

15# ---------------------------------------------------------------------------------------------------------------------- 

16# DATA RESHAPING AND REORGANIZING 

17# ---------------------------------------------------------------------------------------------------------------------- 

18@is_operation("concat_along") 

19def concat(objs: Sequence, name: str, dims: Sequence, *args, **kwargs): 

20 """Combines the pd.Index and xr.concat functions into one. 

21 

22 :param objs: the xarray objects to be concatenated 

23 :param name: the name of the new dimension 

24 :param dims: the coordinates of the new dimension 

25 :param args: passed to ``xr.concat`` 

26 :param kwargs: passed to ``xr.concat`` 

27 :return: objects concatenated along the new dimension 

28 """ 

29 return xr.concat(objs, pd.Index(dims, name=name), *args, **kwargs) 

30 

31 

32@is_operation("flatten_dims") 

33@apply_along_dim 

34def flatten_dims( 

35 ds: Union[xr.Dataset, xr.DataArray], 

36 *, 

37 dims: dict, 

38 new_coords: Sequence = None, 

39) -> Union[xr.Dataset, xr.DataArray]: 

40 """Flattens dimensions of an xarray object into a new dimension. New coordinates can be assigned, 

41 else the dimension is simply given trivial dimensions. The operation is a combination of stacking and 

42 subsequently dropping the multiindex. 

43 

44 :param ds: the xarray object to reshape 

45 :param dims: a dictionary, keyed by the name of the new dimension, and with the dimensions to be flattened as the value 

46 :param new_coords: (optional) coordinates for the new dimension 

47 :return the xarray object with flattened dimensions 

48 """ 

49 new_dim, dims_to_stack = list(dims.keys())[0], list(dims.values())[0] 

50 

51 # Check if the new dimension name already exists. If it already exists, use a temporary name for the new dimension 

52 # switch back later 

53 _renamed = False 

54 if new_dim in list(ds.coords.keys()): 

55 new_dim = f"__{new_dim}__" 

56 _renamed = True 

57 

58 # Stack and drop the dimensions 

59 ds = ds.stack({new_dim: dims_to_stack}) 

60 q = set(dims_to_stack) 

61 q.add(new_dim) 

62 ds = ds.drop_vars(q) 

63 

64 # Name the stacked dimension back to the originally intended name 

65 if _renamed: 

66 ds = ds.rename({new_dim: list(dims.keys())[0]}) 

67 new_dim = list(dims.keys())[0] 

68 # Add coordinates to new dimension and return 

69 if new_coords is None: 

70 return ds.assign_coords({new_dim: np.arange(len(ds.coords[new_dim]))}) 

71 else: 

72 return ds.assign_coords({new_dim: new_coords}) 

73 

74 

75@is_operation("broadcast") 

76@apply_along_dim 

77def broadcast( 

78 ds1: xr.DataArray, ds2: xr.DataArray, *, x: str = "x", p: str = "loss", **kwargs 

79) -> xr.Dataset: 

80 """Broadcasts together two ``xr.DataArray`` s and returns a dataset with given ``x`` and ``p`` as variable names. 

81 

82 :param ds1: the first array 

83 :param ds2: the second array 

84 :param x: name for the new first variable 

85 :param p: name for the new second variable 

86 :param kwargs: passed on to ``xr.broadcast`` 

87 :return: ``xr.Dataset`` with variables ``x`` and ``p`` 

88 """ 

89 return xr.broadcast(xr.Dataset({x: ds1, p: ds2}), **kwargs)[0] 

90 

91 

92# ---------------------------------------------------------------------------------------------------------------------- 

93# BASIC STATISTICS FUNCTIONS 

94# ---------------------------------------------------------------------------------------------------------------------- 

95@is_operation("stat") 

96@apply_along_dim 

97def stat_function( 

98 data: xr.Dataset, *, stat: str, x: str, y: str, **kwargs 

99) -> Union[xr.DataArray, xr.Dataset]: 

100 """Basic statistical function which returns statistical properties of a one-dimensional dataset representing 

101 x and y(x)-values. 

102 

103 :param data: ``xr.Dataset`` along which to calculate the statistic. The dataset must contain the ``y`` key as 

104 a variable, but ``x`` may also be a coordinate name. 

105 :param stat: type of statistic to calculate: can be ``mean``, ``std``, ``iqr``, ``mode``, or ``avg_peak_width``. 

106 When calculating the mode, both the x-value and y-value are returned, and when calculating peak widths, both the 

107 mean width and standard deviation are calculated. 

108 :param x: label of the x-values; can be a variable in the dataset or a coordinate 

109 :param y: function values 

110 :param kwargs: kwargs passed to the respective calculation function 

111 :return: the computed statistic 

112 """ 

113 

114 _permitted_stat_functions = ["mean", "std", "iqr", "mode", "avg_peak_width"] 

115 if stat not in _permitted_stat_functions: 

116 raise ValueError( 

117 f"Unrecognised stat function '{stat}'; choose from '{', '.join(_permitted_stat_functions)}'." 

118 ) 

119 

120 # x-values can be either a variable or a coordinate 

121 if x in data.coords.keys(): 

122 _x_vals = data.coords[x] 

123 else: 

124 _x_vals = data[x] 

125 

126 # Ignore nans in the values 

127 _x_vals, _y_vals = _x_vals[~np.isnan(data[y])], data[y][~np.isnan(data[y])] 

128 

129 # ------------------------------------------------------------------------------------------------------------------ 

130 # Expectation value: m = int f(x) x dx 

131 # ------------------------------------------------------------------------------------------------------------------ 

132 

133 if stat == "mean": 

134 _res = scipy.integrate.trapezoid(_y_vals * _x_vals, _x_vals, **kwargs) 

135 return xr.DataArray(_res, name=stat) 

136 

137 # ------------------------------------------------------------------------------------------------------------------ 

138 # Standard deviation: std^2 = int (x - m)^2 f(x) dx 

139 # ------------------------------------------------------------------------------------------------------------------ 

140 

141 elif stat == "std": 

142 _m = stat_function(data, x=x, y=y, stat="mean") 

143 _res = np.sqrt( 

144 scipy.integrate.trapezoid(_y_vals * (_x_vals - _m) ** 2, _x_vals, **kwargs) 

145 ) 

146 return xr.DataArray(_res, name=stat) 

147 

148 # ------------------------------------------------------------------------------------------------------------------ 

149 # Interquartile range: length between first and third quartile 

150 # ------------------------------------------------------------------------------------------------------------------ 

151 

152 elif stat == "iqr": 

153 _int = scipy.integrate.trapezoid(_y_vals, _x_vals, **kwargs) 

154 _a_0 = -1.0 

155 __int = 0.0 

156 _res = 0.0 

157 for i in range(1, len(_x_vals)): 

158 __int += scipy.integrate.trapezoid( 

159 _y_vals[i - 1 : i + 1], _x_vals[i - 1 : i + 1], **kwargs 

160 ) 

161 if __int > 0.25 * _int and _a_0 == -1: 

162 _a_0 = _x_vals[i].item() 

163 if __int > 0.75 * _int: 

164 _res = _x_vals[i].item() - _a_0 

165 break 

166 return xr.DataArray(_res, name=stat) 

167 

168 # ------------------------------------------------------------------------------------------------------------------ 

169 # Mode: both the x-value and the y-value of the mode are returned 

170 # ------------------------------------------------------------------------------------------------------------------ 

171 

172 elif stat == "mode": 

173 # Get the index of the mode and select it 

174 idx_max = np.argmax(_y_vals.data, **kwargs) 

175 mode_x = _x_vals[idx_max] 

176 mode_y = _y_vals[idx_max] 

177 return xr.Dataset(data_vars=dict(mode_x=mode_x, mode_y=mode_y)) 

178 

179 # ------------------------------------------------------------------------------------------------------------------ 

180 # Average peak width: both the mean width and standard deviation of the widths is returned 

181 # ------------------------------------------------------------------------------------------------------------------ 

182 

183 elif stat == "avg_peak_width": 

184 if "width" not in kwargs: 

185 raise Exception("'width' kwarg required for 'scipy.signal.find_peaks'!") 

186 

187 # Insert a zero at the beginning and the end of the array to ensure peaks at the ends are found 

188 _y_vals = np.insert(np.insert(_y_vals, 0, 0), 0, -1) 

189 

190 # Find the peaks along the array 

191 peaks = scipy.signal.find_peaks(_y_vals, **kwargs) 

192 

193 # Calculate the mean and standard deviation of the peaks 

194 mean, std = ( 

195 np.mean(peaks[1]["widths"]) * np.diff(_x_vals)[0], 

196 np.std(peaks[1]["widths"]) * np.diff(_x_vals)[0], 

197 ) 

198 

199 return xr.Dataset(data_vars=dict(mean_peak_width=mean, peak_width_std=std)) 

200 

201 

202@is_operation("mean") 

203@apply_along_dim 

204def mean(*args, **kwargs) -> xr.Dataset: 

205 """Computes the mean of a dataset""" 

206 return stat_function(*args, stat="mean", **kwargs) 

207 

208 

209@is_operation("std") 

210@apply_along_dim 

211def std(*args, **kwargs) -> xr.Dataset: 

212 """Computes the standard deviation of a dataset""" 

213 return stat_function(*args, stat="std", **kwargs) 

214 

215 

216@is_operation("iqr") 

217@apply_along_dim 

218def iqr(*args, **kwargs) -> xr.Dataset: 

219 """Computes the interquartile range of a dataset""" 

220 return stat_function(*args, stat="iqr", **kwargs) 

221 

222 

223@is_operation("mode") 

224@apply_along_dim 

225def mode(*args, **kwargs) -> xr.Dataset: 

226 """Computes the mode of a dataset""" 

227 return stat_function(*args, stat="mode", **kwargs) 

228 

229 

230@is_operation("avg_peak_width") 

231@apply_along_dim 

232def avg_peak_width(*args, **kwargs) -> xr.Dataset: 

233 """Computes the average peak width and std of peak widths of a dataset""" 

234 return stat_function(*args, stat="avg_peak_width", **kwargs) 

235 

236 

237@is_operation("p_value") 

238@apply_along_dim 

239def p_value( 

240 data: xr.Dataset, point: Any, *, x: str, y: str, null: str = "mean" 

241) -> xr.DataArray: 

242 """Calculates the p value of a ``point`` from a Dataset containing x-y-pairs. It is assumed the integral under y 

243 is normalised for the p-value to be meaningful. The p-value can be calculated wrt to the mean or the mode 

244 of the distribution 

245 

246 :param data: ``xr.Dataset`` containing the x and p(x) values 

247 :param point: point at which to calculate the p value 

248 :param x: label of the x-values; can be a variable in the dataset or a coordinate 

249 :param y: function values; assumed to be normalised 

250 :param null: (optional) null wrt which the p-value is to be calculated; can be either ``mean`` 

251 or ``mode`` 

252 :return: ``xr.DataArray`` of the p-value of ``point`` 

253 """ 

254 

255 # x can be both a variable and a coordinate 

256 if x in data.coords.keys(): 

257 _x_vals = data.coords[x] 

258 else: 

259 _x_vals = data[x] 

260 

261 if isinstance(point, xr.DataArray): 

262 point = point.data 

263 

264 # Calculate the value of the null of the distribution 

265 m = ( 

266 mean(data, x=x, y=y).data 

267 if null == "mean" 

268 else mode(data, x=x, y=y)["mode_x"].data 

269 ) 

270 

271 # Calculate the index of the point 

272 t_index = np.argmin(np.abs(_x_vals - point).data) 

273 

274 # Calculate the p-value depending on the location of the point 

275 if point >= m: 

276 return xr.DataArray( 

277 scipy.integrate.trapezoid(data[y][t_index:], _x_vals[t_index:]), 

278 name="p_value", 

279 ) 

280 

281 else: 

282 return xr.DataArray( 

283 scipy.integrate.trapezoid(data[y][:t_index], _x_vals[:t_index]), 

284 name="p_value", 

285 ) 

286 

287 

288@is_operation("normalize") 

289@apply_along_dim 

290def normalize( 

291 distribution: xr.Dataset, *, x: str, y: str, norm: float = 1, **kwargs 

292) -> xr.Dataset: 

293 """Normalises a probability distribution of x- and y-values 

294 

295 :param distribution: ``xr.Dataset`` of x- and y-values 

296 :param x: the x-values 

297 :param y: the function values 

298 :param norm: (optional) value to which to normalise the distribution 

299 :param kwargs: passed to ``scipy.integrate.trapezoid`` 

300 :return: the normalised probability distribution 

301 """ 

302 

303 integral = scipy.integrate.trapezoid(distribution[y], distribution[x], **kwargs) 

304 distribution[y] *= norm / integral 

305 return distribution 

306 

307 

308# ---------------------------------------------------------------------------------------------------------------------- 

309# HISTOGRAMS 

310# ---------------------------------------------------------------------------------------------------------------------- 

311 

312 

313@is_operation("hist") 

314@apply_along_dim 

315def hist( 

316 da: xr.DataArray, 

317 bins: Any = 100, 

318 ranges: Any = None, 

319 *, 

320 dim: str, 

321 axis: int = None, 

322 normalize: Union[float, bool] = False, 

323 use_bins_as_coords: bool = False, 

324 **kwargs, 

325) -> Union[xr.Dataset, xr.DataArray]: 

326 """Applies ``np.histogram`` using the ``apply_along_dim`` decorator to allow histogramming along multiple 

327 dimensions. This function applies ``np.histogram`` along a single axis of an ``xr.DataArray`` object; 

328 it is recommended to only use ``apply_along_dim`` across small dimensions, as splitting and recombining the 

329 xarray objects is very expensive. 

330 

331 :param da: the ``xr.DataArray`` on which to apply the histogram function 

332 :param bins: the bins to use, passed to ``np.histogram``. This can be a single integer, in which case it is 

333 interpreted as the number of bins, a Sequence defining the bin edges, or a string defining the method to use. 

334 See ``np.histogram`` for details 

335 :param ranges: (optional): the lower and upper range of the bins 

336 :param dim: the dimension along which to apply the operation. If not passed, an ``axis`` argument must be 

337 provided 

338 :param axis: (optional) the axis along which to apply np.histogram. 

339 :param normalize: whether to normalize the counts. Can be a boolean or a float, in which case the counts are 

340 normalized to that value 

341 :param use_bins_as_coords: whether to use the bin centres as coordinates of the dataset, or as variables. If true, 

342 a ``xr.DataArray`` is returned, with the bin centres as coordinates and the counts as the data. This may 

343 cause incompatibilities with ``apply_along_dim``, since different samples have different bin_centres. For this 

344 reason, the default behaviour is to return a ``xr.Dataset`` with the bin_centres and counts as variables, 

345 and ``bin_idx`` as the coordinate. This enables combining different histograms with different bin centres 

346 (but same number of bins) into a single dataset. If passed, `ranges` must also be passed to ensure 

347 all histogram bins are identical. 

348 :param kwargs: passed to ``np.histogram`` 

349 :return ``xr.DataArray`` or ``xr.Dataset`` containing the bin centres either as coordinates or as variables, 

350 and the counts. 

351 """ 

352 if dim is None and axis is None: 

353 raise ValueError("Must supply either 'dim' or 'axis' arguments!") 

354 

355 if use_bins_as_coords and ranges is None: 

356 raise ValueError( 

357 "Setting 'use_bins_as_coords' to 'True' requires passing a 'ranges' argument to " 

358 "ensure all coordinates are equal" 

359 ) 

360 # Get the axis along which to apply the operations 

361 if dim is not None: 

362 axis = list(da.dims).index(dim) 

363 

364 # Get the bins and range objects 

365 bins, ranges = _get_hist_bins_ranges(da, bins, ranges, axis) 

366 

367 # Apply the histogram function along the axis 

368 res = np.apply_along_axis( 

369 _hist, axis, da.data, bins=bins, range=ranges, normalize=normalize, **kwargs 

370 ) 

371 

372 # Get the counts and the bin centres. Note that the bin centres are equal along every dimension! 

373 counts, bin_centres = np.take(res, 0, axis=axis), np.take(res, 1, axis=axis) 

374 

375 # Put the dataset back together again, relabelling the coordinate dimension that was binned 

376 coords = dict(da.coords) 

377 

378 # Bin centres are to be used as coordinates 

379 if use_bins_as_coords: 

380 sel = [0] * len(np.shape(bin_centres)) 

381 sel[axis] = None 

382 bin_centres = bin_centres[tuple(sel)].flatten() 

383 coords.update({dim: bin_centres}) 

384 

385 res = xr.DataArray( 

386 counts, 

387 dims=list(da.sizes.keys()), 

388 coords=coords, 

389 name=da.name if da.name else "count", 

390 ) 

391 return res.rename({dim: "x"}) 

392 

393 else: 

394 coords.update({dim: np.arange(np.shape(bin_centres)[axis])}) 

395 other_dim = list(coords.keys()) 

396 other_dim.remove(dim) 

397 attrs = [*other_dim, "bin_idx"] if other_dim else ["bin_idx"] 

398 coords["bin_idx"] = coords.pop(dim) 

399 

400 return xr.Dataset( 

401 data_vars={ 

402 da.name if da.name else "count": (attrs, counts), 

403 "x": (attrs, bin_centres), 

404 }, 

405 coords=coords, 

406 ) 

407 

408 

409# ---------------------------------------------------------------------------------------------------------------------- 

410# DISTANCES BETWEEN PROBABILITY DENSITIES 

411# ---------------------------------------------------------------------------------------------------------------------- 

412@is_operation("distances_between_distributions") 

413@apply_along_dim 

414def distances_between_distributions( 

415 P: Union[xr.DataArray, xr.Dataset], 

416 Q: Union[xr.DataArray, xr.Dataset], 

417 *, 

418 stat: str, 

419 p: float = 2, 

420 x: str = None, 

421 y: str = None, 

422 **kwargs, 

423) -> xr.DataArray: 

424 """Calculates distances between two distributions P and Q. Possible distances are: 

425 

426 - Hellinger distance: d(P, Q) = 1/2 * integral sqrt(P(x)) - sqrt(Q(x))**2 dx. 

427 - Relative entropy: d(P, Q) = integral P(x) log(P(x)/Q(x))dx 

428 - Lp distance: d(P, Q) = ( integral (P(x) - Q(x))^p dx)^{1/p} 

429 

430 These distances are calculated on the common support of P and Q; if P and Q have different discretisation 

431 levels, the functions are interpolated. 

432 

433 :param P: one-dimensional ``xr.DataArray`` or ``xr.Dataset`` of values for P. If ``xr.Dataset``, ``x`` and ``y`` 

434 arguments must be passed. 

435 :param Q: one-dimensional ``xr.DataArray`` or ``xr.Dataset`` of values for Q. If ``xr.Dataset``, ``x`` and ``y`` 

436 arguments must be passed. 

437 :param stat: which density to function to use 

438 :param p: p-value for the Lp distance 

439 :param x: x-values to use if P and Q are ``xr.Datasets``. 

440 :param y: y-values to use if P and Q are ``xr.Datasets``. 

441 :param kwargs: kwargs, passed on to ``scipy.integrate.trapezoid`` 

442 :return: the distance between p and q 

443 """ 

444 

445 _permitted_stat_functions = ["Hellinger", "relative_entropy", "Lp"] 

446 if stat not in _permitted_stat_functions: 

447 raise ValueError( 

448 f"Unrecognised stat function '{stat}'; choose from '{', '.join(_permitted_stat_functions)}'." 

449 ) 

450 

451 # If P and Q are datasets, convert to DataArrays 

452 if isinstance(P, xr.Dataset): 

453 P = xr.DataArray( 

454 P[y], coords={"x": P[x] if x in list(P.data_vars) else P.coords[x]} 

455 ) 

456 if isinstance(Q, xr.Dataset): 

457 Q = xr.DataArray( 

458 Q[y], coords={"x": Q[x] if x in list(Q.data_vars) else Q.coords[x]} 

459 ) 

460 

461 # Interpolate P and Q on their common support 

462 P, Q, grid = _interpolate(P, Q) 

463 

464 # Hellinger distance 

465 if stat == "Hellinger": 

466 return xr.DataArray( 

467 0.5 

468 * scipy.integrate.trapezoid( 

469 np.square(np.sqrt(P) - np.sqrt(Q)), grid, **kwargs 

470 ), 

471 name="Hellinger_distance", 

472 ) 

473 

474 # Relative entropy 

475 elif stat == "relative_entropy": 

476 P, Q = np.where(P != 0, P, 1), np.where(Q != 0, Q, 1) 

477 return xr.DataArray( 

478 scipy.integrate.trapezoid(P * np.log(P / Q), grid, **kwargs), 

479 name="relative_entropy", 

480 ) 

481 

482 # Lp distance 

483 elif stat == "Lp": 

484 return xr.DataArray( 

485 scipy.integrate.trapezoid((P - Q) ** p, grid, **kwargs) ** (1 / p), 

486 name=f"Lp_distance", 

487 ) 

488 

489 

490@is_operation("Hellinger_distance") 

491@apply_along_dim 

492def Hellinger_distance(*args, **kwargs) -> xr.DataArray: 

493 return distances_between_distributions(*args, stat="Hellinger", **kwargs) 

494 

495 

496@is_operation("relative_entropy") 

497@apply_along_dim 

498def relative_entropy(*args, **kwargs) -> xr.DataArray: 

499 return distances_between_distributions(*args, stat="relative_entropy", **kwargs) 

500 

501 

502@is_operation("Lp_distance") 

503@apply_along_dim 

504def Lp_distance(*args, **kwargs) -> xr.DataArray: 

505 return distances_between_distributions(*args, stat="Lp", **kwargs) 

506 

507 

508# ---------------------------------------------------------------------------------------------------------------------- 

509# PROBABILITY DENSITY FUNCTIONS 

510# ---------------------------------------------------------------------------------------------------------------------- 

511 

512 

513@is_operation("joint_2D") 

514@apply_along_dim 

515def joint_2D( 

516 x: xr.DataArray, 

517 y: xr.DataArray, 

518 values: xr.DataArray, 

519 bins: Union[int, xr.DataArray] = 100, 

520 ranges: xr.DataArray = None, 

521 *, 

522 statistic: Union[str, callable] = "mean", 

523 normalize: Union[bool, float] = False, 

524 dx: float = None, 

525 dy: float = None, 

526 dim_names: Sequence = ("x", "y"), 

527 **kwargs, 

528) -> xr.DataArray: 

529 """ 

530 Computes the two-dimensional joint distribution of a dataset of parameters by calling the scipy.stats.binned_statistic_2d 

531 function. The function returns a statistic for each bin (typically the mean). 

532 

533 :param x: DataArray of samples in the first dimension 

534 :param y: DataArray of samples in the second dimension 

535 :param values: DataArray of values to be binned 

536 :param bins: (optional) ``bins`` argument to ``scipy.binned_statistic_2d`` 

537 :param ranges: (optional) ``range`` argument to ``scipy.binned_statistic_2d`` 

538 :param statistic: (optional) ``statistic`` argument to ``scipy.binned_statistic_2d`` 

539 :param normalize: (optional) whether to normalize the joint (False by default), and the normalisation value (1 by default) 

540 :param dx: (optional) the spacial differential dx to use for normalisation. If provided, the norm will not be 

541 calculated by integrating against the x-values, but rather by assuming the coordinates are spaced ``dx`` apart 

542 :param dy: (optional) the spacial differential dy to use for normalisation. If provided, the norm will not be 

543 calculated by integrating against the y-values, but rather by assuming the coordinates are spaced ``dy`` apart 

544 :param dim_names: (optional) names of the two dimensions 

545 :return: ``xr.DataArray`` of the joint distribution 

546 """ 

547 

548 # Get the number of bins 

549 if isinstance(bins, xr.DataArray): 

550 bins = bins.data 

551 

552 # Allow passing 'None' arguments in the plot config for certain entries of the range arg 

553 # This allows clipping only on some dimensions without having to specify every limit 

554 if ranges is not None: 

555 ranges = ( 

556 np.array(ranges.data) 

557 if isinstance(ranges, xr.DataArray) 

558 else np.array(ranges) 

559 ) 

560 for idx in range(len(ranges)): 

561 if None in ranges[idx]: 

562 ranges[idx] = ( 

563 [np.min(x), np.max(x)] if idx == 0 else [np.min(y), np.max(y)] 

564 ) 

565 else: 

566 ranges = kwargs.pop("range", None) 

567 

568 # Get the statistics and bin edges 

569 stat, x_edge, y_edge, _ = scipy.stats.binned_statistic_2d( 

570 x, y, values, statistic=statistic, bins=bins, range=ranges, **kwargs 

571 ) 

572 # Normalise the joint distribution, if given 

573 if normalize: 

574 if dy is None: 

575 int_y = [ 

576 scipy.integrate.trapezoid( 

577 stat[i][~np.isnan(stat[i])], 

578 0.5 * (y_edge[1:] + y_edge[:-1])[~np.isnan(stat[i])], 

579 ) 

580 for i in range(stat.shape[0]) 

581 ] 

582 else: 

583 int_y = [ 

584 scipy.integrate.trapezoid(stat[i][~np.isnan(stat[i])], dx=dy) 

585 for i in range(stat.shape[0]) 

586 ] 

587 

588 norm = ( 

589 scipy.integrate.trapezoid(int_y, 0.5 * (x_edge[1:] + x_edge[:-1])) 

590 if dx is None 

591 else scipy.integrate.trapezoid(int_y, dx=dx) 

592 ) 

593 if norm == 0: 

594 norm = 1 

595 stat /= norm if isinstance(normalize, bool) else norm / normalize 

596 

597 return xr.DataArray( 

598 data=stat, 

599 dims=dim_names, 

600 coords={ 

601 dim_names[0]: 0.5 * (x_edge[1:] + x_edge[:-1]), 

602 dim_names[1]: 0.5 * (y_edge[1:] + y_edge[:-1]), 

603 }, 

604 name="joint", 

605 ) 

606 

607 

608@is_operation("joint_2D_ds") 

609@apply_along_dim 

610def joint_2D_ds( 

611 ds: Union[xr.DataArray, xr.Dataset], 

612 values: xr.DataArray, 

613 bins: xr.DataArray = 100, 

614 ranges: xr.DataArray = None, 

615 *, 

616 x: str, 

617 y: str, 

618 **kwargs, 

619) -> xr.DataArray: 

620 """Computes a two-dimensional joint from a single dataset with x and y given as variables, or from 

621 a DataArray with x and y given as coordinate dimensions.""" 

622 

623 if isinstance(ds, xr.Dataset): 

624 return joint_2D(ds[x], ds[y], values, bins, ranges, dim_names=(x, y), **kwargs) 

625 elif isinstance(ds, xr.DataArray): 

626 return joint_2D( 

627 ds.sel(dict(parameter=x)), 

628 ds.sel(dict(parameter=y)), 

629 values, 

630 bins, 

631 ranges, 

632 dim_names=(x, y), 

633 **kwargs, 

634 ) 

635 

636 

637@is_operation("marginal_from_joint") 

638@apply_along_dim 

639def marginal_from_joint( 

640 joint: xr.DataArray, 

641 *, 

642 parameter: str, 

643 normalize: Union[bool, float] = True, 

644 scale_y_bins: bool = False, 

645) -> xr.Dataset: 

646 """ 

647 Computes a marginal from a two-dimensional joint distribution by summing over one parameter. Normalizes 

648 the marginal, if specified. NaN values in the joint are skipped when normalising: they are not zero, just unknown. 

649 Since x-values may differ for different parameters, the x-values are variables in a dataset, not coordinates. 

650 The coordinates are given by the bin index, thereby allowing marginals across multiple parameters to be combined 

651 into a single xr.Dataset. 

652 

653 :param joint: the joint distribution over which to marginalise 

654 :param normalize: whether to normalize the marginal distribution. If true, normalizes to 1, else normalizes to 

655 a given value 

656 :param scale_y_bins: whether to scale the integration over y by range of the given values (y_max - y_min) 

657 """ 

658 

659 # Get the integration coordinate 

660 integration_coord = [c for c in list(joint.coords) if c != parameter][0] 

661 

662 # Marginalise over the integration coordinate 

663 marginal = np.array([]) 

664 for p in joint.coords[parameter]: 

665 _y, _x = joint.sel({parameter: p}).data, joint.coords[integration_coord] 

666 if scale_y_bins and not np.isnan(_y).all(): 

667 _f = np.nanmax(_y) - np.nanmin(_y) 

668 _f = 1.0 / _f if _f != 0 else 1.0 

669 else: 

670 _f = 1.0 

671 marginal = np.append( 

672 marginal, 

673 _f * scipy.integrate.trapezoid(_y[~np.isnan(_y)], _x[~np.isnan(_y)]), 

674 ) 

675 

676 # Normalise, if given 

677 if normalize: 

678 norm = scipy.integrate.trapezoid(marginal, joint.coords[parameter]) 

679 if norm == 0: 

680 norm = 1 

681 marginal /= norm if isinstance(normalize, bool) else norm / normalize 

682 

683 # Return a dataset with x- and y-values as variables, and coordinates given by the bin index 

684 # This allows combining different marginals with different x-values but identical number of bins 

685 # into a single dataset 

686 return xr.Dataset( 

687 data_vars=dict( 

688 x=(["bin_idx"], joint.coords[parameter].data), 

689 y=(["bin_idx"], marginal), 

690 ), 

691 coords=dict( 

692 bin_idx=(["bin_idx"], np.arange(len(joint.coords[parameter].data))) 

693 ), 

694 ) 

695 

696 

697@is_operation("marginal") 

698@apply_along_dim 

699def marginal( 

700 x: xr.DataArray, 

701 prob: xr.DataArray, 

702 bins: Union[int, xr.DataArray] = None, 

703 ranges: Union[Sequence, xr.DataArray] = None, 

704 *, 

705 parameter: str = "x", 

706 normalize: Union[bool, float] = True, 

707 scale_y_bins: bool = False, 

708 **kwargs, 

709) -> xr.Dataset: 

710 """ 

711 Computes a marginal directly from a ``xr.DataArray`` of x-values and a ``xr.DataArray`` of probabilities by first 

712 computing the joint distribution and then marginalising over the probability. This way, points that are sampled 

713 multiple times only contribute once to the marginal, which is not a representation of the frequency with which 

714 each point is sampled, but of the calculated likelihood function. 

715 

716 :param x: array of samples of the first variable (the parameter estimates) 

717 :param prob: array of samples of (unnormalised) probability values 

718 :param bins: bins to use for both dimensions 

719 :param range: range to use for both dimensions. Defaults to the minimum and maximum along each dimension 

720 :param parameter: the parameter over which to marginalise. Defaults to the first dimension. 

721 :param normalize: whether to normalize the marginal 

722 :param scale_y_bins: whether to scale the integration over y by range of the given values (y_max - y_min) 

723 :param kwargs: other kwargs, passed to ``joint_2D`` 

724 :return: ``xr.Dataset`` of the marginal densities 

725 """ 

726 joint = joint_2D(x, prob, prob, bins, ranges, normalize=normalize, **kwargs) 

727 return marginal_from_joint( 

728 joint, parameter=parameter, normalize=normalize, scale_y_bins=scale_y_bins 

729 ) 

730 

731 

732@is_operation("marginal_from_ds") 

733@apply_along_dim 

734def marginal_from_ds( 

735 ds: xr.Dataset, 

736 bins: xr.DataArray = 100, 

737 ranges: xr.DataArray = None, 

738 *, 

739 x: str, 

740 y: str, 

741 **kwargs, 

742) -> xr.Dataset: 

743 """Computes the marginal from a single dataset with x and y given as variables.""" 

744 return marginal(ds[x], ds[y], bins, ranges, **kwargs) 

745 

746 

747@is_operation("joint_DD") 

748@apply_along_dim 

749def joint_DD( 

750 sample: xr.DataArray, 

751 values: xr.DataArray, 

752 bins: Union[int, xr.DataArray] = 100, 

753 ranges: xr.DataArray = None, 

754 *, 

755 statistic: Union[str, callable] = "mean", 

756 normalize: Union[bool, float] = False, 

757 dim_names: Sequence = None, 

758 **kwargs, 

759) -> xr.DataArray: 

760 """ 

761 Computes the d-dimensional joint distribution of a dataset of parameters by calling ``scipy.stats.binned_statistic_dd``. 

762 This function can handle at most 32 parameters. A statistic for each bin is returned (mean by default). 

763 

764 :param sample: ``xr.DataArray`` of samples of shape ``(N, D)`` 

765 :param values: ``xr.DataArray`` of values to be binned, of shape ``(D, )`` 

766 :param bins: bins argument to ``scipy.binned_statistic_dd`` 

767 :param ranges: range argument to ``scipy.binned_statistic_dd`` 

768 :param statistic: (optional) ``statistic`` argument to ``scipy.binned_statistic_2d`` 

769 :param normalize: (not implemented) whether to normalize the joint (False by default), 

770 and the normalisation value (1 by default) 

771 :param dim_names: (optional) names of the two dimensions 

772 :return: ``xr.Dataset`` of the joint distribution 

773 """ 

774 if normalize: 

775 raise NotImplementedError( 

776 "Normalisation for d-dimensional joints is not yet implemented!" 

777 ) 

778 

779 # Get the number of bins 

780 if isinstance(bins, xr.DataArray): 

781 bins = bins.data 

782 

783 dim_names = ( 

784 sample.coords[list(sample.dims)[-1]].data if dim_names is None else dim_names 

785 ) 

786 

787 # Allow passing 'None' arguments in the plot config for certain entries of the range arg 

788 # This allows clipping only on some dimensions without having to specify every limit 

789 if ranges is not None: 

790 ranges = ranges.data if isinstance(ranges, xr.DataArray) else ranges 

791 for idx in range(len(ranges)): 

792 if None in ranges[idx]: 

793 ranges[idx] = [np.min(sample.coords[idx]), np.max(sample.coords[idx])] 

794 else: 

795 ranges = kwargs.pop("range", None) 

796 

797 # Get the statistics and bin edges 

798 stat, bin_edges, _ = scipy.stats.binned_statistic_dd( 

799 sample, values, statistic=statistic, bins=bins, range=ranges, **kwargs 

800 ) 

801 

802 return xr.DataArray( 

803 data=stat, 

804 dims=dim_names, 

805 coords={dim_names[i]: 0.5 * (b[1:] + b[:-1]) for i, b in enumerate(bin_edges)}, 

806 name="joint", 

807 ) 

808 

809 

810# ---------------------------------------------------------------------------------------------------------------------- 

811# MCMC operations 

812# ---------------------------------------------------------------------------------------------------------------------- 

813@is_operation("batch_mean") 

814@apply_along_dim 

815def batch_mean(da: xr.DataArray, *, batch_size: int = None) -> xr.Dataset: 

816 """Computes the mean of a single sampling chain over batches of length B. Default batch length is 

817 int(sqrt(N)), where N is the length of the chain. 

818 

819 :param da: dataarray of samples 

820 :param batch_size: batch length over which to compute averages 

821 :return: res: averages of the batches 

822 """ 

823 vals = da.data 

824 means = np.array([]) 

825 windows = np.arange(0, len(vals), batch_size) 

826 if len(windows) == 1: 

827 windows = np.append(windows, len(vals) - 1) 

828 else: 

829 if windows[-1] != len(vals) - 1: 

830 windows = np.append(windows, len(vals) - 1) 

831 for idx, start_idx in enumerate(windows[:-1]): 

832 means = np.append(means, np.mean(vals[start_idx : windows[idx + 1]])) 

833 

834 return xr.Dataset( 

835 data_vars=dict(means=("batch_idx", means)), 

836 coords=dict(batch_idx=("batch_idx", np.arange(len(means)))), 

837 ) 

838 

839 

840@is_operation("gelman_rubin") 

841@apply_along_dim 

842def gelman_rubin(da: xr.Dataset, *, step_size: int = 1) -> xr.Dataset: 

843 R = [] 

844 for i in range(step_size, len(da.coords["sample"]), step_size): 

845 da_sub = da.isel({"sample": slice(0, i)}) 

846 L = len(da_sub.coords["sample"]) 

847 

848 chain_mean = da_sub.mean("sample") 

849 between_chain_variance = L * chain_mean.std("seed", ddof=1) ** 2 

850 within_chain_variance = da_sub.std("sample", ddof=1) ** 2 

851 W = within_chain_variance.mean("seed") 

852 R.append(((L - 1) * W / L + 1 / L * between_chain_variance) / W) 

853 

854 return xr.Dataset( 

855 data_vars=dict(gelman_rubin=("sample", R)), 

856 coords=dict( 

857 sample=("sample", np.arange(step_size, len(da.coords["sample"]), step_size)) 

858 ), 

859 ) 

860 

861 

862# ---------------------------------------------------------------------------------------------------------------------- 

863# CSV operations 

864# ---------------------------------------------------------------------------------------------------------------------- 

865@is_operation("to_csv") 

866def to_csv( 

867 data: Union[xr.Dataset, xr.DataArray], path: str 

868) -> Union[xr.Dataset, xr.DataArray]: 

869 df = data.to_dataframe() 

870 df.to_csv(path) 

871 return data