Coverage for model_plots/_op_utils.py: 96%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import itertools 

2from typing import Sequence, Union, Tuple 

3 

4import numpy as np 

5import xarray as xr 

6 

7# ---------------------------------------------------------------------------------------------------------------------- 

8# UTILITY FUNCTIONS FOR CUSTOM DAG OPERATIONS 

9# ---------------------------------------------------------------------------------------------------------------------- 

10 

11 

12def apply_along_dim(func): 

13 def _apply_along_axes( 

14 *args, 

15 along_dim: Sequence = None, 

16 exclude_dim: Sequence = None, 

17 **kwargs, 

18 ): 

19 """Decorator which allows for applying a function, acting on aligned array-likes, along dimensions of 

20 xarray objects. The datasets must be aligned. All functions using this header should therefore only take 

21 xarray objects as arguments that can be indexed along common dimensions. All other arguments should be keywords. 

22 

23 :param args: Sequence of xarray objects (``xr.Dataset`` or ``xr.DataArray``) which are to be aligned 

24 :param along_dim: the dimensions along with to apply the operation 

25 :param exclude_dim: the dimensions to exclude. This is an alternative to providing the 'along_dim' argument. 

26 Cannot provide both 'along_dim' and 'exclude_dim' 

27 :param kwargs: passed to function 

28 :return: if ``along_dim`` or ``exclude_dim`` are given, returns a ``xr.Dataset`` of merged arrays, else returns 

29 the return type of ``func``. 

30 """ 

31 if along_dim and exclude_dim: 

32 raise ValueError("Cannot provide both 'along_dim' and 'exclude_dim'!") 

33 

34 if along_dim is not None or exclude_dim is not None: 

35 # Get the coordinates for all the dimensions that are to be excluded 

36 if exclude_dim is None: 

37 excluded_dims = [] 

38 for c in list(args[0].coords.keys()): 

39 if c not in along_dim: 

40 excluded_dims.append(c) 

41 else: 

42 excluded_dims = exclude_dim 

43 excluded_coords = [args[0].coords[_].data for _ in excluded_dims] 

44 

45 # Collect the dsets into one dataset 

46 dsets = [] 

47 

48 # Iterate over all coordinates in the dimensions and apply the function separately 

49 for idx in itertools.product(*(range(len(_)) for _ in excluded_coords)): 

50 # Strip both datasets of all coords except the ones along which the function is being 

51 # applied. Add the coordinates back afterwards and re-merge. 

52 dsets.append( 

53 func( 

54 *[ 

55 arg.sel( 

56 { 

57 excluded_dims[j]: excluded_coords[j][idx[j]] 

58 for j in range(len(excluded_dims)) 

59 }, 

60 drop=True, 

61 ) 

62 for arg in args 

63 ], 

64 **kwargs, 

65 ).expand_dims( 

66 dim={ 

67 excluded_dims[i]: [excluded_coords[i][idx[i]]] 

68 for i in range(len(excluded_dims)) 

69 } 

70 ) 

71 ) 

72 

73 # Merge the datasets into one and return 

74 return xr.merge(dsets) 

75 

76 else: 

77 return func(*args, **kwargs) 

78 

79 return _apply_along_axes 

80 

81 

82def _hist( 

83 obj, *, normalize: Union[bool, float] = None, **kwargs 

84) -> Tuple[np.ndarray, np.ndarray]: 

85 """Applies ``numpy.histogram`` along an axis of an object and returns the counts and bin centres. 

86 If specified, the counts are normalised. Returns the counts and bin centres (not bin edges!) 

87 :param obj: data to bin 

88 :param normalize: (optional) whether to normalise the counts; can be a boolean or a float, in which 

89 case the float is interpreted as the normalisation constant. 

90 :param kwargs: passed on to ``np.histogram`` 

91 :return: bin counts and bin centres 

92 """ 

93 _counts, _edges = np.histogram(obj, **kwargs) 

94 _counts = _counts.astype(float) 

95 _bin_centres = 0.5 * (_edges[:-1] + _edges[1:]) 

96 

97 # Normalise, if given 

98 if normalize: 

99 norm = np.nansum(_counts) 

100 norm = 1.0 if norm == 0.0 else norm 

101 _counts /= norm if isinstance(normalize, bool) else norm / normalize 

102 return _counts, _bin_centres 

103 

104 

105def _get_hist_bins_ranges(ds, bins, ranges, axis): 

106 """Returns histogram bins and ranges in such a way that they can be passed to a histogram function. Bins are 

107 converted into numpy arrays, and ``None`` entries in the range are converted into the minimum or maximum of 

108 the data. 

109 

110 :param ds: dataset to be binned 

111 :param bins: ``bins`` argument to the histogramming function; if an ``xr.DataArray`` object, is converted 

112 into a numpy array 

113 :param ranges: ``ranges`` argument to the histogramming function; if ``None``, is filled with minimum and maximum 

114 value of the data. 

115 :param axis: axis along which the histogramming will be applied 

116 :return: bins and ranges for the histogram 

117 """ 

118 

119 # Convert the bins from an xarray object to a numpy object, if required 

120 if isinstance(bins, xr.DataArray): 

121 bins = bins.data 

122 

123 # Fill ``None`` entries in the ranges 

124 if ranges is not None: 

125 ranges = ( 

126 np.array(ranges.data) 

127 if isinstance(ranges, xr.DataArray) 

128 else np.array(ranges) 

129 ) 

130 for idx in range(len(ranges)): 

131 if ranges[idx] is None: 

132 ranges[idx] = ( 

133 np.min(ds.data[axis]) if idx == 0 else np.max(ds.data[axis]) 

134 ) 

135 

136 return bins, ranges 

137 

138 

139def _interpolate( 

140 _p: xr.DataArray, _q: xr.DataArray 

141) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 

142 """Interpolates two one-dimensional densities _p and _q onto their common support, with the mesh size given by the sum of the 

143 individual mesh sizes. ``_p`` and ``_q`` must be one-dimensional.""" 

144 

145 # Get the coordinate dimensions 

146 _dim1, _dim2 = list(_p.coords.keys())[0], list(_q.coords.keys())[0] 

147 

148 # Return densities if they are already equal 

149 if len(_p.coords[_dim1].data) == len(_q.coords[_dim2].data): 

150 if all(_p.coords[_dim1].data == _q.coords[_dim2].data): 

151 return _p.data, _q.data, _p.coords[_dim1].data 

152 

153 # Get the common support 

154 _x_min, _x_max = np.max( 

155 [_p.coords[_dim1][0].item(), _q.coords[_dim2][0].item()] 

156 ), np.min([_p.coords[_dim1][-1].item(), _q.coords[_dim2][-1].item()]) 

157 

158 # Interpolate the functions on the intersection of their support 

159 _grid = np.linspace(_x_min, _x_max, len(_p.coords[_dim1]) + len(_q.coords[_dim2])) 

160 _p_interp = np.interp(_grid, _p.coords[_dim1], _p) 

161 _q_interp = np.interp(_grid, _q.coords[_dim2], _q) 

162 

163 return _p_interp, _q_interp, _grid