Coverage for tests/core/test_utils.py: 95%
42 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
2from builtins import *
3from os.path import dirname as up
5import pytest
6import torch
7from dantro._import_tools import import_module_from_path
8from pkg_resources import resource_filename
10from utopya.yaml import load_yml
12sys.path.insert(0, up(up(up(__file__))))
14utils = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include.utils")
16# Load the test config
17CFG_FILENAME = resource_filename("tests", "cfgs/test_utils.yml")
18test_cfg = load_yml(CFG_FILENAME)
21def test_random_tensor():
22 def _test_entry(cfg, tensor):
23 if cfg["distribution"] == "uniform":
24 assert cfg["parameters"]["lower"] <= tensor <= cfg["parameters"]["upper"]
26 for _, config in test_cfg.items():
27 _raises = config.pop("_raises", False)
28 _exp_exc = Exception if not isinstance(_raises, str) else globals()[_raises]
29 _warns = config.pop("_warns", False)
30 _exp_warning = UserWarning if not isinstance(_warns, str) else globals()[_warns]
31 _match = config.pop("_match", " ")
33 cfg = config if "cfg" not in config.keys() else config.get("cfg")
35 if not _raises:
36 for size in [(1,), (4, 4, 4)]:
37 t = utils.random_tensor(cfg, size=size)
39 if isinstance(cfg, list):
40 assert len(t) == len(cfg)
41 else:
42 assert t.shape == torch.Size(size)
44 t = torch.flatten(t)
45 for _ in range(len(t)):
46 if isinstance(cfg, dict):
47 _test_entry(cfg, t[_])
48 else:
49 _test_entry(cfg[_], t[_])
51 if not _raises and not _warns:
52 utils.random_tensor(cfg, size=(1,))
54 elif _warns and not _raises:
55 with pytest.warns(_exp_warning, match=_match):
56 utils.random_tensor(cfg, size=(1,))
58 elif _raises and not _warns:
59 with pytest.raises(_exp_exc, match=_match):
60 utils.random_tensor(cfg, size=(1,))