Coverage for tests/models/test_data_loading.py: 88%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4from dantro._import_tools import import_module_from_path 

5from dantro._yaml import load_yml 

6from pkg_resources import resource_filename 

7 

8from utopya.testtools import ModelTest 

9 

10sys.path.insert(0, up(up(up(__file__)))) 

11 

12SIR = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="models.SIR") 

13HW = import_module_from_path( 

14 mod_path=up(up(up(__file__))), mod_str="models.HarrisWilson" 

15) 

16 

17# Load the test config 

18CFG_FILENAME = resource_filename("tests", "cfgs/test_data_loading.yml") 

19test_cfg = load_yml(CFG_FILENAME) 

20 

21 

22def test_data_loading(): 

23 for _, config in test_cfg.items(): 

24 # Get the model type 

25 model_name = config.pop("model") 

26 

27 mtc = ModelTest(model_name) 

28 model = mtc.create_run_load(**config) 

29 

30 assert model[1] 

31 

32 # Load the previously generated data and run again 

33 if model in ["HarrisWilson", "SIR"]: 

34 config["parameter_space"][model_name].update( 

35 { 

36 "Data": { 

37 "load_from_dir": model[0]._dirs["run"] + "/data/uni0/data.h5" 

38 } 

39 } 

40 ) 

41 if model in ["Kuramoto", "HarrisWilsonNW"]: 

42 for ele in ["network", "eigen_frequencies", "training_data"]: 

43 config["parameter_space"][model_name].update( 

44 { 

45 "Data": { 

46 "load_from_dir": { 

47 ele: model[0]._dirs["run"] + "/data/uni0/data.h5" 

48 } 

49 } 

50 } 

51 ) 

52 

53 model = mtc.create_run_load(**config) 

54 

55 assert model[1] 

56 

57 del model