Coverage for models / SIR / ensemble_training / NN.py: 15%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import sys 

2from os.path import dirname as up 

3import h5py as h5 

4import numpy as np 

5import torch 

6from dantro._import_tools import import_module_from_path 

7 

8sys.path.append(up(up(up(__file__)))) 

9 

10base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

11 

12class SIR_NN: 

13 def __init__( 

14 self, 

15 *, 

16 rng: np.random.Generator, 

17 h5group: h5.Group, 

18 neural_net: base.BaseNN, 

19 loss_function: dict, 

20 to_learn: list, 

21 true_parameters: dict = {}, 

22 write_every: int = 1, 

23 write_start: int = 1, 

24 training_data: torch.Tensor, 

25 batch_size: int, 

26 scaling_factors: dict = {}, 

27 **__, 

28 ): 

29 """Initialize the model instance with a previously constructed RNG and 

30 HDF5 group to write the output data to. 

31 

32 :param rng (np.random.Generator): The shared RNG 

33 :param h5group (h5.Group): The output file group to write data to 

34 :param neural_net: The neural network 

35 :param loss_function (dict): the loss function to use 

36 :param to_learn: the list of parameter names to learn 

37 :param true_parameters: the dictionary of true parameters 

38 :param training_data: the training data to use 

39 :param write_every: write every iteration 

40 :param write_start: iteration at which to start writing 

41 :param batch_size: epoch batch size: instead of calculating the entire time series, 

42 only a subsample of length batch_size can be used. The likelihood is then 

43 scaled up accordingly. 

44 :param scaling_factors: factors by which the parameters are to be scaled 

45 """ 

46 self._h5group = h5group 

47 self._rng = rng 

48 

49 self.neural_net = neural_net 

50 self.neural_net.optimizer.zero_grad() 

51 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()]( 

52 loss_function.get("args", None), **loss_function.get("kwargs", {}) 

53 ) 

54 

55 self.current_loss = torch.tensor(0.0) 

56 

57 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

58 self.true_parameters = { 

59 key: torch.tensor(val, dtype=torch.float) 

60 for key, val in true_parameters.items() 

61 } 

62 self.current_predictions = torch.zeros(len(to_learn)) 

63 

64 self.training_data = training_data 

65 

66 # Generate the batch ids 

67 batches = np.arange(0, self.training_data.shape[0], batch_size) 

68 if len(batches) == 1: 

69 batches = np.append(batches, training_data.shape[0] - 1) 

70 else: 

71 if batches[-1] != training_data.shape[0] - 1: 

72 batches = np.append(batches, training_data.shape[0] - 1) 

73 

74 self.batches = batches 

75 

76 self.scaling_factors = scaling_factors 

77 

78 # --- Set up chunked dataset to store the state data in -------------------------------------------------------- 

79 # Write the loss after every batch 

80 self._dset_loss = self._h5group.create_dataset( 

81 "loss", 

82 (0,), 

83 maxshape=(None,), 

84 chunks=True, 

85 compression=3, 

86 ) 

87 self._dset_loss.attrs["dim_names"] = ["batch"] 

88 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step" 

89 self._dset_loss.attrs["coords__batch"] = [write_start, write_every] 

90 

91 # Write the computation time of every epoch 

92 self.dset_time = self._h5group.create_dataset( 

93 "computation_time", 

94 (0,), 

95 maxshape=(None,), 

96 chunks=True, 

97 compression=3, 

98 ) 

99 self.dset_time.attrs["dim_names"] = ["epoch"] 

100 self.dset_time.attrs["coords_mode__epoch"] = "trivial" 

101 

102 # Write the parameter predictions after every batch 

103 self.dset_parameters = self._h5group.create_dataset( 

104 "parameters", 

105 (0, len(self.to_learn.keys())), 

106 maxshape=(None, len(self.to_learn.keys())), 

107 chunks=True, 

108 compression=3, 

109 ) 

110 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"] 

111 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step" 

112 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every] 

113 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

114 self.dset_parameters.attrs["coords__parameter"] = to_learn 

115 

116 # The training data and batch ids 

117 self.training_data = training_data 

118 

119 batches = np.arange(0, training_data.shape[0], batch_size) 

120 if len(batches) == 1: 

121 batches = np.append(batches, training_data.shape[0] - 1) 

122 else: 

123 if batches[-1] != training_data.shape[0] - 1: 

124 batches = np.append(batches, training_data.shape[0] - 1) 

125 self.batches = batches 

126 

127 # Batches processed 

128 self._time = 0 

129 self._write_every = write_every 

130 self._write_start = write_start 

131 

132 def epoch(self): 

133 """ 

134 An epoch is a pass over the entire dataset. The dataset is processed in batches, where B < L is the batch 

135 number. After each batch, the parameters of the neural network are updated. For example, if L = 100 and 

136 B = 50, two passes are made over the dataset -- one over the first 50 steps, and one 

137 over the second 50. The entire time series is processed, even if L is not divisible into equal segments of 

138 length B. For instance, is B is 30, the time series is processed in 3 steps of 30 and one of 10. 

139 

140 """ 

141 

142 # Process the training data in batches 

143 for batch_no, batch_idx in enumerate(self.batches[:-1]): 

144 predicted_parameters = self.neural_net( 

145 torch.flatten(self.training_data[batch_idx]) 

146 ) 

147 

148 # Get the parameters: infection rate, recovery time, noise variance 

149 p = ( 

150 self.scaling_factors.get("p_infect", 1.0) * predicted_parameters[self.to_learn["p_infect"]] 

151 if "p_infect" in self.to_learn.keys() 

152 else self.true_parameters["p_infect"] 

153 ) 

154 t = ( 

155 self.scaling_factors.get("t_infectious", 1.0) * predicted_parameters[self.to_learn["t_infectious"]] 

156 if "t_infectious" in self.to_learn.keys() 

157 else self.true_parameters["t_infectious"] 

158 ) 

159 sigma = ( 

160 self.scaling_factors.get("sigma", 1.0) * predicted_parameters[self.to_learn["sigma"]] 

161 if "sigma" in self.to_learn.keys() 

162 else self.true_parameters["sigma"] 

163 ) 

164 

165 current_densities = self.training_data[batch_idx].clone() 

166 current_densities.requires_grad_(True) 

167 

168 loss = torch.tensor(0.0, requires_grad=True) 

169 

170 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1): 

171 # Recovery rate 

172 tau = 1 / t * torch.sigmoid(1000 * (ele / t - 1)) 

173 

174 # Random noise 

175 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0)) 

176 

177 # Solve the ODE 

178 current_densities = torch.clip( 

179 current_densities 

180 + torch.stack( 

181 [ 

182 (-p * current_densities[0] - sigma * w) 

183 * current_densities[1], 

184 (p * current_densities[0] + sigma * w - tau) 

185 * current_densities[1], 

186 tau * current_densities[1], 

187 ] 

188 ), 

189 0.0, 

190 1.0, 

191 ) 

192 

193 # Calculate loss 

194 loss = loss + self.loss_function( 

195 current_densities, self.training_data[ele] 

196 ) * ( 

197 self.training_data.shape[0] 

198 / (self.batches[batch_no + 1] - batch_idx) 

199 ) 

200 

201 loss.backward() 

202 self.neural_net.optimizer.step() 

203 self.neural_net.optimizer.zero_grad() 

204 self.current_loss = loss.clone().detach().cpu().numpy().item() 

205 self.current_predictions = predicted_parameters.clone().detach().cpu() 

206 

207 # Scale the parameters 

208 for param in self.to_learn.keys(): 

209 self.current_predictions[self.to_learn[param]] *= self.scaling_factors.get(param, 1.0) 

210 self._time += 1 

211 self.write_data() 

212 

213 def write_data(self): 

214 """Write the current state (loss and parameter predictions) into the state dataset. 

215 

216 In the case of HDF5 data writing that is used here, this requires to 

217 extend the dataset size prior to writing; this way, the newly written 

218 data is always in the last row of the dataset. 

219 """ 

220 if self._time >= self._write_start and (self._time % self._write_every == 0): 

221 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0) 

222 self._dset_loss[-1] = self.current_loss 

223 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

224 self.dset_parameters[-1, :] = [ 

225 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys() 

226 ]