Coverage for include/neural_net.py: 98%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1from typing import Any, List, Sequence, Union 

2 

3import torch 

4from torch import nn 

5 

6from .utils import random_tensor 

7 

8# ---------------------------------------------------------------------------------------------------------------------- 

9# -- NN utility functions ---------------------------------------------------------------------------------------------- 

10# ---------------------------------------------------------------------------------------------------------------------- 

11 

12 

13def sigmoid(beta=torch.tensor(1.0)): 

14 """Extends the torch.nn.sigmoid activation function by allowing for a slope parameter.""" 

15 

16 return lambda x: torch.sigmoid(beta * x) 

17 

18 

19# Pytorch activation functions. 

20# Pairs of activation functions and whether they are part of the torch.nn module, in which case they must be called 

21# via func(*args, **kwargs)(x). 

22 

23 

24ACTIVATION_FUNCS = { 

25 "abs": [torch.abs, False], 

26 "celu": [torch.nn.CELU, True], 

27 "cos": [torch.cos, False], 

28 "cosine": [torch.cos, False], 

29 "elu": [torch.nn.ELU, True], 

30 "gelu": [torch.nn.GELU, True], 

31 "hardshrink": [torch.nn.Hardshrink, True], 

32 "hardsigmoid": [torch.nn.Hardsigmoid, True], 

33 "hardswish": [torch.nn.Hardswish, True], 

34 "hardtanh": [torch.nn.Hardtanh, True], 

35 "leakyrelu": [torch.nn.LeakyReLU, True], 

36 "linear": [None, False], 

37 "logsigmoid": [torch.nn.LogSigmoid, True], 

38 "mish": [torch.nn.Mish, True], 

39 "prelu": [torch.nn.PReLU, True], 

40 "relu": [torch.nn.ReLU, True], 

41 "rrelu": [torch.nn.RReLU, True], 

42 "selu": [torch.nn.SELU, True], 

43 "sigmoid": [sigmoid, True], 

44 "silu": [torch.nn.SiLU, True], 

45 "sin": [torch.sin, False], 

46 "sine": [torch.sin, False], 

47 "softplus": [torch.nn.Softplus, True], 

48 "softshrink": [torch.nn.Softshrink, True], 

49 "softsign": [torch.nn.Softsign, True], 

50 "swish": [torch.nn.SiLU, True], 

51 "tanh": [torch.nn.Tanh, True], 

52 "tanhshrink": [torch.nn.Tanhshrink, True], 

53 "threshold": [torch.nn.Threshold, True], 

54} 

55 

56 

57def get_architecture( 

58 input_size: int, output_size: int, n_layers: int, cfg: dict 

59) -> List[int]: 

60 # Apply default to all hidden layers 

61 _nodes = [cfg.get("default")] * n_layers 

62 

63 # Update layer-specific settings 

64 _layer_specific = cfg.get("layer_specific", {}) 

65 for layer_id, layer_size in _layer_specific.items(): 

66 _nodes[layer_id] = layer_size 

67 

68 return [input_size] + _nodes + [output_size] 

69 

70 

71def get_activation_funcs(n_layers: int, cfg: dict) -> List[callable]: 

72 """Extracts the activation functions from the config. The config is a dictionary containing the 

73 default activation function, and a layer-specific entry detailing exceptions from the default. 'None' entries 

74 are interpreted as linear layers. 

75 

76 .. Example: 

77 activation_funcs: 

78 default: relu 

79 layer_specific: 

80 0: ~ 

81 2: tanh 

82 3: 

83 name: HardTanh 

84 args: 

85 - -2 # min_value 

86 - +2 # max_value 

87 """ 

88 

89 def _single_layer_func(layer_cfg: Union[str, dict]) -> callable: 

90 """Return the activation function from an entry for a single layer""" 

91 

92 # Entry is a single string 

93 if isinstance(layer_cfg, str): 

94 _f = ACTIVATION_FUNCS[layer_cfg.lower()] 

95 if _f[1]: 

96 return _f[0]() 

97 else: 

98 return _f[0] 

99 

100 # Entry is a dictionary containing args and kwargs 

101 elif isinstance(layer_cfg, dict): 

102 _f = ACTIVATION_FUNCS[layer_cfg.get("name").lower()] 

103 if _f[1]: 

104 return _f[0](*layer_cfg.get("args", ()), **layer_cfg.get("kwargs", {})) 

105 else: 

106 return _f[0] 

107 

108 elif layer_cfg is None: 

109 _f = ACTIVATION_FUNCS["linear"][0] 

110 

111 else: 

112 raise ValueError(f"Unrecognized activation function {cfg}!") 

113 

114 # Use default activation function on all layers 

115 _funcs = [_single_layer_func(cfg.get("default"))] * (n_layers + 1) 

116 

117 # Change activation functions on specified layers 

118 _layer_specific = cfg.get("layer_specific", {}) 

119 for layer_id, layer_cfg in _layer_specific.items(): 

120 _funcs[layer_id] = _single_layer_func(layer_cfg) 

121 

122 return _funcs 

123 

124 

125def get_bias(n_layers: int, cfg: dict) -> List[Any]: 

126 """Extracts the bias initialisation settings from the config. The config is a dictionary containing the 

127 default, and a layer-specific entry detailing exceptions from the default. 'None' entries 

128 are interpreted as unbiased layers. 

129 

130 .. Example: 

131 biases: 

132 default: ~ 

133 layer_specific: 

134 0: [-1, 1] 

135 3: [2, 3] 

136 """ 

137 

138 # Use the default value on all layers 

139 biases = [cfg.get("default")] * (n_layers + 1) 

140 

141 # Amend bias on specified layers 

142 _layer_specific = cfg.get("layer_specific", {}) 

143 for layer_id, layer_bias in _layer_specific.items(): 

144 biases[layer_id] = layer_bias 

145 

146 return biases 

147 

148 

149# ----------------------------------------------------------------------------- 

150# -- Neural net class --------------------------------------------------------- 

151# ----------------------------------------------------------------------------- 

152 

153 

154class NeuralNet(nn.Module): 

155 OPTIMIZERS = { 

156 "Adagrad": torch.optim.Adagrad, 

157 "Adam": torch.optim.Adam, 

158 "AdamW": torch.optim.AdamW, 

159 "SparseAdam": torch.optim.SparseAdam, 

160 "Adamax": torch.optim.Adamax, 

161 "ASGD": torch.optim.ASGD, 

162 "LBFGS": torch.optim.LBFGS, 

163 "NAdam": torch.optim.NAdam, 

164 "RAdam": torch.optim.RAdam, 

165 "RMSprop": torch.optim.RMSprop, 

166 "Rprop": torch.optim.Rprop, 

167 "SGD": torch.optim.SGD, 

168 } 

169 

170 def __init__( 

171 self, 

172 *, 

173 input_size: int, 

174 output_size: int, 

175 num_layers: int, 

176 nodes_per_layer: dict, 

177 activation_funcs: dict, 

178 biases: dict, 

179 prior: Union[list, dict] = None, 

180 prior_max_iter: int = 500, 

181 prior_tol: float = 1e-5, 

182 optimizer: str = "Adam", 

183 learning_rate: float = 0.002, 

184 optimizer_kwargs: dict = {}, 

185 **__, 

186 ): 

187 """ 

188 

189 :param input_size: the number of input values 

190 :param output_size: the number of output values 

191 :param num_layers: the number of hidden layers 

192 :param nodes_per_layer: a dictionary specifying the number of nodes per layer 

193 :param activation_funcs: a dictionary specifying the activation functions to use 

194 :param biases: a dictionary containing the initialisation parameters for the bias 

195 :param prior (optional): initial prior distribution of the parameters. If given, the neural net will 

196 initially output a random value within that distribution. 

197 :param prior_tol (optional): the tolerance with which the prior distribution should be met 

198 :param prior_max_iter (optional): maximum number of training iterations to hit the prior target 

199 :param optimizer: the name of the optimizer to use. Default is the torch.optim.Adam optimizer. 

200 :param learning_rate: the learning rate of the optimizer. Default is 1e-3. 

201 :param __: Additional model parameters (ignored) 

202 """ 

203 

204 super().__init__() 

205 self.flatten = nn.Flatten() 

206 

207 self.input_dim = input_size 

208 self.output_dim = output_size 

209 self.hidden_dim = num_layers 

210 

211 # Get architecture, activation functions, and layer bias 

212 self.architecture = get_architecture( 

213 input_size, output_size, num_layers, nodes_per_layer 

214 ) 

215 self.activation_funcs = get_activation_funcs(num_layers, activation_funcs) 

216 self.bias = get_bias(num_layers, biases) 

217 

218 # Add the neural net layers 

219 self.layers = nn.ModuleList() 

220 for i in range(len(self.architecture) - 1): 

221 layer = nn.Linear( 

222 self.architecture[i], 

223 self.architecture[i + 1], 

224 bias=self.bias[i] is not None, 

225 ) 

226 

227 # Initialise the biases of the layers with a uniform distribution 

228 if self.bias[i] is not None: 

229 # Use the pytorch default if indicated 

230 if self.bias[i] == "default": 

231 torch.nn.init.uniform_(layer.bias) 

232 # Initialise the bias on explicitly provided intervals 

233 else: 

234 torch.nn.init.uniform_(layer.bias, self.bias[i][0], self.bias[i][1]) 

235 

236 self.layers.append(layer) 

237 

238 # Get the optimizer 

239 self.optimizer = self.OPTIMIZERS[optimizer]( 

240 self.parameters(), lr=learning_rate, **optimizer_kwargs 

241 ) 

242 

243 # Get the initial distribution and initialise 

244 self.prior_distribution = prior 

245 self.initialise_to_prior(tol=prior_tol, max_iter=prior_max_iter) 

246 

247 def initialise_to_prior(self, *, tol: float = 1e-5, max_iter: int = 500) -> None: 

248 """Initialises the neural net to output values following a prior distribution. The random tensor is drawn 

249 following a prior distribution and the neural network trained to output that value. Training is performed 

250 until the neural network output matches the drawn value (which typically only takes a few seconds), or until 

251 a maximum iteration count is reached. 

252 

253 :param tol: the target error on the neural net initial output and drawn value. 

254 :param max_iter: maximum number of training steps to perform in the while loop 

255 """ 

256 

257 # If not initial distribution is given, nothing happens 

258 if self.prior_distribution is None: 

259 return 

260 

261 # Draw a target tensor following the given prior distribution 

262 target = random_tensor(self.prior_distribution, size=(self.output_dim,)) 

263 

264 # Generate a prediction and train the net to output the given target 

265 prediction = self.forward(torch.rand(self.input_dim)) 

266 iter = 0 

267 

268 # Use a separate optimizer for the training 

269 optim = torch.optim.Adam(self.parameters(), lr=0.002) 

270 while torch.norm(prediction - target) > tol and iter < max_iter: 

271 prediction = self.forward(torch.rand(self.input_dim)) 

272 loss = torch.nn.functional.mse_loss(target, prediction, reduction="sum") 

273 loss.backward() 

274 optim.step() 

275 optim.zero_grad() 

276 iter += 1 

277 

278 # ... Evaluation functions ......................................................................................... 

279 

280 # The model forward pass 

281 def forward(self, x): 

282 for i in range(len(self.layers)): 

283 if self.activation_funcs[i] is None: 

284 x = self.layers[i](x) 

285 else: 

286 x = self.activation_funcs[i](self.layers[i](x)) 

287 return x