Coverage for model_plots/_op_utils.py: 96%
50 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import itertools
2from typing import Sequence, Union, Tuple
4import numpy as np
5import xarray as xr
7# ----------------------------------------------------------------------------------------------------------------------
8# UTILITY FUNCTIONS FOR CUSTOM DAG OPERATIONS
9# ----------------------------------------------------------------------------------------------------------------------
12def apply_along_dim(func):
13 def _apply_along_axes(
14 *args,
15 along_dim: Sequence = None,
16 exclude_dim: Sequence = None,
17 **kwargs,
18 ):
19 """Decorator which allows for applying a function, acting on aligned array-likes, along dimensions of
20 xarray objects. The datasets must be aligned. All functions using this header should therefore only take
21 xarray objects as arguments that can be indexed along common dimensions. All other arguments should be keywords.
23 :param args: Sequence of xarray objects (``xr.Dataset`` or ``xr.DataArray``) which are to be aligned
24 :param along_dim: the dimensions along with to apply the operation
25 :param exclude_dim: the dimensions to exclude. This is an alternative to providing the 'along_dim' argument.
26 Cannot provide both 'along_dim' and 'exclude_dim'
27 :param kwargs: passed to function
28 :return: if ``along_dim`` or ``exclude_dim`` are given, returns a ``xr.Dataset`` of merged arrays, else returns
29 the return type of ``func``.
30 """
31 if along_dim and exclude_dim:
32 raise ValueError("Cannot provide both 'along_dim' and 'exclude_dim'!")
34 if along_dim is not None or exclude_dim is not None:
35 # Get the coordinates for all the dimensions that are to be excluded
36 if exclude_dim is None:
37 excluded_dims = []
38 for c in list(args[0].coords.keys()):
39 if c not in along_dim:
40 excluded_dims.append(c)
41 else:
42 excluded_dims = exclude_dim
43 excluded_coords = [args[0].coords[_].data for _ in excluded_dims]
45 # Collect the dsets into one dataset
46 dsets = []
48 # Iterate over all coordinates in the dimensions and apply the function separately
49 for idx in itertools.product(*(range(len(_)) for _ in excluded_coords)):
50 # Strip both datasets of all coords except the ones along which the function is being
51 # applied. Add the coordinates back afterwards and re-merge.
52 dsets.append(
53 func(
54 *[
55 arg.sel(
56 {
57 excluded_dims[j]: excluded_coords[j][idx[j]]
58 for j in range(len(excluded_dims))
59 },
60 drop=True,
61 )
62 for arg in args
63 ],
64 **kwargs,
65 ).expand_dims(
66 dim={
67 excluded_dims[i]: [excluded_coords[i][idx[i]]]
68 for i in range(len(excluded_dims))
69 }
70 )
71 )
73 # Merge the datasets into one and return
74 return xr.merge(dsets)
76 else:
77 return func(*args, **kwargs)
79 return _apply_along_axes
82def _hist(
83 obj, *, normalize: Union[bool, float] = None, **kwargs
84) -> Tuple[np.ndarray, np.ndarray]:
85 """Applies ``numpy.histogram`` along an axis of an object and returns the counts and bin centres.
86 If specified, the counts are normalised. Returns the counts and bin centres (not bin edges!)
87 :param obj: data to bin
88 :param normalize: (optional) whether to normalise the counts; can be a boolean or a float, in which
89 case the float is interpreted as the normalisation constant.
90 :param kwargs: passed on to ``np.histogram``
91 :return: bin counts and bin centres
92 """
93 _counts, _edges = np.histogram(obj, **kwargs)
94 _counts = _counts.astype(float)
95 _bin_centres = 0.5 * (_edges[:-1] + _edges[1:])
97 # Normalise, if given
98 if normalize:
99 norm = np.nansum(_counts)
100 norm = 1.0 if norm == 0.0 else norm
101 _counts /= norm if isinstance(normalize, bool) else norm / normalize
102 return _counts, _bin_centres
105def _get_hist_bins_ranges(ds, bins, ranges, axis):
106 """Returns histogram bins and ranges in such a way that they can be passed to a histogram function. Bins are
107 converted into numpy arrays, and ``None`` entries in the range are converted into the minimum or maximum of
108 the data.
110 :param ds: dataset to be binned
111 :param bins: ``bins`` argument to the histogramming function; if an ``xr.DataArray`` object, is converted
112 into a numpy array
113 :param ranges: ``ranges`` argument to the histogramming function; if ``None``, is filled with minimum and maximum
114 value of the data.
115 :param axis: axis along which the histogramming will be applied
116 :return: bins and ranges for the histogram
117 """
119 # Convert the bins from an xarray object to a numpy object, if required
120 if isinstance(bins, xr.DataArray):
121 bins = bins.data
123 # Fill ``None`` entries in the ranges
124 if ranges is not None:
125 ranges = (
126 np.array(ranges.data)
127 if isinstance(ranges, xr.DataArray)
128 else np.array(ranges)
129 )
130 for idx in range(len(ranges)):
131 if ranges[idx] is None:
132 ranges[idx] = (
133 np.min(ds.data[axis]) if idx == 0 else np.max(ds.data[axis])
134 )
136 return bins, ranges
139def _interpolate(
140 _p: xr.DataArray, _q: xr.DataArray
141) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
142 """Interpolates two one-dimensional densities _p and _q onto their common support, with the mesh size given by the sum of the
143 individual mesh sizes. ``_p`` and ``_q`` must be one-dimensional."""
145 # Get the coordinate dimensions
146 _dim1, _dim2 = list(_p.coords.keys())[0], list(_q.coords.keys())[0]
148 # Return densities if they are already equal
149 if len(_p.coords[_dim1].data) == len(_q.coords[_dim2].data):
150 if all(_p.coords[_dim1].data == _q.coords[_dim2].data):
151 return _p.data, _q.data, _p.coords[_dim1].data
153 # Get the common support
154 _x_min, _x_max = np.max(
155 [_p.coords[_dim1][0].item(), _q.coords[_dim2][0].item()]
156 ), np.min([_p.coords[_dim1][-1].item(), _q.coords[_dim2][-1].item()])
158 # Interpolate the functions on the intersection of their support
159 _grid = np.linspace(_x_min, _x_max, len(_p.coords[_dim1]) + len(_q.coords[_dim2]))
160 _p_interp = np.interp(_grid, _p.coords[_dim1], _p)
161 _q_interp = np.interp(_grid, _q.coords[_dim2], _q)
163 return _p_interp, _q_interp, _grid