Coverage for include/utils.py: 93%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1from typing import Union 

2 

3import torch 

4 

5""" General utility functions """ 

6 

7 

8def random_tensor( 

9 cfg: Union[dict, list], *, size: tuple = None, device: str = "cpu", **__ 

10) -> torch.Tensor: 

11 """Generates a multi-dimensional random tensor. Each entry can be initialised separately, or a common 

12 initialisation configuration is used for each entry. For instance, the configuration 

13 

14 .. code-block:: 

15 

16 cfg: 

17 distribution: uniform 

18 parameters: 

19 lower: 0 

20 upper: 1 

21 

22 together with `size: (2, 2)` will initialise a 2x2 matrix with entries drawn from a uniform distribution on 

23 [0, 1]. The configuration 

24 

25 .. code-block:: 

26 

27 cfg: 

28 - distribution: uniform 

29 parameters: 

30 lower: 0 

31 upper: 1 

32 - distribution: normal 

33 parameters: 

34 mean: 0.5 

35 std: 0.1 

36 

37 will initialise a (2, 1) tensor with entries drawn from different distributions. 

38 

39 :param cfg: the configuration entry containing the initialisation data 

40 :param size (optional): the size of the tensor, in case the configuration is not a list 

41 :param device: the device onto which to load the data 

42 :param __: additional kwargs (ignored) 

43 :return: the tensor of random variables 

44 """ 

45 

46 def _random_tensor_1d( 

47 *, distribution: str, parameters: dict, s: tuple = (1,), **__ 

48 ) -> torch.Tensor: 

49 """Generates a random tensor according to a distribution. 

50 

51 :param distribution: the type of distribution. Can be 'uniform' or 'normal'. 

52 :param parameters: the parameters relevant to the respective distribution 

53 :param s: the size of the random tensor 

54 """ 

55 

56 # Uniform distribution in an interval 

57 if distribution == "uniform": 

58 l, u = parameters.get("lower"), parameters.get("upper") 

59 if l > u: 

60 raise ValueError( 

61 f"Upper bound must be greater or equal to lower bound; got {l} and {u}!" 

62 ) 

63 

64 return torch.tensor((u - l), dtype=torch.float) * torch.rand( 

65 s, dtype=torch.float, device=device 

66 ) + torch.tensor(l, dtype=torch.float) 

67 

68 # Normal distribution 

69 elif distribution == "normal": 

70 return torch.normal( 

71 parameters.get("mean"), 

72 parameters.get("std"), 

73 size=s, 

74 device=device, 

75 dtype=torch.float, 

76 ) 

77 

78 else: 

79 raise ValueError(f"Unrecognised distribution type {distribution}!") 

80 

81 if isinstance(cfg, list): 

82 return torch.tensor([_random_tensor_1d(**entry) for entry in cfg]).to(device) 

83 else: 

84 return _random_tensor_1d(**cfg, s=size).to(device)