Coverage for model_plots/data_ops.py: 81%
271 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1from typing import Any, Sequence, Union
3import numpy as np
4import pandas as pd
5import scipy.signal
6import xarray as xr
8from utopya.eval import is_operation
10from ._op_utils import _get_hist_bins_ranges, _hist, _interpolate, apply_along_dim
12# --- Custom DAG operations for the NeuralABM model --------------------------------------------------------------------
15# ----------------------------------------------------------------------------------------------------------------------
16# DATA RESHAPING AND REORGANIZING
17# ----------------------------------------------------------------------------------------------------------------------
18@is_operation("concat_along")
19def concat(objs: Sequence, name: str, dims: Sequence, *args, **kwargs):
20 """Combines the pd.Index and xr.concat functions into one.
22 :param objs: the xarray objects to be concatenated
23 :param name: the name of the new dimension
24 :param dims: the coordinates of the new dimension
25 :param args: passed to ``xr.concat``
26 :param kwargs: passed to ``xr.concat``
27 :return: objects concatenated along the new dimension
28 """
29 return xr.concat(objs, pd.Index(dims, name=name), *args, **kwargs)
32@is_operation("flatten_dims")
33@apply_along_dim
34def flatten_dims(
35 ds: Union[xr.Dataset, xr.DataArray],
36 *,
37 dims: dict,
38 new_coords: Sequence = None,
39) -> Union[xr.Dataset, xr.DataArray]:
40 """Flattens dimensions of an xarray object into a new dimension. New coordinates can be assigned,
41 else the dimension is simply given trivial dimensions. The operation is a combination of stacking and
42 subsequently dropping the multiindex.
44 :param ds: the xarray object to reshape
45 :param dims: a dictionary, keyed by the name of the new dimension, and with the dimensions to be flattened as the value
46 :param new_coords: (optional) coordinates for the new dimension
47 :return the xarray object with flattened dimensions
48 """
49 new_dim, dims_to_stack = list(dims.keys())[0], list(dims.values())[0]
51 # Check if the new dimension name already exists. If it already exists, use a temporary name for the new dimension
52 # switch back later
53 _renamed = False
54 if new_dim in list(ds.coords.keys()):
55 new_dim = f"__{new_dim}__"
56 _renamed = True
58 # Stack and drop the dimensions
59 ds = ds.stack({new_dim: dims_to_stack})
60 q = set(dims_to_stack)
61 q.add(new_dim)
62 ds = ds.drop_vars(q)
64 # Name the stacked dimension back to the originally intended name
65 if _renamed:
66 ds = ds.rename({new_dim: list(dims.keys())[0]})
67 new_dim = list(dims.keys())[0]
68 # Add coordinates to new dimension and return
69 if new_coords is None:
70 return ds.assign_coords({new_dim: np.arange(len(ds.coords[new_dim]))})
71 else:
72 return ds.assign_coords({new_dim: new_coords})
75@is_operation("broadcast")
76@apply_along_dim
77def broadcast(
78 ds1: xr.DataArray, ds2: xr.DataArray, *, x: str = "x", p: str = "loss", **kwargs
79) -> xr.Dataset:
80 """Broadcasts together two ``xr.DataArray`` s and returns a dataset with given ``x`` and ``p`` as variable names.
82 :param ds1: the first array
83 :param ds2: the second array
84 :param x: name for the new first variable
85 :param p: name for the new second variable
86 :param kwargs: passed on to ``xr.broadcast``
87 :return: ``xr.Dataset`` with variables ``x`` and ``p``
88 """
89 return xr.broadcast(xr.Dataset({x: ds1, p: ds2}), **kwargs)[0]
92# ----------------------------------------------------------------------------------------------------------------------
93# BASIC STATISTICS FUNCTIONS
94# ----------------------------------------------------------------------------------------------------------------------
95@is_operation("stat")
96@apply_along_dim
97def stat_function(
98 data: xr.Dataset, *, stat: str, x: str, y: str, **kwargs
99) -> Union[xr.DataArray, xr.Dataset]:
100 """Basic statistical function which returns statistical properties of a one-dimensional dataset representing
101 x and y(x)-values.
103 :param data: ``xr.Dataset`` along which to calculate the statistic. The dataset must contain the ``y`` key as
104 a variable, but ``x`` may also be a coordinate name.
105 :param stat: type of statistic to calculate: can be ``mean``, ``std``, ``iqr``, ``mode``, or ``avg_peak_width``.
106 When calculating the mode, both the x-value and y-value are returned, and when calculating peak widths, both the
107 mean width and standard deviation are calculated.
108 :param x: label of the x-values; can be a variable in the dataset or a coordinate
109 :param y: function values
110 :param kwargs: kwargs passed to the respective calculation function
111 :return: the computed statistic
112 """
114 _permitted_stat_functions = ["mean", "std", "iqr", "mode", "avg_peak_width"]
115 if stat not in _permitted_stat_functions:
116 raise ValueError(
117 f"Unrecognised stat function '{stat}'; choose from '{', '.join(_permitted_stat_functions)}'."
118 )
120 # x-values can be either a variable or a coordinate
121 if x in data.coords.keys():
122 _x_vals = data.coords[x]
123 else:
124 _x_vals = data[x]
126 # Ignore nans in the values
127 _x_vals, _y_vals = _x_vals[~np.isnan(data[y])], data[y][~np.isnan(data[y])]
129 # ------------------------------------------------------------------------------------------------------------------
130 # Expectation value: m = int f(x) x dx
131 # ------------------------------------------------------------------------------------------------------------------
133 if stat == "mean":
134 _res = scipy.integrate.trapezoid(_y_vals * _x_vals, _x_vals, **kwargs)
135 return xr.DataArray(_res, name=stat)
137 # ------------------------------------------------------------------------------------------------------------------
138 # Standard deviation: std^2 = int (x - m)^2 f(x) dx
139 # ------------------------------------------------------------------------------------------------------------------
141 elif stat == "std":
142 _m = stat_function(data, x=x, y=y, stat="mean")
143 _res = np.sqrt(
144 scipy.integrate.trapezoid(_y_vals * (_x_vals - _m) ** 2, _x_vals, **kwargs)
145 )
146 return xr.DataArray(_res, name=stat)
148 # ------------------------------------------------------------------------------------------------------------------
149 # Interquartile range: length between first and third quartile
150 # ------------------------------------------------------------------------------------------------------------------
152 elif stat == "iqr":
153 _int = scipy.integrate.trapezoid(_y_vals, _x_vals, **kwargs)
154 _a_0 = -1.0
155 __int = 0.0
156 _res = 0.0
157 for i in range(1, len(_x_vals)):
158 __int += scipy.integrate.trapezoid(
159 _y_vals[i - 1 : i + 1], _x_vals[i - 1 : i + 1], **kwargs
160 )
161 if __int > 0.25 * _int and _a_0 == -1:
162 _a_0 = _x_vals[i].item()
163 if __int > 0.75 * _int:
164 _res = _x_vals[i].item() - _a_0
165 break
166 return xr.DataArray(_res, name=stat)
168 # ------------------------------------------------------------------------------------------------------------------
169 # Mode: both the x-value and the y-value of the mode are returned
170 # ------------------------------------------------------------------------------------------------------------------
172 elif stat == "mode":
173 # Get the index of the mode and select it
174 idx_max = np.argmax(_y_vals.data, **kwargs)
175 mode_x = _x_vals[idx_max]
176 mode_y = _y_vals[idx_max]
177 return xr.Dataset(data_vars=dict(mode_x=mode_x, mode_y=mode_y))
179 # ------------------------------------------------------------------------------------------------------------------
180 # Average peak width: both the mean width and standard deviation of the widths is returned
181 # ------------------------------------------------------------------------------------------------------------------
183 elif stat == "avg_peak_width":
184 if "width" not in kwargs:
185 raise Exception("'width' kwarg required for 'scipy.signal.find_peaks'!")
187 # Insert a zero at the beginning and the end of the array to ensure peaks at the ends are found
188 _y_vals = np.insert(np.insert(_y_vals, 0, 0), 0, -1)
190 # Find the peaks along the array
191 peaks = scipy.signal.find_peaks(_y_vals, **kwargs)
193 # Calculate the mean and standard deviation of the peaks
194 mean, std = (
195 np.mean(peaks[1]["widths"]) * np.diff(_x_vals)[0],
196 np.std(peaks[1]["widths"]) * np.diff(_x_vals)[0],
197 )
199 return xr.Dataset(data_vars=dict(mean_peak_width=mean, peak_width_std=std))
202@is_operation("mean")
203@apply_along_dim
204def mean(*args, **kwargs) -> xr.Dataset:
205 """Computes the mean of a dataset"""
206 return stat_function(*args, stat="mean", **kwargs)
209@is_operation("std")
210@apply_along_dim
211def std(*args, **kwargs) -> xr.Dataset:
212 """Computes the standard deviation of a dataset"""
213 return stat_function(*args, stat="std", **kwargs)
216@is_operation("iqr")
217@apply_along_dim
218def iqr(*args, **kwargs) -> xr.Dataset:
219 """Computes the interquartile range of a dataset"""
220 return stat_function(*args, stat="iqr", **kwargs)
223@is_operation("mode")
224@apply_along_dim
225def mode(*args, **kwargs) -> xr.Dataset:
226 """Computes the mode of a dataset"""
227 return stat_function(*args, stat="mode", **kwargs)
230@is_operation("avg_peak_width")
231@apply_along_dim
232def avg_peak_width(*args, **kwargs) -> xr.Dataset:
233 """Computes the average peak width and std of peak widths of a dataset"""
234 return stat_function(*args, stat="avg_peak_width", **kwargs)
237@is_operation("p_value")
238@apply_along_dim
239def p_value(
240 data: xr.Dataset, point: Any, *, x: str, y: str, null: str = "mean"
241) -> xr.DataArray:
242 """Calculates the p value of a ``point`` from a Dataset containing x-y-pairs. It is assumed the integral under y
243 is normalised for the p-value to be meaningful. The p-value can be calculated wrt to the mean or the mode
244 of the distribution
246 :param data: ``xr.Dataset`` containing the x and p(x) values
247 :param point: point at which to calculate the p value
248 :param x: label of the x-values; can be a variable in the dataset or a coordinate
249 :param y: function values; assumed to be normalised
250 :param null: (optional) null wrt which the p-value is to be calculated; can be either ``mean``
251 or ``mode``
252 :return: ``xr.DataArray`` of the p-value of ``point``
253 """
255 # x can be both a variable and a coordinate
256 if x in data.coords.keys():
257 _x_vals = data.coords[x]
258 else:
259 _x_vals = data[x]
261 if isinstance(point, xr.DataArray):
262 point = point.data
264 # Calculate the value of the null of the distribution
265 m = (
266 mean(data, x=x, y=y).data
267 if null == "mean"
268 else mode(data, x=x, y=y)["mode_x"].data
269 )
271 # Calculate the index of the point
272 t_index = np.argmin(np.abs(_x_vals - point).data)
274 # Calculate the p-value depending on the location of the point
275 if point >= m:
276 return xr.DataArray(
277 scipy.integrate.trapezoid(data[y][t_index:], _x_vals[t_index:]),
278 name="p_value",
279 )
281 else:
282 return xr.DataArray(
283 scipy.integrate.trapezoid(data[y][:t_index], _x_vals[:t_index]),
284 name="p_value",
285 )
288@is_operation("normalize")
289@apply_along_dim
290def normalize(
291 distribution: xr.Dataset, *, x: str, y: str, norm: float = 1, **kwargs
292) -> xr.Dataset:
293 """Normalises a probability distribution of x- and y-values
295 :param distribution: ``xr.Dataset`` of x- and y-values
296 :param x: the x-values
297 :param y: the function values
298 :param norm: (optional) value to which to normalise the distribution
299 :param kwargs: passed to ``scipy.integrate.trapezoid``
300 :return: the normalised probability distribution
301 """
303 integral = scipy.integrate.trapezoid(distribution[y], distribution[x], **kwargs)
304 distribution[y] *= norm / integral
305 return distribution
308# ----------------------------------------------------------------------------------------------------------------------
309# HISTOGRAMS
310# ----------------------------------------------------------------------------------------------------------------------
313@is_operation("hist")
314@apply_along_dim
315def hist(
316 da: xr.DataArray,
317 bins: Any = 100,
318 ranges: Any = None,
319 *,
320 dim: str,
321 axis: int = None,
322 normalize: Union[float, bool] = False,
323 use_bins_as_coords: bool = False,
324 **kwargs,
325) -> Union[xr.Dataset, xr.DataArray]:
326 """Applies ``np.histogram`` using the ``apply_along_dim`` decorator to allow histogramming along multiple
327 dimensions. This function applies ``np.histogram`` along a single axis of an ``xr.DataArray`` object;
328 it is recommended to only use ``apply_along_dim`` across small dimensions, as splitting and recombining the
329 xarray objects is very expensive.
331 :param da: the ``xr.DataArray`` on which to apply the histogram function
332 :param bins: the bins to use, passed to ``np.histogram``. This can be a single integer, in which case it is
333 interpreted as the number of bins, a Sequence defining the bin edges, or a string defining the method to use.
334 See ``np.histogram`` for details
335 :param ranges: (optional): the lower and upper range of the bins
336 :param dim: the dimension along which to apply the operation. If not passed, an ``axis`` argument must be
337 provided
338 :param axis: (optional) the axis along which to apply np.histogram.
339 :param normalize: whether to normalize the counts. Can be a boolean or a float, in which case the counts are
340 normalized to that value
341 :param use_bins_as_coords: whether to use the bin centres as coordinates of the dataset, or as variables. If true,
342 a ``xr.DataArray`` is returned, with the bin centres as coordinates and the counts as the data. This may
343 cause incompatibilities with ``apply_along_dim``, since different samples have different bin_centres. For this
344 reason, the default behaviour is to return a ``xr.Dataset`` with the bin_centres and counts as variables,
345 and ``bin_idx`` as the coordinate. This enables combining different histograms with different bin centres
346 (but same number of bins) into a single dataset. If passed, `ranges` must also be passed to ensure
347 all histogram bins are identical.
348 :param kwargs: passed to ``np.histogram``
349 :return ``xr.DataArray`` or ``xr.Dataset`` containing the bin centres either as coordinates or as variables,
350 and the counts.
351 """
352 if dim is None and axis is None:
353 raise ValueError("Must supply either 'dim' or 'axis' arguments!")
355 if use_bins_as_coords and ranges is None:
356 raise ValueError(
357 "Setting 'use_bins_as_coords' to 'True' requires passing a 'ranges' argument to "
358 "ensure all coordinates are equal"
359 )
360 # Get the axis along which to apply the operations
361 if dim is not None:
362 axis = list(da.dims).index(dim)
364 # Get the bins and range objects
365 bins, ranges = _get_hist_bins_ranges(da, bins, ranges, axis)
367 # Apply the histogram function along the axis
368 res = np.apply_along_axis(
369 _hist, axis, da.data, bins=bins, range=ranges, normalize=normalize, **kwargs
370 )
372 # Get the counts and the bin centres. Note that the bin centres are equal along every dimension!
373 counts, bin_centres = np.take(res, 0, axis=axis), np.take(res, 1, axis=axis)
375 # Put the dataset back together again, relabelling the coordinate dimension that was binned
376 coords = dict(da.coords)
378 # Bin centres are to be used as coordinates
379 if use_bins_as_coords:
380 sel = [0] * len(np.shape(bin_centres))
381 sel[axis] = None
382 bin_centres = bin_centres[tuple(sel)].flatten()
383 coords.update({dim: bin_centres})
385 res = xr.DataArray(
386 counts,
387 dims=list(da.sizes.keys()),
388 coords=coords,
389 name=da.name if da.name else "count",
390 )
391 return res.rename({dim: "x"})
393 else:
394 coords.update({dim: np.arange(np.shape(bin_centres)[axis])})
395 other_dim = list(coords.keys())
396 other_dim.remove(dim)
397 attrs = [*other_dim, "bin_idx"] if other_dim else ["bin_idx"]
398 coords["bin_idx"] = coords.pop(dim)
400 return xr.Dataset(
401 data_vars={
402 da.name if da.name else "count": (attrs, counts),
403 "x": (attrs, bin_centres),
404 },
405 coords=coords,
406 )
409# ----------------------------------------------------------------------------------------------------------------------
410# DISTANCES BETWEEN PROBABILITY DENSITIES
411# ----------------------------------------------------------------------------------------------------------------------
412@is_operation("distances_between_distributions")
413@apply_along_dim
414def distances_between_distributions(
415 P: Union[xr.DataArray, xr.Dataset],
416 Q: Union[xr.DataArray, xr.Dataset],
417 *,
418 stat: str,
419 p: float = 2,
420 x: str = None,
421 y: str = None,
422 **kwargs,
423) -> xr.DataArray:
424 """Calculates distances between two distributions P and Q. Possible distances are:
426 - Hellinger distance: d(P, Q) = 1/2 * integral sqrt(P(x)) - sqrt(Q(x))**2 dx.
427 - Relative entropy: d(P, Q) = integral P(x) log(P(x)/Q(x))dx
428 - Lp distance: d(P, Q) = ( integral (P(x) - Q(x))^p dx)^{1/p}
430 These distances are calculated on the common support of P and Q; if P and Q have different discretisation
431 levels, the functions are interpolated.
433 :param P: one-dimensional ``xr.DataArray`` or ``xr.Dataset`` of values for P. If ``xr.Dataset``, ``x`` and ``y``
434 arguments must be passed.
435 :param Q: one-dimensional ``xr.DataArray`` or ``xr.Dataset`` of values for Q. If ``xr.Dataset``, ``x`` and ``y``
436 arguments must be passed.
437 :param stat: which density to function to use
438 :param p: p-value for the Lp distance
439 :param x: x-values to use if P and Q are ``xr.Datasets``.
440 :param y: y-values to use if P and Q are ``xr.Datasets``.
441 :param kwargs: kwargs, passed on to ``scipy.integrate.trapezoid``
442 :return: the distance between p and q
443 """
445 _permitted_stat_functions = ["Hellinger", "relative_entropy", "Lp"]
446 if stat not in _permitted_stat_functions:
447 raise ValueError(
448 f"Unrecognised stat function '{stat}'; choose from '{', '.join(_permitted_stat_functions)}'."
449 )
451 # If P and Q are datasets, convert to DataArrays
452 if isinstance(P, xr.Dataset):
453 P = xr.DataArray(
454 P[y], coords={"x": P[x] if x in list(P.data_vars) else P.coords[x]}
455 )
456 if isinstance(Q, xr.Dataset):
457 Q = xr.DataArray(
458 Q[y], coords={"x": Q[x] if x in list(Q.data_vars) else Q.coords[x]}
459 )
461 # Interpolate P and Q on their common support
462 P, Q, grid = _interpolate(P, Q)
464 # Hellinger distance
465 if stat == "Hellinger":
466 return xr.DataArray(
467 0.5
468 * scipy.integrate.trapezoid(
469 np.square(np.sqrt(P) - np.sqrt(Q)), grid, **kwargs
470 ),
471 name="Hellinger_distance",
472 )
474 # Relative entropy
475 elif stat == "relative_entropy":
476 P, Q = np.where(P != 0, P, 1), np.where(Q != 0, Q, 1)
477 return xr.DataArray(
478 scipy.integrate.trapezoid(P * np.log(P / Q), grid, **kwargs),
479 name="relative_entropy",
480 )
482 # Lp distance
483 elif stat == "Lp":
484 return xr.DataArray(
485 scipy.integrate.trapezoid((P - Q) ** p, grid, **kwargs) ** (1 / p),
486 name=f"Lp_distance",
487 )
490@is_operation("Hellinger_distance")
491@apply_along_dim
492def Hellinger_distance(*args, **kwargs) -> xr.DataArray:
493 return distances_between_distributions(*args, stat="Hellinger", **kwargs)
496@is_operation("relative_entropy")
497@apply_along_dim
498def relative_entropy(*args, **kwargs) -> xr.DataArray:
499 return distances_between_distributions(*args, stat="relative_entropy", **kwargs)
502@is_operation("Lp_distance")
503@apply_along_dim
504def Lp_distance(*args, **kwargs) -> xr.DataArray:
505 return distances_between_distributions(*args, stat="Lp", **kwargs)
508# ----------------------------------------------------------------------------------------------------------------------
509# PROBABILITY DENSITY FUNCTIONS
510# ----------------------------------------------------------------------------------------------------------------------
513@is_operation("joint_2D")
514@apply_along_dim
515def joint_2D(
516 x: xr.DataArray,
517 y: xr.DataArray,
518 values: xr.DataArray,
519 bins: Union[int, xr.DataArray] = 100,
520 ranges: xr.DataArray = None,
521 *,
522 statistic: Union[str, callable] = "mean",
523 normalize: Union[bool, float] = False,
524 dx: float = None,
525 dy: float = None,
526 dim_names: Sequence = ("x", "y"),
527 **kwargs,
528) -> xr.DataArray:
529 """
530 Computes the two-dimensional joint distribution of a dataset of parameters by calling the scipy.stats.binned_statistic_2d
531 function. The function returns a statistic for each bin (typically the mean).
533 :param x: DataArray of samples in the first dimension
534 :param y: DataArray of samples in the second dimension
535 :param values: DataArray of values to be binned
536 :param bins: (optional) ``bins`` argument to ``scipy.binned_statistic_2d``
537 :param ranges: (optional) ``range`` argument to ``scipy.binned_statistic_2d``
538 :param statistic: (optional) ``statistic`` argument to ``scipy.binned_statistic_2d``
539 :param normalize: (optional) whether to normalize the joint (False by default), and the normalisation value (1 by default)
540 :param dx: (optional) the spacial differential dx to use for normalisation. If provided, the norm will not be
541 calculated by integrating against the x-values, but rather by assuming the coordinates are spaced ``dx`` apart
542 :param dy: (optional) the spacial differential dy to use for normalisation. If provided, the norm will not be
543 calculated by integrating against the y-values, but rather by assuming the coordinates are spaced ``dy`` apart
544 :param dim_names: (optional) names of the two dimensions
545 :return: ``xr.DataArray`` of the joint distribution
546 """
548 # Get the number of bins
549 if isinstance(bins, xr.DataArray):
550 bins = bins.data
552 # Allow passing 'None' arguments in the plot config for certain entries of the range arg
553 # This allows clipping only on some dimensions without having to specify every limit
554 if ranges is not None:
555 ranges = (
556 np.array(ranges.data)
557 if isinstance(ranges, xr.DataArray)
558 else np.array(ranges)
559 )
560 for idx in range(len(ranges)):
561 if None in ranges[idx]:
562 ranges[idx] = (
563 [np.min(x), np.max(x)] if idx == 0 else [np.min(y), np.max(y)]
564 )
565 else:
566 ranges = kwargs.pop("range", None)
568 # Get the statistics and bin edges
569 stat, x_edge, y_edge, _ = scipy.stats.binned_statistic_2d(
570 x, y, values, statistic=statistic, bins=bins, range=ranges, **kwargs
571 )
572 # Normalise the joint distribution, if given
573 if normalize:
574 if dy is None:
575 int_y = [
576 scipy.integrate.trapezoid(
577 stat[i][~np.isnan(stat[i])],
578 0.5 * (y_edge[1:] + y_edge[:-1])[~np.isnan(stat[i])],
579 )
580 for i in range(stat.shape[0])
581 ]
582 else:
583 int_y = [
584 scipy.integrate.trapezoid(stat[i][~np.isnan(stat[i])], dx=dy)
585 for i in range(stat.shape[0])
586 ]
588 norm = (
589 scipy.integrate.trapezoid(int_y, 0.5 * (x_edge[1:] + x_edge[:-1]))
590 if dx is None
591 else scipy.integrate.trapezoid(int_y, dx=dx)
592 )
593 if norm == 0:
594 norm = 1
595 stat /= norm if isinstance(normalize, bool) else norm / normalize
597 return xr.DataArray(
598 data=stat,
599 dims=dim_names,
600 coords={
601 dim_names[0]: 0.5 * (x_edge[1:] + x_edge[:-1]),
602 dim_names[1]: 0.5 * (y_edge[1:] + y_edge[:-1]),
603 },
604 name="joint",
605 )
608@is_operation("joint_2D_ds")
609@apply_along_dim
610def joint_2D_ds(
611 ds: Union[xr.DataArray, xr.Dataset],
612 values: xr.DataArray,
613 bins: xr.DataArray = 100,
614 ranges: xr.DataArray = None,
615 *,
616 x: str,
617 y: str,
618 **kwargs,
619) -> xr.DataArray:
620 """Computes a two-dimensional joint from a single dataset with x and y given as variables, or from
621 a DataArray with x and y given as coordinate dimensions."""
623 if isinstance(ds, xr.Dataset):
624 return joint_2D(ds[x], ds[y], values, bins, ranges, dim_names=(x, y), **kwargs)
625 elif isinstance(ds, xr.DataArray):
626 return joint_2D(
627 ds.sel(dict(parameter=x)),
628 ds.sel(dict(parameter=y)),
629 values,
630 bins,
631 ranges,
632 dim_names=(x, y),
633 **kwargs,
634 )
637@is_operation("marginal_from_joint")
638@apply_along_dim
639def marginal_from_joint(
640 joint: xr.DataArray,
641 *,
642 parameter: str,
643 normalize: Union[bool, float] = True,
644 scale_y_bins: bool = False,
645) -> xr.Dataset:
646 """
647 Computes a marginal from a two-dimensional joint distribution by summing over one parameter. Normalizes
648 the marginal, if specified. NaN values in the joint are skipped when normalising: they are not zero, just unknown.
649 Since x-values may differ for different parameters, the x-values are variables in a dataset, not coordinates.
650 The coordinates are given by the bin index, thereby allowing marginals across multiple parameters to be combined
651 into a single xr.Dataset.
653 :param joint: the joint distribution over which to marginalise
654 :param normalize: whether to normalize the marginal distribution. If true, normalizes to 1, else normalizes to
655 a given value
656 :param scale_y_bins: whether to scale the integration over y by range of the given values (y_max - y_min)
657 """
659 # Get the integration coordinate
660 integration_coord = [c for c in list(joint.coords) if c != parameter][0]
662 # Marginalise over the integration coordinate
663 marginal = np.array([])
664 for p in joint.coords[parameter]:
665 _y, _x = joint.sel({parameter: p}).data, joint.coords[integration_coord]
666 if scale_y_bins and not np.isnan(_y).all():
667 _f = np.nanmax(_y) - np.nanmin(_y)
668 _f = 1.0 / _f if _f != 0 else 1.0
669 else:
670 _f = 1.0
671 marginal = np.append(
672 marginal,
673 _f * scipy.integrate.trapezoid(_y[~np.isnan(_y)], _x[~np.isnan(_y)]),
674 )
676 # Normalise, if given
677 if normalize:
678 norm = scipy.integrate.trapezoid(marginal, joint.coords[parameter])
679 if norm == 0:
680 norm = 1
681 marginal /= norm if isinstance(normalize, bool) else norm / normalize
683 # Return a dataset with x- and y-values as variables, and coordinates given by the bin index
684 # This allows combining different marginals with different x-values but identical number of bins
685 # into a single dataset
686 return xr.Dataset(
687 data_vars=dict(
688 x=(["bin_idx"], joint.coords[parameter].data),
689 y=(["bin_idx"], marginal),
690 ),
691 coords=dict(
692 bin_idx=(["bin_idx"], np.arange(len(joint.coords[parameter].data)))
693 ),
694 )
697@is_operation("marginal")
698@apply_along_dim
699def marginal(
700 x: xr.DataArray,
701 prob: xr.DataArray,
702 bins: Union[int, xr.DataArray] = None,
703 ranges: Union[Sequence, xr.DataArray] = None,
704 *,
705 parameter: str = "x",
706 normalize: Union[bool, float] = True,
707 scale_y_bins: bool = False,
708 **kwargs,
709) -> xr.Dataset:
710 """
711 Computes a marginal directly from a ``xr.DataArray`` of x-values and a ``xr.DataArray`` of probabilities by first
712 computing the joint distribution and then marginalising over the probability. This way, points that are sampled
713 multiple times only contribute once to the marginal, which is not a representation of the frequency with which
714 each point is sampled, but of the calculated likelihood function.
716 :param x: array of samples of the first variable (the parameter estimates)
717 :param prob: array of samples of (unnormalised) probability values
718 :param bins: bins to use for both dimensions
719 :param range: range to use for both dimensions. Defaults to the minimum and maximum along each dimension
720 :param parameter: the parameter over which to marginalise. Defaults to the first dimension.
721 :param normalize: whether to normalize the marginal
722 :param scale_y_bins: whether to scale the integration over y by range of the given values (y_max - y_min)
723 :param kwargs: other kwargs, passed to ``joint_2D``
724 :return: ``xr.Dataset`` of the marginal densities
725 """
726 joint = joint_2D(x, prob, prob, bins, ranges, normalize=normalize, **kwargs)
727 return marginal_from_joint(
728 joint, parameter=parameter, normalize=normalize, scale_y_bins=scale_y_bins
729 )
732@is_operation("marginal_from_ds")
733@apply_along_dim
734def marginal_from_ds(
735 ds: xr.Dataset,
736 bins: xr.DataArray = 100,
737 ranges: xr.DataArray = None,
738 *,
739 x: str,
740 y: str,
741 **kwargs,
742) -> xr.Dataset:
743 """Computes the marginal from a single dataset with x and y given as variables."""
744 return marginal(ds[x], ds[y], bins, ranges, **kwargs)
747@is_operation("joint_DD")
748@apply_along_dim
749def joint_DD(
750 sample: xr.DataArray,
751 values: xr.DataArray,
752 bins: Union[int, xr.DataArray] = 100,
753 ranges: xr.DataArray = None,
754 *,
755 statistic: Union[str, callable] = "mean",
756 normalize: Union[bool, float] = False,
757 dim_names: Sequence = None,
758 **kwargs,
759) -> xr.DataArray:
760 """
761 Computes the d-dimensional joint distribution of a dataset of parameters by calling ``scipy.stats.binned_statistic_dd``.
762 This function can handle at most 32 parameters. A statistic for each bin is returned (mean by default).
764 :param sample: ``xr.DataArray`` of samples of shape ``(N, D)``
765 :param values: ``xr.DataArray`` of values to be binned, of shape ``(D, )``
766 :param bins: bins argument to ``scipy.binned_statistic_dd``
767 :param ranges: range argument to ``scipy.binned_statistic_dd``
768 :param statistic: (optional) ``statistic`` argument to ``scipy.binned_statistic_2d``
769 :param normalize: (not implemented) whether to normalize the joint (False by default),
770 and the normalisation value (1 by default)
771 :param dim_names: (optional) names of the two dimensions
772 :return: ``xr.Dataset`` of the joint distribution
773 """
774 if normalize:
775 raise NotImplementedError(
776 "Normalisation for d-dimensional joints is not yet implemented!"
777 )
779 # Get the number of bins
780 if isinstance(bins, xr.DataArray):
781 bins = bins.data
783 dim_names = (
784 sample.coords[list(sample.dims)[-1]].data if dim_names is None else dim_names
785 )
787 # Allow passing 'None' arguments in the plot config for certain entries of the range arg
788 # This allows clipping only on some dimensions without having to specify every limit
789 if ranges is not None:
790 ranges = ranges.data if isinstance(ranges, xr.DataArray) else ranges
791 for idx in range(len(ranges)):
792 if None in ranges[idx]:
793 ranges[idx] = [np.min(sample.coords[idx]), np.max(sample.coords[idx])]
794 else:
795 ranges = kwargs.pop("range", None)
797 # Get the statistics and bin edges
798 stat, bin_edges, _ = scipy.stats.binned_statistic_dd(
799 sample, values, statistic=statistic, bins=bins, range=ranges, **kwargs
800 )
802 return xr.DataArray(
803 data=stat,
804 dims=dim_names,
805 coords={dim_names[i]: 0.5 * (b[1:] + b[:-1]) for i, b in enumerate(bin_edges)},
806 name="joint",
807 )
810# ----------------------------------------------------------------------------------------------------------------------
811# MCMC operations
812# ----------------------------------------------------------------------------------------------------------------------
813@is_operation("batch_mean")
814@apply_along_dim
815def batch_mean(da: xr.DataArray, *, batch_size: int = None) -> xr.Dataset:
816 """Computes the mean of a single sampling chain over batches of length B. Default batch length is
817 int(sqrt(N)), where N is the length of the chain.
819 :param da: dataarray of samples
820 :param batch_size: batch length over which to compute averages
821 :return: res: averages of the batches
822 """
823 vals = da.data
824 means = np.array([])
825 windows = np.arange(0, len(vals), batch_size)
826 if len(windows) == 1:
827 windows = np.append(windows, len(vals) - 1)
828 else:
829 if windows[-1] != len(vals) - 1:
830 windows = np.append(windows, len(vals) - 1)
831 for idx, start_idx in enumerate(windows[:-1]):
832 means = np.append(means, np.mean(vals[start_idx : windows[idx + 1]]))
834 return xr.Dataset(
835 data_vars=dict(means=("batch_idx", means)),
836 coords=dict(batch_idx=("batch_idx", np.arange(len(means)))),
837 )
840@is_operation("gelman_rubin")
841@apply_along_dim
842def gelman_rubin(da: xr.Dataset, *, step_size: int = 1) -> xr.Dataset:
843 R = []
844 for i in range(step_size, len(da.coords["sample"]), step_size):
845 da_sub = da.isel({"sample": slice(0, i)})
846 L = len(da_sub.coords["sample"])
848 chain_mean = da_sub.mean("sample")
849 between_chain_variance = L * chain_mean.std("seed", ddof=1) ** 2
850 within_chain_variance = da_sub.std("sample", ddof=1) ** 2
851 W = within_chain_variance.mean("seed")
852 R.append(((L - 1) * W / L + 1 / L * between_chain_variance) / W)
854 return xr.Dataset(
855 data_vars=dict(gelman_rubin=("sample", R)),
856 coords=dict(
857 sample=("sample", np.arange(step_size, len(da.coords["sample"]), step_size))
858 ),
859 )
862# ----------------------------------------------------------------------------------------------------------------------
863# CSV operations
864# ----------------------------------------------------------------------------------------------------------------------
865@is_operation("to_csv")
866def to_csv(
867 data: Union[xr.Dataset, xr.DataArray], path: str
868) -> Union[xr.Dataset, xr.DataArray]:
869 df = data.to_dataframe()
870 df.to_csv(path)
871 return data