Coverage for tests/models/test_SIR_DataGeneration.py: 97%
31 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
2from os.path import dirname as up
4import h5py as h5
5import pytest
6import torch
7from dantro._import_tools import import_module_from_path
8from pkg_resources import resource_filename
10from utopya.yaml import load_yml
12sys.path.insert(0, up(up(up(__file__))))
14SIR = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="models.SIR")
16# Load the test config
17CFG_FILENAME = resource_filename("tests", "cfgs/SIR_DataGeneration.yml")
18test_cfg = load_yml(CFG_FILENAME)
21# Test that ABM data and smooth data are generated
22def test_data_generation(tmpdir):
23 # Create an h5File in the temporary directory for the
24 h5dir = tmpdir.mkdir("hdf5_data")
26 for name, config in test_cfg.items():
27 h5file = h5.File(h5dir.join(f"test_{name}.h5"), "w")
28 h5group = h5file.create_group("SIR")
30 synthetic_data = SIR.get_SIR_data(
31 data_cfg=config, h5group=h5group, write_init_state=False
32 )
34 n = config["synthetic_data"]["num_steps"]
35 assert len(synthetic_data) == n
37 # Check the densities are consistent in the noiseless case, and non-negative in the noisy case
38 sigma = config["synthetic_data"].pop("sigma", 0)
39 if sigma == 0:
40 assert (
41 torch.round(torch.sum(synthetic_data, dim=1), decimals=4)
42 == torch.tensor([1.0])
43 ).all()
44 else:
45 assert (synthetic_data >= 0).all()
47 # Check infection has taken place
48 assert torch.max(synthetic_data, axis=0)[1][1] > synthetic_data[-1][1]
50 # Check recovery has taken place
51 assert synthetic_data[0][0] > synthetic_data[-1][0]
52 assert synthetic_data[-1][-1] > synthetic_data[0][-1]
54 # Check data has been written to the h5Group
55 if config["synthetic_data"]["type"] == "from_ABM":
56 assert list(h5group.keys()) == ["kinds", "position", "true_counts"]
57 elif config["synthetic_data"]["type"] == "smooth":
58 assert list(h5group.keys()) == ["true_counts"]