Coverage for tests/test_data_ops.py: 100%
163 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
2from os.path import dirname as up
4import numpy as np
5import pytest
6import scipy.integrate
7import xarray as xr
8from dantro._import_tools import import_module_from_path
9from pkg_resources import resource_filename
11from utopya.yaml import load_yml
13sys.path.insert(0, up(up(up(__file__))))
15ops = import_module_from_path(
16 mod_path=up(up(up(__file__))), mod_str="model_plots.data_ops"
17)
19# Load the test config
20CFG_FILENAME = resource_filename("tests", "cfgs/data_ops.yml")
21test_cfg = load_yml(CFG_FILENAME)
24# ----------------------------------------------------------------------------------------------------------------------
25# DECORATOR
26# ----------------------------------------------------------------------------------------------------------------------
29def test_apply_along_dim():
30 """Tests the apply_along_dim decorator"""
32 @ops.apply_along_dim
33 def _test_func(data):
34 return xr.DataArray(np.mean(data), name=data.name)
36 # Test on DataArray
37 da = xr.DataArray(
38 np.ones((10, 10, 10)),
39 dims=["x", "y", "z"],
40 coords={k: np.arange(10) for k in ["x", "y", "z"]},
41 name="foo",
42 )
44 res = _test_func(da, along_dim=["x"])
46 # apply_along_dim returns a Dataset
47 assert isinstance(res, xr.Dataset)
49 # Assert operation was applied along that dimension
50 assert set(res.coords.keys()) == {"y", "z"}
51 assert list(res.data_vars) == ["foo"]
52 assert all(res["foo"].data.flatten() == 1)
54 # Assert specifying operation is the same as excluding certain dimensions
55 assert res == _test_func(da, exclude_dim=["y", "z"])
57 # Apply along all dimensions
58 res = _test_func(da, along_dim=["x", "y", "z"])
59 assert res == xr.Dataset(data_vars=dict(foo=([], 1)))
61 # Test applying function without decorator
62 res = _test_func(da)
63 assert isinstance(res, type(da))
64 assert res == xr.DataArray(1)
66 # Test applying the decorator along multiple args
67 @ops.apply_along_dim
68 def _test_func_nD(da1, da2, da3, *, op: str = "sum"):
69 if op == "sum":
70 return (da1 * da2 * da3).sum()
71 else:
72 return (da1 * da2 * da3).mean()
74 res = _test_func_nD(da, 2 * da, -1 * da, along_dim=["x"])
75 assert set(res.coords.keys()) == {"y", "z"}
76 assert all(res["foo"].data.flatten() == len(da.coords["x"]) * 2 * -1)
77 assert res == _test_func_nD(da, 2 * da, -1 * da, exclude_dim=["y", "z"])
79 # Test passing kwargs to function
80 res = _test_func_nD(da, 2 * da, -1 * da, along_dim=["x"], op="mean")
81 assert all(res["foo"].data.flatten() == 1 * 2 * -1)
83 # Test passing both ``along_dim`` and ``exclude_dim`` raises an error
84 with pytest.raises(ValueError, match="Cannot provide both"):
85 _test_func(da, along_dim=["x"], exclude_dim=["y"])
88# ----------------------------------------------------------------------------------------------------------------------
89# DATA RESHAPING AND REORGANIZING
90# ----------------------------------------------------------------------------------------------------------------------
91def test_concat():
92 """Tests concatenation of multiple xarray objects"""
94 da = xr.DataArray(
95 np.ones((10, 10, 10)),
96 dims=["x", "y", "z"],
97 coords={k: np.arange(10) for k in ["x", "y", "z"]},
98 name="foo",
99 )
101 # Test on DataArrays and Datasets
102 for _obj in [da, xr.Dataset(dict(foo=da))]:
103 res = ops.concat([_obj, _obj, _obj], "type", ["a", "b", "c"])
104 assert set(res.coords.keys()) == set(list(_obj.coords.keys()) + ["type"])
107def test_flatten_dims():
108 """Test flattening coordinates of xarray objects into one"""
110 da = xr.DataArray(
111 np.ones((5, 5, 5)),
112 dims=["dim1", "dim2", "dim3"],
113 coords={k: np.arange(5) for k in ["dim1", "dim2", "dim3"]},
114 )
116 # Test on DataArrays and Datasets
117 for _obj in [da, xr.Dataset(dict(foo=da))]:
118 res = ops.flatten_dims(_obj, dims={"dim2": ["dim2", "dim3"]})
119 assert set(res.coords.keys()) == {"dim1", "dim2"}
120 assert len(
121 res.coords["dim2"] == len(_obj.coords["dim2"]) * len(_obj.coords["dim3"])
122 )
124 # Test reassigning coordinates
125 res = ops.flatten_dims(
126 _obj, dims={"dim2": ["dim2", "dim3"]}, new_coords=np.arange(25, 50, 1)
127 )
128 assert all(res.coords["dim2"] == np.arange(25, 50, 1))
131def test_broadcast():
132 """Test broadcasting xr.DataArray s into a single Dataset"""
133 da1 = xr.DataArray(np.random.rand(10, 3), dims=["sample", "parameter"])
134 da2 = xr.DataArray(np.exp(-np.linspace(0, 1, 10)), dims=["sample"])
135 res = ops.broadcast(da1, da2, x="x", p="loss")
137 assert isinstance(res, xr.Dataset)
138 assert set(res.data_vars) == {"x", "loss"}
139 assert res == ops.broadcast(da2, da1)
142# ----------------------------------------------------------------------------------------------------------------------
143# BASIC STATISTICS FUNCTIONS
144# ----------------------------------------------------------------------------------------------------------------------
145def test_stat_function():
146 """Tests the statistics functions on a normal distribution"""
147 _x = np.linspace(-5, 5, 1000)
148 _m, _std = 0.0, 1.0
149 _f = np.exp(-((_x - _m) ** 2) / (2 * _std**2))
150 _norm = scipy.integrate.trapezoid(_f, _x)
152 ds = xr.Dataset(dict(y=(["x"], _f)), coords=dict(x=_x))
154 # Test normalisation
155 ds = ops.normalize(ds, x="x", y="y")
156 assert scipy.integrate.trapezoid(ds["y"], ds.coords["x"]) == pytest.approx(
157 1.0, 1e-4
158 )
160 # Test mean
161 mean = ops.stat_function(ds, stat="mean", x="x", y="y")
162 assert mean == pytest.approx(_m, abs=1e-3)
164 # Test mean with x as a coordinate or variable
165 ds = xr.Dataset(dict(y=(["x"], ds["y"].data), x_val=ds.coords["x"]))
166 assert mean == ops.stat_function(ds, stat="mean", x="x_val", y="y")
167 assert mean == ops.stat_function(ds, stat="mean", x="x", y="y")
169 # Test standard deviation
170 std = ops.stat_function(ds, stat="std", x="x", y="y")
171 assert std == pytest.approx(_std, abs=1e-3)
173 # Test interquartile range
174 iqr = ops.stat_function(ds, stat="iqr", x="x", y="y")
175 assert iqr == pytest.approx(1.34, abs=1e-2)
177 # Test mode
178 mode = ops.stat_function(ds, stat="mode", x="x", y="y")
179 assert mode["mode_x"].data.item() == pytest.approx(_m, abs=1e-2)
180 assert mode["mode_y"].data.item() == pytest.approx(1.0 / _norm, abs=1e-2)
182 # Test peak width calculation
183 peak_widths = ops.stat_function(ds, stat="avg_peak_width", x="x", y="y", width=1)
184 assert peak_widths["mean_peak_width"] == pytest.approx(2.355 * _std, abs=1e-2)
185 assert peak_widths["peak_width_std"] == 0.0
187 # Test p-value calculation
188 assert ops.p_value(ds, 0.0, x="x", y="y") == pytest.approx(0.5, abs=1e-1)
189 assert ops.p_value(ds, -1, x="x", y="y") == pytest.approx(0.159, abs=5e-3)
190 assert ops.p_value(ds, -1, x="x", y="y") == pytest.approx(
191 ops.p_value(ds, +1, x="x", y="y"), abs=5e-3
192 )
193 assert ops.p_value(ds, xr.DataArray(0.0), x="x", y="y") == ops.p_value(
194 ds, 0.0, x="x", y="y"
195 )
197 # Assert the p-value for a Gaussian wrt the mean is the same as wrt to the mode
198 assert ops.p_value(ds, -1, x="x", y="y", null="mean") == ops.p_value(
199 ds, -1, x="x", y="y", null="mode"
200 )
203# ----------------------------------------------------------------------------------------------------------------------
204# HISTOGRAMS
205# ----------------------------------------------------------------------------------------------------------------------
206def test_hist():
207 """Tests histogram functions"""
209 _n_samples = 100
210 _n_vals = 100
211 _n_bins = 20
213 # Test histogramming a 1D array
214 da = xr.DataArray(
215 np.random.rand(_n_vals), dims=["i"], coords=dict(i=np.arange(_n_vals))
216 )
217 hist = ops.hist(da, _n_bins, dim="i")
218 assert set(hist.coords.keys()) == {"bin_idx"}
219 assert set(hist.data_vars) == {"count", "x"}
221 # Test total number of counts has not changed
222 assert all(hist["count"].sum("bin_idx").data.flatten() == _n_vals)
224 # Repeat the same thing, this time using the bin centres as coordinates
225 hist = ops.hist(da, _n_bins, [0, 1], dim="i", use_bins_as_coords=True)
226 assert isinstance(hist, xr.DataArray)
227 assert set(hist.coords.keys()) == {"x"}
228 assert all(
229 hist.coords["x"].data
230 == 0.5
231 * (np.linspace(0, 1, _n_bins + 1)[1:] + np.linspace(0, 1, _n_bins + 1)[:-1])
232 )
233 assert all(hist.sum("x").data.flatten() == _n_vals)
235 # Test histogramming a 2D array
236 da = xr.DataArray(
237 np.random.rand(_n_samples, _n_vals),
238 dims=["sample", "i"],
239 coords=dict(sample=np.arange(_n_samples), i=np.arange(_n_vals)),
240 )
241 hist = ops.hist(da, _n_bins, dim="i")
242 assert set(hist.coords.keys()) == {"sample", "bin_idx"}
243 assert set(hist.data_vars) == {"count", "x"}
244 assert all(hist["count"].sum("bin_idx").data.flatten() == _n_vals)
246 hist = ops.hist(da, _n_bins, dim="i", ranges=[0, 1], use_bins_as_coords=True)
247 assert set(hist.coords.keys()) == {"sample", "x"}
248 assert all(hist.sum("x").data.flatten() == _n_vals)
250 # Test histogramming a 3D array
251 da = da.expand_dims(dict(dim0=np.arange(4)))
252 hist = ops.hist(
253 da,
254 dim="i",
255 exclude_dim=["dim0"],
256 bins=_n_bins,
257 ranges=[0, 1],
258 use_bins_as_coords=True,
259 )
260 assert set(hist.coords.keys()) == {"sample", "x", "dim0"}
262 hist = ops.hist(da, dim="i", exclude_dim=["dim0"], bins=_n_bins)
263 assert set(hist.coords.keys()) == {"sample", "bin_idx", "dim0"}
265 # Test normalisation of bin counts
266 hist_normalised = ops.hist(da, bins=100, dim="i", ranges=[0, 1], normalize=2.0)
267 assert hist_normalised["count"].sum("bin_idx").data.flatten() == pytest.approx(
268 2.0, abs=1e-10
269 )
271 # Test selecting range
272 hist_clipped = ops.hist(da, bins=100, dim="i", ranges=[1, 2])
273 assert all(hist_clipped["count"].data.flatten() == 0)
276# ----------------------------------------------------------------------------------------------------------------------
277# PROBABILITY DENSITY OPERATIONS
278# ----------------------------------------------------------------------------------------------------------------------
281def test_joint_and_marginal_2D():
282 """Test two-dimensional joint distributions and marginals from joints are correctly calculated"""
284 # Generate a 2D-Gaussian on a square domain
285 _lower, _upper, _n_vals = -5, +5, 1000
286 _bins = 50
287 _m, _std = 0.0, 1.0
288 _x, _y = (_upper - _lower) * np.random.rand(_n_vals) + _lower, (
289 _upper - _lower
290 ) * np.random.rand(_n_vals) + _lower
291 _f = np.exp(-((_x - _m) ** 2) / (2 * _std**2)) * np.exp(
292 -((_y - _m) ** 2) / (2 * _std**2)
293 )
295 # Calculate the joint distribution
296 joint = ops.joint_2D(_x, _y, _f, bins=_bins, normalize=1.0)
297 assert set(joint.coords.keys()) == {"x", "y"}
299 # Assert the maximum of the joint is roughly in the middle
300 assert all(
301 [
302 idx == pytest.approx(25, abs=2)
303 for idx in np.unravel_index(joint.argmax(), joint.shape)
304 ]
305 )
307 # Assert the marginal distribution of each dimension is normalized
308 marginal_x = ops.marginal_from_joint(joint, parameter="x")
309 marginal_y = ops.marginal_from_joint(joint, parameter="y")
310 assert scipy.integrate.trapezoid(marginal_x["y"], marginal_x["x"]) == pytest.approx(
311 1.0, abs=1e-4
312 )
313 assert scipy.integrate.trapezoid(marginal_y["y"], marginal_y["x"]) == pytest.approx(
314 1.0, abs=1e-4
315 )
317 # Assert the marginals are again approximately Gaussian
318 assert ops.mean(marginal_x, x="x", y="y") == pytest.approx(0.0, abs=1e-1)
319 assert ops.mean(marginal_y, x="x", y="y") == pytest.approx(0.0, abs=1e-1)
320 assert ops.std(marginal_x, x="x", y="y") == pytest.approx(1.0, abs=5e-2)
321 assert ops.std(marginal_y, x="x", y="y") == pytest.approx(1.0, abs=5e-2)
323 # Assert alternative joint operation does the same thing
324 joint_from_ds = ops.joint_2D_ds(
325 xr.Dataset(dict(x=_x, y=_y)), _f, x="x", y="y", bins=_bins, normalize=1.0
326 )
327 assert joint_from_ds.equals(joint)
329 # Assert 3D joint
330 _z = (_upper - _lower) * np.random.rand(_n_vals) + _lower
331 samples = (
332 ops.concat(
333 [xr.DataArray(_x), xr.DataArray(_y), xr.DataArray(_z)],
334 "parameter",
335 ["a", "b", "c"],
336 )
337 .transpose()
338 .assign_coords(dict(dim_0=np.arange(_n_vals)))
339 )
341 _f = xr.DataArray(_f * np.exp(-((_z - _m) ** 2) / (2 * _std**2))).assign_coords(
342 dict(dim_0=np.arange(_n_vals))
343 )
345 joint_3D = ops.joint_DD(samples, _f, bins=50)
346 assert set(joint_3D.coords) == {"a", "b", "c"}
349def test_distances_between_densities():
350 """Tests the Hellinger distance between distributions"""
351 _x = np.linspace(-5, 5, 500)
352 Gaussian = xr.DataArray(
353 np.exp(-(_x**2) / 2), dims=["x"], coords=dict(x=_x)
354 ) # mean = 0, std = 1
355 Gaussian /= scipy.integrate.trapezoid(Gaussian.data, _x)
357 assert ops.Hellinger_distance(Gaussian, Gaussian) == 0
358 assert ops.relative_entropy(Gaussian, Gaussian) == 0
359 assert ops.Lp_distance(Gaussian, Gaussian, p=2) == 0
361 # Test calculating the distances on a xr.Dataset instead
362 Gaussian = xr.Dataset(dict(_x=Gaussian.coords["x"], y=Gaussian))
363 assert ops.Hellinger_distance(Gaussian, Gaussian, x="_x", y="y") == 0
364 Gaussian = xr.Dataset(dict(y=Gaussian["y"]))
365 assert ops.Hellinger_distance(Gaussian, Gaussian, x="x", y="y") == 0
367 _x = np.linspace(0, 5, 500)
368 Uniform1 = xr.DataArray(
369 np.array([1 if 1 <= x <= 3 else 0 for x in _x]), dims=["x"], coords=dict(x=_x)
370 )
371 Uniform2 = xr.DataArray(
372 np.array([1 if 2 <= x <= 4 else 0 for x in _x]), dims=["x"], coords=dict(x=_x)
373 )
375 # Total area where the two do not overlap is 2, so Hellinger distance = 1/2 * 2 = 1
376 assert ops.Hellinger_distance(Uniform1, Uniform2) == pytest.approx(1.0, abs=1e-2)
378 # Test interpolation works: shifted uniform distribution, area of non-overlap = 3, so Hellinger
379 # distance = 3/2
380 Uniform1 = xr.DataArray(
381 np.array([1 if 2 <= x <= 4 else 0 for x in np.linspace(1, 6, 500)]),
382 dims=["x"],
383 coords=dict(x=np.linspace(1, 6, 500)),
384 )
385 Uniform2 = xr.DataArray(
386 np.array([1 if 3 <= x else 0 for x in np.linspace(2, 7, 750)]),
387 dims=["x"],
388 coords=dict(x=np.linspace(2, 7, 750)),
389 )
391 assert ops.Hellinger_distance(Uniform1, Uniform2) == pytest.approx(1.5, abs=1e-2)