Coverage for tests/models/test_SIR_DataGeneration.py: 97%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4import h5py as h5 

5import pytest 

6import torch 

7from dantro._import_tools import import_module_from_path 

8from pkg_resources import resource_filename 

9 

10from utopya.yaml import load_yml 

11 

12sys.path.insert(0, up(up(up(__file__)))) 

13 

14SIR = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="models.SIR") 

15 

16# Load the test config 

17CFG_FILENAME = resource_filename("tests", "cfgs/SIR_DataGeneration.yml") 

18test_cfg = load_yml(CFG_FILENAME) 

19 

20 

21# Test that ABM data and smooth data are generated 

22def test_data_generation(tmpdir): 

23 # Create an h5File in the temporary directory for the 

24 h5dir = tmpdir.mkdir("hdf5_data") 

25 

26 for name, config in test_cfg.items(): 

27 h5file = h5.File(h5dir.join(f"test_{name}.h5"), "w") 

28 h5group = h5file.create_group("SIR") 

29 

30 synthetic_data = SIR.get_SIR_data( 

31 data_cfg=config, h5group=h5group, write_init_state=False 

32 ) 

33 

34 n = config["synthetic_data"]["num_steps"] 

35 assert len(synthetic_data) == n 

36 

37 # Check the densities are consistent in the noiseless case, and non-negative in the noisy case 

38 sigma = config["synthetic_data"].pop("sigma", 0) 

39 if sigma == 0: 

40 assert ( 

41 torch.round(torch.sum(synthetic_data, dim=1), decimals=4) 

42 == torch.tensor([1.0]) 

43 ).all() 

44 else: 

45 assert (synthetic_data >= 0).all() 

46 

47 # Check infection has taken place 

48 assert torch.max(synthetic_data, axis=0)[1][1] > synthetic_data[-1][1] 

49 

50 # Check recovery has taken place 

51 assert synthetic_data[0][0] > synthetic_data[-1][0] 

52 assert synthetic_data[-1][-1] > synthetic_data[0][-1] 

53 

54 # Check data has been written to the h5Group 

55 if config["synthetic_data"]["type"] == "from_ABM": 

56 assert list(h5group.keys()) == ["kinds", "position", "true_counts"] 

57 elif config["synthetic_data"]["type"] == "smooth": 

58 assert list(h5group.keys()) == ["true_counts"]