Coverage for include/utils.py: 93%
15 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1from typing import Union
3import torch
5""" General utility functions """
8def random_tensor(
9 cfg: Union[dict, list], *, size: tuple = None, device: str = "cpu", **__
10) -> torch.Tensor:
11 """Generates a multi-dimensional random tensor. Each entry can be initialised separately, or a common
12 initialisation configuration is used for each entry. For instance, the configuration
14 .. code-block::
16 cfg:
17 distribution: uniform
18 parameters:
19 lower: 0
20 upper: 1
22 together with `size: (2, 2)` will initialise a 2x2 matrix with entries drawn from a uniform distribution on
23 [0, 1]. The configuration
25 .. code-block::
27 cfg:
28 - distribution: uniform
29 parameters:
30 lower: 0
31 upper: 1
32 - distribution: normal
33 parameters:
34 mean: 0.5
35 std: 0.1
37 will initialise a (2, 1) tensor with entries drawn from different distributions.
39 :param cfg: the configuration entry containing the initialisation data
40 :param size (optional): the size of the tensor, in case the configuration is not a list
41 :param device: the device onto which to load the data
42 :param __: additional kwargs (ignored)
43 :return: the tensor of random variables
44 """
46 def _random_tensor_1d(
47 *, distribution: str, parameters: dict, s: tuple = (1,), **__
48 ) -> torch.Tensor:
49 """Generates a random tensor according to a distribution.
51 :param distribution: the type of distribution. Can be 'uniform' or 'normal'.
52 :param parameters: the parameters relevant to the respective distribution
53 :param s: the size of the random tensor
54 """
56 # Uniform distribution in an interval
57 if distribution == "uniform":
58 l, u = parameters.get("lower"), parameters.get("upper")
59 if l > u:
60 raise ValueError(
61 f"Upper bound must be greater or equal to lower bound; got {l} and {u}!"
62 )
64 return torch.tensor((u - l), dtype=torch.float) * torch.rand(
65 s, dtype=torch.float, device=device
66 ) + torch.tensor(l, dtype=torch.float)
68 # Normal distribution
69 elif distribution == "normal":
70 return torch.normal(
71 parameters.get("mean"),
72 parameters.get("std"),
73 size=s,
74 device=device,
75 dtype=torch.float,
76 )
78 else:
79 raise ValueError(f"Unrecognised distribution type {distribution}!")
81 if isinstance(cfg, list):
82 return torch.tensor([_random_tensor_1d(**entry) for entry in cfg]).to(device)
83 else:
84 return _random_tensor_1d(**cfg, s=size).to(device)