Coverage for tests/test_data_ops.py: 100%

163 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4import numpy as np 

5import pytest 

6import scipy.integrate 

7import xarray as xr 

8from dantro._import_tools import import_module_from_path 

9from pkg_resources import resource_filename 

10 

11from utopya.yaml import load_yml 

12 

13sys.path.insert(0, up(up(up(__file__)))) 

14 

15ops = import_module_from_path( 

16 mod_path=up(up(up(__file__))), mod_str="model_plots.data_ops" 

17) 

18 

19# Load the test config 

20CFG_FILENAME = resource_filename("tests", "cfgs/data_ops.yml") 

21test_cfg = load_yml(CFG_FILENAME) 

22 

23 

24# ---------------------------------------------------------------------------------------------------------------------- 

25# DECORATOR 

26# ---------------------------------------------------------------------------------------------------------------------- 

27 

28 

29def test_apply_along_dim(): 

30 """Tests the apply_along_dim decorator""" 

31 

32 @ops.apply_along_dim 

33 def _test_func(data): 

34 return xr.DataArray(np.mean(data), name=data.name) 

35 

36 # Test on DataArray 

37 da = xr.DataArray( 

38 np.ones((10, 10, 10)), 

39 dims=["x", "y", "z"], 

40 coords={k: np.arange(10) for k in ["x", "y", "z"]}, 

41 name="foo", 

42 ) 

43 

44 res = _test_func(da, along_dim=["x"]) 

45 

46 # apply_along_dim returns a Dataset 

47 assert isinstance(res, xr.Dataset) 

48 

49 # Assert operation was applied along that dimension 

50 assert set(res.coords.keys()) == {"y", "z"} 

51 assert list(res.data_vars) == ["foo"] 

52 assert all(res["foo"].data.flatten() == 1) 

53 

54 # Assert specifying operation is the same as excluding certain dimensions 

55 assert res == _test_func(da, exclude_dim=["y", "z"]) 

56 

57 # Apply along all dimensions 

58 res = _test_func(da, along_dim=["x", "y", "z"]) 

59 assert res == xr.Dataset(data_vars=dict(foo=([], 1))) 

60 

61 # Test applying function without decorator 

62 res = _test_func(da) 

63 assert isinstance(res, type(da)) 

64 assert res == xr.DataArray(1) 

65 

66 # Test applying the decorator along multiple args 

67 @ops.apply_along_dim 

68 def _test_func_nD(da1, da2, da3, *, op: str = "sum"): 

69 if op == "sum": 

70 return (da1 * da2 * da3).sum() 

71 else: 

72 return (da1 * da2 * da3).mean() 

73 

74 res = _test_func_nD(da, 2 * da, -1 * da, along_dim=["x"]) 

75 assert set(res.coords.keys()) == {"y", "z"} 

76 assert all(res["foo"].data.flatten() == len(da.coords["x"]) * 2 * -1) 

77 assert res == _test_func_nD(da, 2 * da, -1 * da, exclude_dim=["y", "z"]) 

78 

79 # Test passing kwargs to function 

80 res = _test_func_nD(da, 2 * da, -1 * da, along_dim=["x"], op="mean") 

81 assert all(res["foo"].data.flatten() == 1 * 2 * -1) 

82 

83 # Test passing both ``along_dim`` and ``exclude_dim`` raises an error 

84 with pytest.raises(ValueError, match="Cannot provide both"): 

85 _test_func(da, along_dim=["x"], exclude_dim=["y"]) 

86 

87 

88# ---------------------------------------------------------------------------------------------------------------------- 

89# DATA RESHAPING AND REORGANIZING 

90# ---------------------------------------------------------------------------------------------------------------------- 

91def test_concat(): 

92 """Tests concatenation of multiple xarray objects""" 

93 

94 da = xr.DataArray( 

95 np.ones((10, 10, 10)), 

96 dims=["x", "y", "z"], 

97 coords={k: np.arange(10) for k in ["x", "y", "z"]}, 

98 name="foo", 

99 ) 

100 

101 # Test on DataArrays and Datasets 

102 for _obj in [da, xr.Dataset(dict(foo=da))]: 

103 res = ops.concat([_obj, _obj, _obj], "type", ["a", "b", "c"]) 

104 assert set(res.coords.keys()) == set(list(_obj.coords.keys()) + ["type"]) 

105 

106 

107def test_flatten_dims(): 

108 """Test flattening coordinates of xarray objects into one""" 

109 

110 da = xr.DataArray( 

111 np.ones((5, 5, 5)), 

112 dims=["dim1", "dim2", "dim3"], 

113 coords={k: np.arange(5) for k in ["dim1", "dim2", "dim3"]}, 

114 ) 

115 

116 # Test on DataArrays and Datasets 

117 for _obj in [da, xr.Dataset(dict(foo=da))]: 

118 res = ops.flatten_dims(_obj, dims={"dim2": ["dim2", "dim3"]}) 

119 assert set(res.coords.keys()) == {"dim1", "dim2"} 

120 assert len( 

121 res.coords["dim2"] == len(_obj.coords["dim2"]) * len(_obj.coords["dim3"]) 

122 ) 

123 

124 # Test reassigning coordinates 

125 res = ops.flatten_dims( 

126 _obj, dims={"dim2": ["dim2", "dim3"]}, new_coords=np.arange(25, 50, 1) 

127 ) 

128 assert all(res.coords["dim2"] == np.arange(25, 50, 1)) 

129 

130 

131def test_broadcast(): 

132 """Test broadcasting xr.DataArray s into a single Dataset""" 

133 da1 = xr.DataArray(np.random.rand(10, 3), dims=["sample", "parameter"]) 

134 da2 = xr.DataArray(np.exp(-np.linspace(0, 1, 10)), dims=["sample"]) 

135 res = ops.broadcast(da1, da2, x="x", p="loss") 

136 

137 assert isinstance(res, xr.Dataset) 

138 assert set(res.data_vars) == {"x", "loss"} 

139 assert res == ops.broadcast(da2, da1) 

140 

141 

142# ---------------------------------------------------------------------------------------------------------------------- 

143# BASIC STATISTICS FUNCTIONS 

144# ---------------------------------------------------------------------------------------------------------------------- 

145def test_stat_function(): 

146 """Tests the statistics functions on a normal distribution""" 

147 _x = np.linspace(-5, 5, 1000) 

148 _m, _std = 0.0, 1.0 

149 _f = np.exp(-((_x - _m) ** 2) / (2 * _std**2)) 

150 _norm = scipy.integrate.trapezoid(_f, _x) 

151 

152 ds = xr.Dataset(dict(y=(["x"], _f)), coords=dict(x=_x)) 

153 

154 # Test normalisation 

155 ds = ops.normalize(ds, x="x", y="y") 

156 assert scipy.integrate.trapezoid(ds["y"], ds.coords["x"]) == pytest.approx( 

157 1.0, 1e-4 

158 ) 

159 

160 # Test mean 

161 mean = ops.stat_function(ds, stat="mean", x="x", y="y") 

162 assert mean == pytest.approx(_m, abs=1e-3) 

163 

164 # Test mean with x as a coordinate or variable 

165 ds = xr.Dataset(dict(y=(["x"], ds["y"].data), x_val=ds.coords["x"])) 

166 assert mean == ops.stat_function(ds, stat="mean", x="x_val", y="y") 

167 assert mean == ops.stat_function(ds, stat="mean", x="x", y="y") 

168 

169 # Test standard deviation 

170 std = ops.stat_function(ds, stat="std", x="x", y="y") 

171 assert std == pytest.approx(_std, abs=1e-3) 

172 

173 # Test interquartile range 

174 iqr = ops.stat_function(ds, stat="iqr", x="x", y="y") 

175 assert iqr == pytest.approx(1.34, abs=1e-2) 

176 

177 # Test mode 

178 mode = ops.stat_function(ds, stat="mode", x="x", y="y") 

179 assert mode["mode_x"].data.item() == pytest.approx(_m, abs=1e-2) 

180 assert mode["mode_y"].data.item() == pytest.approx(1.0 / _norm, abs=1e-2) 

181 

182 # Test peak width calculation 

183 peak_widths = ops.stat_function(ds, stat="avg_peak_width", x="x", y="y", width=1) 

184 assert peak_widths["mean_peak_width"] == pytest.approx(2.355 * _std, abs=1e-2) 

185 assert peak_widths["peak_width_std"] == 0.0 

186 

187 # Test p-value calculation 

188 assert ops.p_value(ds, 0.0, x="x", y="y") == pytest.approx(0.5, abs=1e-1) 

189 assert ops.p_value(ds, -1, x="x", y="y") == pytest.approx(0.159, abs=5e-3) 

190 assert ops.p_value(ds, -1, x="x", y="y") == pytest.approx( 

191 ops.p_value(ds, +1, x="x", y="y"), abs=5e-3 

192 ) 

193 assert ops.p_value(ds, xr.DataArray(0.0), x="x", y="y") == ops.p_value( 

194 ds, 0.0, x="x", y="y" 

195 ) 

196 

197 # Assert the p-value for a Gaussian wrt the mean is the same as wrt to the mode 

198 assert ops.p_value(ds, -1, x="x", y="y", null="mean") == ops.p_value( 

199 ds, -1, x="x", y="y", null="mode" 

200 ) 

201 

202 

203# ---------------------------------------------------------------------------------------------------------------------- 

204# HISTOGRAMS 

205# ---------------------------------------------------------------------------------------------------------------------- 

206def test_hist(): 

207 """Tests histogram functions""" 

208 

209 _n_samples = 100 

210 _n_vals = 100 

211 _n_bins = 20 

212 

213 # Test histogramming a 1D array 

214 da = xr.DataArray( 

215 np.random.rand(_n_vals), dims=["i"], coords=dict(i=np.arange(_n_vals)) 

216 ) 

217 hist = ops.hist(da, _n_bins, dim="i") 

218 assert set(hist.coords.keys()) == {"bin_idx"} 

219 assert set(hist.data_vars) == {"count", "x"} 

220 

221 # Test total number of counts has not changed 

222 assert all(hist["count"].sum("bin_idx").data.flatten() == _n_vals) 

223 

224 # Repeat the same thing, this time using the bin centres as coordinates 

225 hist = ops.hist(da, _n_bins, [0, 1], dim="i", use_bins_as_coords=True) 

226 assert isinstance(hist, xr.DataArray) 

227 assert set(hist.coords.keys()) == {"x"} 

228 assert all( 

229 hist.coords["x"].data 

230 == 0.5 

231 * (np.linspace(0, 1, _n_bins + 1)[1:] + np.linspace(0, 1, _n_bins + 1)[:-1]) 

232 ) 

233 assert all(hist.sum("x").data.flatten() == _n_vals) 

234 

235 # Test histogramming a 2D array 

236 da = xr.DataArray( 

237 np.random.rand(_n_samples, _n_vals), 

238 dims=["sample", "i"], 

239 coords=dict(sample=np.arange(_n_samples), i=np.arange(_n_vals)), 

240 ) 

241 hist = ops.hist(da, _n_bins, dim="i") 

242 assert set(hist.coords.keys()) == {"sample", "bin_idx"} 

243 assert set(hist.data_vars) == {"count", "x"} 

244 assert all(hist["count"].sum("bin_idx").data.flatten() == _n_vals) 

245 

246 hist = ops.hist(da, _n_bins, dim="i", ranges=[0, 1], use_bins_as_coords=True) 

247 assert set(hist.coords.keys()) == {"sample", "x"} 

248 assert all(hist.sum("x").data.flatten() == _n_vals) 

249 

250 # Test histogramming a 3D array 

251 da = da.expand_dims(dict(dim0=np.arange(4))) 

252 hist = ops.hist( 

253 da, 

254 dim="i", 

255 exclude_dim=["dim0"], 

256 bins=_n_bins, 

257 ranges=[0, 1], 

258 use_bins_as_coords=True, 

259 ) 

260 assert set(hist.coords.keys()) == {"sample", "x", "dim0"} 

261 

262 hist = ops.hist(da, dim="i", exclude_dim=["dim0"], bins=_n_bins) 

263 assert set(hist.coords.keys()) == {"sample", "bin_idx", "dim0"} 

264 

265 # Test normalisation of bin counts 

266 hist_normalised = ops.hist(da, bins=100, dim="i", ranges=[0, 1], normalize=2.0) 

267 assert hist_normalised["count"].sum("bin_idx").data.flatten() == pytest.approx( 

268 2.0, abs=1e-10 

269 ) 

270 

271 # Test selecting range 

272 hist_clipped = ops.hist(da, bins=100, dim="i", ranges=[1, 2]) 

273 assert all(hist_clipped["count"].data.flatten() == 0) 

274 

275 

276# ---------------------------------------------------------------------------------------------------------------------- 

277# PROBABILITY DENSITY OPERATIONS 

278# ---------------------------------------------------------------------------------------------------------------------- 

279 

280 

281def test_joint_and_marginal_2D(): 

282 """Test two-dimensional joint distributions and marginals from joints are correctly calculated""" 

283 

284 # Generate a 2D-Gaussian on a square domain 

285 _lower, _upper, _n_vals = -5, +5, 1000 

286 _bins = 50 

287 _m, _std = 0.0, 1.0 

288 _x, _y = (_upper - _lower) * np.random.rand(_n_vals) + _lower, ( 

289 _upper - _lower 

290 ) * np.random.rand(_n_vals) + _lower 

291 _f = np.exp(-((_x - _m) ** 2) / (2 * _std**2)) * np.exp( 

292 -((_y - _m) ** 2) / (2 * _std**2) 

293 ) 

294 

295 # Calculate the joint distribution 

296 joint = ops.joint_2D(_x, _y, _f, bins=_bins, normalize=1.0) 

297 assert set(joint.coords.keys()) == {"x", "y"} 

298 

299 # Assert the maximum of the joint is roughly in the middle 

300 assert all( 

301 [ 

302 idx == pytest.approx(25, abs=2) 

303 for idx in np.unravel_index(joint.argmax(), joint.shape) 

304 ] 

305 ) 

306 

307 # Assert the marginal distribution of each dimension is normalized 

308 marginal_x = ops.marginal_from_joint(joint, parameter="x") 

309 marginal_y = ops.marginal_from_joint(joint, parameter="y") 

310 assert scipy.integrate.trapezoid(marginal_x["y"], marginal_x["x"]) == pytest.approx( 

311 1.0, abs=1e-4 

312 ) 

313 assert scipy.integrate.trapezoid(marginal_y["y"], marginal_y["x"]) == pytest.approx( 

314 1.0, abs=1e-4 

315 ) 

316 

317 # Assert the marginals are again approximately Gaussian 

318 assert ops.mean(marginal_x, x="x", y="y") == pytest.approx(0.0, abs=1e-1) 

319 assert ops.mean(marginal_y, x="x", y="y") == pytest.approx(0.0, abs=1e-1) 

320 assert ops.std(marginal_x, x="x", y="y") == pytest.approx(1.0, abs=5e-2) 

321 assert ops.std(marginal_y, x="x", y="y") == pytest.approx(1.0, abs=5e-2) 

322 

323 # Assert alternative joint operation does the same thing 

324 joint_from_ds = ops.joint_2D_ds( 

325 xr.Dataset(dict(x=_x, y=_y)), _f, x="x", y="y", bins=_bins, normalize=1.0 

326 ) 

327 assert joint_from_ds.equals(joint) 

328 

329 # Assert 3D joint 

330 _z = (_upper - _lower) * np.random.rand(_n_vals) + _lower 

331 samples = ( 

332 ops.concat( 

333 [xr.DataArray(_x), xr.DataArray(_y), xr.DataArray(_z)], 

334 "parameter", 

335 ["a", "b", "c"], 

336 ) 

337 .transpose() 

338 .assign_coords(dict(dim_0=np.arange(_n_vals))) 

339 ) 

340 

341 _f = xr.DataArray(_f * np.exp(-((_z - _m) ** 2) / (2 * _std**2))).assign_coords( 

342 dict(dim_0=np.arange(_n_vals)) 

343 ) 

344 

345 joint_3D = ops.joint_DD(samples, _f, bins=50) 

346 assert set(joint_3D.coords) == {"a", "b", "c"} 

347 

348 

349def test_distances_between_densities(): 

350 """Tests the Hellinger distance between distributions""" 

351 _x = np.linspace(-5, 5, 500) 

352 Gaussian = xr.DataArray( 

353 np.exp(-(_x**2) / 2), dims=["x"], coords=dict(x=_x) 

354 ) # mean = 0, std = 1 

355 Gaussian /= scipy.integrate.trapezoid(Gaussian.data, _x) 

356 

357 assert ops.Hellinger_distance(Gaussian, Gaussian) == 0 

358 assert ops.relative_entropy(Gaussian, Gaussian) == 0 

359 assert ops.Lp_distance(Gaussian, Gaussian, p=2) == 0 

360 

361 # Test calculating the distances on a xr.Dataset instead 

362 Gaussian = xr.Dataset(dict(_x=Gaussian.coords["x"], y=Gaussian)) 

363 assert ops.Hellinger_distance(Gaussian, Gaussian, x="_x", y="y") == 0 

364 Gaussian = xr.Dataset(dict(y=Gaussian["y"])) 

365 assert ops.Hellinger_distance(Gaussian, Gaussian, x="x", y="y") == 0 

366 

367 _x = np.linspace(0, 5, 500) 

368 Uniform1 = xr.DataArray( 

369 np.array([1 if 1 <= x <= 3 else 0 for x in _x]), dims=["x"], coords=dict(x=_x) 

370 ) 

371 Uniform2 = xr.DataArray( 

372 np.array([1 if 2 <= x <= 4 else 0 for x in _x]), dims=["x"], coords=dict(x=_x) 

373 ) 

374 

375 # Total area where the two do not overlap is 2, so Hellinger distance = 1/2 * 2 = 1 

376 assert ops.Hellinger_distance(Uniform1, Uniform2) == pytest.approx(1.0, abs=1e-2) 

377 

378 # Test interpolation works: shifted uniform distribution, area of non-overlap = 3, so Hellinger 

379 # distance = 3/2 

380 Uniform1 = xr.DataArray( 

381 np.array([1 if 2 <= x <= 4 else 0 for x in np.linspace(1, 6, 500)]), 

382 dims=["x"], 

383 coords=dict(x=np.linspace(1, 6, 500)), 

384 ) 

385 Uniform2 = xr.DataArray( 

386 np.array([1 if 3 <= x else 0 for x in np.linspace(2, 7, 750)]), 

387 dims=["x"], 

388 coords=dict(x=np.linspace(2, 7, 750)), 

389 ) 

390 

391 assert ops.Hellinger_distance(Uniform1, Uniform2) == pytest.approx(1.5, abs=1e-2)