Coverage for tests/core/test_utils.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from builtins import * 

3from os.path import dirname as up 

4 

5import pytest 

6import torch 

7from dantro._import_tools import import_module_from_path 

8from pkg_resources import resource_filename 

9 

10from utopya.yaml import load_yml 

11 

12sys.path.insert(0, up(up(up(__file__)))) 

13 

14utils = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include.utils") 

15 

16# Load the test config 

17CFG_FILENAME = resource_filename("tests", "cfgs/test_utils.yml") 

18test_cfg = load_yml(CFG_FILENAME) 

19 

20 

21def test_random_tensor(): 

22 def _test_entry(cfg, tensor): 

23 if cfg["distribution"] == "uniform": 

24 assert cfg["parameters"]["lower"] <= tensor <= cfg["parameters"]["upper"] 

25 

26 for _, config in test_cfg.items(): 

27 _raises = config.pop("_raises", False) 

28 _exp_exc = Exception if not isinstance(_raises, str) else globals()[_raises] 

29 _warns = config.pop("_warns", False) 

30 _exp_warning = UserWarning if not isinstance(_warns, str) else globals()[_warns] 

31 _match = config.pop("_match", " ") 

32 

33 cfg = config if "cfg" not in config.keys() else config.get("cfg") 

34 

35 if not _raises: 

36 for size in [(1,), (4, 4, 4)]: 

37 t = utils.random_tensor(cfg, size=size) 

38 

39 if isinstance(cfg, list): 

40 assert len(t) == len(cfg) 

41 else: 

42 assert t.shape == torch.Size(size) 

43 

44 t = torch.flatten(t) 

45 for _ in range(len(t)): 

46 if isinstance(cfg, dict): 

47 _test_entry(cfg, t[_]) 

48 else: 

49 _test_entry(cfg[_], t[_]) 

50 

51 if not _raises and not _warns: 

52 utils.random_tensor(cfg, size=(1,)) 

53 

54 elif _warns and not _raises: 

55 with pytest.warns(_exp_warning, match=_match): 

56 utils.random_tensor(cfg, size=(1,)) 

57 

58 elif _raises and not _warns: 

59 with pytest.raises(_exp_exc, match=_match): 

60 utils.random_tensor(cfg, size=(1,))