Coverage for include / neural_net.py: 80%
112 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1from typing import Any, List, Sequence, Union
3import torch
4from torch import nn
6from .utils import random_tensor
8# ----------------------------------------------------------------------------------------------------------------------
9# -- NN utility functions ----------------------------------------------------------------------------------------------
10# ----------------------------------------------------------------------------------------------------------------------
13def sigmoid(beta=torch.tensor(1.0)):
14 """Extends the torch.nn.sigmoid activation function by allowing for a slope parameter."""
16 return lambda x: torch.sigmoid(beta * x)
19# Pytorch activation functions.
20# Pairs of activation functions and whether they are part of the torch.nn module, in which case they must be called
21# via func(*args, **kwargs)(x).
24ACTIVATION_FUNCS = {
25 "abs": [torch.abs, False],
26 "celu": [torch.nn.CELU, True],
27 "cos": [torch.cos, False],
28 "cosine": [torch.cos, False],
29 "elu": [torch.nn.ELU, True],
30 "gelu": [torch.nn.GELU, True],
31 "hardshrink": [torch.nn.Hardshrink, True],
32 "hardsigmoid": [torch.nn.Hardsigmoid, True],
33 "hardswish": [torch.nn.Hardswish, True],
34 "hardtanh": [torch.nn.Hardtanh, True],
35 "leakyrelu": [torch.nn.LeakyReLU, True],
36 "linear": [None, False],
37 "logsigmoid": [torch.nn.LogSigmoid, True],
38 "mish": [torch.nn.Mish, True],
39 "prelu": [torch.nn.PReLU, True],
40 "relu": [torch.nn.ReLU, True],
41 "rrelu": [torch.nn.RReLU, True],
42 "selu": [torch.nn.SELU, True],
43 "sigmoid": [sigmoid, True],
44 "silu": [torch.nn.SiLU, True],
45 "sin": [torch.sin, False],
46 "sine": [torch.sin, False],
47 "softplus": [torch.nn.Softplus, True],
48 "softshrink": [torch.nn.Softshrink, True],
49 "softsign": [torch.nn.Softsign, True],
50 "swish": [torch.nn.SiLU, True],
51 "tanh": [torch.nn.Tanh, True],
52 "tanhshrink": [torch.nn.Tanhshrink, True],
53 "threshold": [torch.nn.Threshold, True],
54}
56OPTIMIZERS = {
57 "Adagrad": torch.optim.Adagrad,
58 "Adam": torch.optim.Adam,
59 "AdamW": torch.optim.AdamW,
60 "SparseAdam": torch.optim.SparseAdam,
61 "Adamax": torch.optim.Adamax,
62 "ASGD": torch.optim.ASGD,
63 "LBFGS": torch.optim.LBFGS,
64 "NAdam": torch.optim.NAdam,
65 "RAdam": torch.optim.RAdam,
66 "RMSprop": torch.optim.RMSprop,
67 "Rprop": torch.optim.Rprop,
68 "SGD": torch.optim.SGD,
69}
71def get_architecture(
72 input_size: int, output_size: int, n_layers: int, cfg: dict
73) -> List[int]:
74 # Apply default to all hidden layers
75 _nodes = [cfg.get("default")] * n_layers
77 # Update layer-specific settings
78 _layer_specific = cfg.get("layer_specific", {})
79 for layer_id, layer_size in _layer_specific.items():
80 _nodes[layer_id] = layer_size
82 return [input_size] + _nodes + [output_size]
84def get_single_layer_func(layer_cfg: Union[str, dict]) -> callable:
86 """Return the activation function from an entry for a single layer"""
88 # Entry is a single string
89 if isinstance(layer_cfg, str):
90 _f = ACTIVATION_FUNCS[layer_cfg.lower()]
91 if _f[1]:
92 return _f[0]()
93 else:
94 return _f[0]
96 # Entry is a dictionary containing args and kwargs
97 elif isinstance(layer_cfg, dict):
98 _f = ACTIVATION_FUNCS[layer_cfg.get("name").lower()]
99 if _f[1]:
100 return _f[0](*layer_cfg.get("args", ()), **layer_cfg.get("kwargs", {}))
101 else:
102 return _f[0]
104 elif layer_cfg is None:
105 _f = ACTIVATION_FUNCS["linear"][0]
107 else:
108 raise ValueError(f"Unrecognized activation function {layer_cfg}!")
110def get_activation_funcs(n_layers: int, cfg: dict) -> List[callable]:
111 """Extracts the activation functions from the config. The config is a dictionary containing the
112 default activation function, and a layer-specific entry detailing exceptions from the default. 'None' entries
113 are interpreted as linear layers.
115 .. Example:
116 activation_funcs:
117 default: relu
118 layer_specific:
119 0: ~
120 2: tanh
121 3:
122 name: HardTanh
123 args:
124 - -2 # min_value
125 - +2 # max_value
126 """
128 # Use default activation function on all layers
129 _funcs = [get_single_layer_func(cfg.get("default"))] * (n_layers + 1)
131 # Change activation functions on specified layers
132 _layer_specific = cfg.get("layer_specific", {})
133 for layer_id, layer_cfg in _layer_specific.items():
134 _funcs[layer_id] = get_single_layer_func(layer_cfg)
136 return _funcs
139def get_bias(n_layers: int, cfg: dict) -> List[Any]:
140 """Extracts the bias initialisation settings from the config. The config is a dictionary containing the
141 default, and a layer-specific entry detailing exceptions from the default. 'None' entries
142 are interpreted as unbiased layers.
144 .. Example:
145 biases:
146 default: ~
147 layer_specific:
148 0: [-1, 1]
149 3: [2, 3]
150 """
152 # Use the default value on all layers
153 biases = [cfg.get("default")] * (n_layers + 1)
155 # Amend bias on specified layers
156 _layer_specific = cfg.get("layer_specific", {})
157 for layer_id, layer_bias in _layer_specific.items():
158 biases[layer_id] = layer_bias
160 return biases
163# -----------------------------------------------------------------------------
164# -- Neural net class ---------------------------------------------------------
165# -----------------------------------------------------------------------------
166class BaseNN(nn.Module):
168 def __init__(
169 self,
170 *,
171 input_size: int,
172 output_size: int,
173 num_layers: int,
174 nodes_per_layer: dict,
175 activation_funcs: dict,
176 biases: dict,
177 optimizer: str = "Adam",
178 learning_rate: float = 0.002,
179 optimizer_kwargs: dict = {},
180 **__,
181 ):
182 """ Base neural network architecture class.
184 :param input_size: the number of input values
185 :param output_size: the number of output values
186 :param num_layers: the number of hidden layers
187 :param nodes_per_layer: a dictionary specifying the number of nodes per layer
188 :param activation_funcs: a dictionary specifying the activation functions to use
189 :param biases: a dictionary containing the initialisation parameters for the bias
190 :param prior (optional): initial prior distribution of the parameters. If given, the neural net will
191 initially output a random value within that distribution.
192 :param prior_tol (optional): the tolerance with which the prior distribution should be met
193 :param prior_max_iter (optional): maximum number of training iterations to hit the prior target
194 :param optimizer: the name of the optimizer to use. Default is the torch.optim.Adam optimizer.
195 :param learning_rate: the learning rate of the optimizer. Default is 1e-3.
196 :param __: Additional model parameters (ignored)
197 """
199 super().__init__()
200 self.flatten = nn.Flatten()
202 self.input_dim = input_size
203 self.output_dim = output_size
204 self.hidden_dim = num_layers
206 # Get architecture, activation functions, and layer bias
207 self.architecture = get_architecture(
208 input_size, output_size, num_layers, nodes_per_layer
209 )
210 self.activation_funcs = get_activation_funcs(num_layers, activation_funcs)
211 self.bias = get_bias(num_layers, biases)
213 # Add the neural net layers
214 self.layers = nn.ModuleList()
215 for i in range(len(self.architecture) - 1):
216 layer = nn.Linear(
217 self.architecture[i],
218 self.architecture[i + 1],
219 bias=self.bias[i] is not None,
220 )
222 # Initialise the biases of the layers with a uniform distribution
223 if self.bias[i] is not None:
224 # Use the pytorch default if indicated
225 if self.bias[i] == "default":
226 torch.nn.init.uniform_(layer.bias)
227 # Initialise the bias on explicitly provided intervals
228 else:
229 torch.nn.init.uniform_(layer.bias, self.bias[i][0], self.bias[i][1])
231 self.layers.append(layer)
233 # Get the optimizer
234 self.optimizer = OPTIMIZERS[optimizer](
235 self.parameters(), lr=learning_rate, **optimizer_kwargs
236 )
238class FeedForwardNN(BaseNN):
240 def __init__(
241 self,
242 *,
243 input_size: int,
244 output_size: int,
245 num_layers: int,
246 nodes_per_layer: dict,
247 activation_funcs: dict,
248 biases: dict,
249 prior: Union[list, dict] = None,
250 prior_max_iter: int = 500,
251 prior_tol: float = 1e-5,
252 optimizer: str = "Adam",
253 learning_rate: float = 0.002,
254 optimizer_kwargs: dict = {},
255 **__,
256 ):
257 """ Standard feed-forward architecture neural network class.
259 :param input_size: the number of input values
260 :param output_size: the number of output values
261 :param num_layers: the number of hidden layers
262 :param nodes_per_layer: a dictionary specifying the number of nodes per layer
263 :param activation_funcs: a dictionary specifying the activation functions to use
264 :param biases: a dictionary containing the initialisation parameters for the bias
265 :param prior (optional): initial prior distribution of the parameters. If given, the neural net will
266 initially output a random value within that distribution.
267 :param prior_tol (optional): the tolerance with which the prior distribution should be met
268 :param prior_max_iter (optional): maximum number of training iterations to hit the prior target
269 :param optimizer: the name of the optimizer to use. Default is the torch.optim.Adam optimizer.
270 :param learning_rate: the learning rate of the optimizer. Default is 1e-3.
271 :param __: Additional model parameters (ignored)
272 """
274 super().__init__(input_size=input_size,
275 output_size=output_size,
276 num_layers=num_layers,
277 nodes_per_layer=nodes_per_layer,
278 activation_funcs=activation_funcs,
279 biases=biases,
280 optimizer=optimizer,
281 learning_rate=learning_rate,
282 optimizer_kwargs=optimizer_kwargs)
284 # Get the initial distribution and initialise
285 self.prior_distribution = prior
286 self.initialise_to_prior(tol=prior_tol, max_iter=prior_max_iter)
288 def initialise_to_prior(self, *, tol: float = 1e-5, max_iter: int = 500) -> None:
289 """Initialises the neural net to output values following a prior distribution. The random tensor is drawn
290 following a prior distribution and the neural network trained to output that value. Training is performed
291 until the neural network output matches the drawn value (which typically only takes a few seconds), or until
292 a maximum iteration count is reached.
294 :param tol: the target error on the neural net initial output and drawn value.
295 :param max_iter: maximum number of training steps to perform in the while loop
296 """
298 # If not initial distribution is given, nothing happens
299 if self.prior_distribution is None:
300 return
302 # Draw a target tensor following the given prior distribution
303 target = random_tensor(self.prior_distribution, size=(self.output_dim,))
305 # Generate a prediction and train the net to output the given target
306 prediction = self.forward(torch.rand(self.input_dim))
307 iter = 0
309 # Use a separate optimizer for the training
310 optim = torch.optim.Adam(self.parameters(), lr=0.002)
311 while torch.norm(prediction - target) > tol and iter < max_iter:
312 prediction = self.forward(torch.rand(self.input_dim))
313 loss = torch.nn.functional.mse_loss(target, prediction, reduction="sum")
314 loss.backward()
315 optim.step()
316 optim.zero_grad()
317 iter += 1
319 # ... Evaluation functions .........................................................................................
321 # The model forward pass
322 def forward(self, x):
323 for i in range(len(self.layers)):
324 if self.activation_funcs[i] is None:
325 x = self.layers[i](x)
326 else:
327 x = self.activation_funcs[i](self.layers[i](x))
328 return x
330class RNN(BaseNN):
332 def __init__(
333 self,
334 *,
335 input_size: int,
336 output_size: int,
337 latent_dim: int,
338 num_layers: int,
339 nodes_per_layer: dict,
340 activation_funcs: dict,
341 latent_activation_func: Union[str, dict] = 'tanh',
342 biases: dict,
343 initial_latent_state: torch.Tensor = None,
344 optimizer: str = "Adam",
345 learning_rate: float = 0.002,
346 optimizer_kwargs: dict = {},
347 **__,
348 ):
349 """ Vanilla recurrent neural network with a z-dimensional latent dimension.
351 :param input_size: the number of input values
352 :param output_size: the number of output values
353 :param latent_dim: latent dimension
354 :param num_layers: the number of hidden layers
355 :param nodes_per_layer: a dictionary specifying the number of nodes per layer
356 :param activation_funcs: a dictionary specifying the activation functions to use
357 :param latent_activation_func: a dictionary specifying the activation function to use on the latent state.
358 Default is hyperbolic tangent.
359 :param biases: a dictionary containing the initialisation parameters for the bias
360 :param optimizer: the name of the optimizer to use. Default is the torch.optim.Adam optimizer.
361 :param learning_rate: the learning rate of the optimizer. Default is 1e-3.
362 :param __: Additional model parameters (ignored)
363 """
365 super().__init__(input_size=input_size + latent_dim,
366 output_size=output_size + latent_dim,
367 num_layers=num_layers,
368 nodes_per_layer=nodes_per_layer,
369 activation_funcs=activation_funcs,
370 biases=biases,
371 optimizer=optimizer,
372 learning_rate=learning_rate,
373 optimizer_kwargs=optimizer_kwargs)
374 self.latent_dim = latent_dim
375 self.z = initial_latent_state if initial_latent_state is not None else torch.zeros(latent_dim)
376 self.z0 = self.z.clone()
378 # Activation function to use on the hidden state
379 self.latent_activation_func = get_single_layer_func(latent_activation_func)
380 if self.latent_activation_func is not None:
381 f = self.activation_funcs[-1]
382 self.activation_funcs[-1] = lambda x: torch.cat((f(x[:-self.latent_dim]), self.latent_activation_func(x[-self.latent_dim:])))
384 # ... Evaluation functions .........................................................................................
385 # The model forward pass
386 def forward(self, x, z = None):
388 # 2D case: recursively apply 1D case
389 if x.dim() == 2:
390 return torch.stack([self.forward(x[i]) for i in range(len(x))])
392 # 1D case
393 if z is None:
394 x = torch.cat([x, self.z])
395 else:
396 x = torch.cat([x, z])
397 for i in range(len(self.layers)):
398 if self.activation_funcs[i] is None:
399 x = self.layers[i](x)
400 else:
401 x = self.activation_funcs[i](self.layers[i](x))
402 self.z = x[-self.latent_dim:]
403 return x[:-self.latent_dim]
405 def reset_hidden_state(self, z = None):
406 self.z = self.z0.clone() if z is None else z.clone()
408class GRU:
409 pass
411class LSTM:
412 pass