Coverage for models/SIR/NN.py: 21%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4import coloredlogs 

5import h5py as h5 

6import numpy as np 

7import torch 

8from dantro import logging 

9from dantro._import_tools import import_module_from_path 

10 

11sys.path.append(up(up(__file__))) 

12sys.path.append(up(up(up(__file__)))) 

13 

14SIR = import_module_from_path(mod_path=up(up(__file__)), mod_str="SIR") 

15base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

16 

17log = logging.getLogger(__name__) 

18coloredlogs.install(fmt="%(levelname)s %(message)s", level="INFO", logger=log) 

19 

20 

21class SIR_NN: 

22 def __init__( 

23 self, 

24 *, 

25 rng: np.random.Generator, 

26 h5group: h5.Group, 

27 neural_net: base.NeuralNet, 

28 loss_function: dict, 

29 to_learn: list, 

30 true_parameters: dict = {}, 

31 write_every: int = 1, 

32 write_start: int = 1, 

33 training_data: torch.Tensor, 

34 batch_size: int, 

35 scaling_factors: dict = {}, 

36 **__, 

37 ): 

38 """Initialize the model instance with a previously constructed RNG and 

39 HDF5 group to write the output data to. 

40 

41 :param rng (np.random.Generator): The shared RNG 

42 :param h5group (h5.Group): The output file group to write data to 

43 :param neural_net: The neural network 

44 :param loss_function (dict): the loss function to use 

45 :param to_learn: the list of parameter names to learn 

46 :param true_parameters: the dictionary of true parameters 

47 :param training_data: the training data to use 

48 :param write_every: write every iteration 

49 :param write_start: iteration at which to start writing 

50 :param batch_size: epoch batch size: instead of calculating the entire time series, 

51 only a subsample of length batch_size can be used. The likelihood is then 

52 scaled up accordingly. 

53 :param scaling_factors: factors by which the parameters are to be scaled 

54 """ 

55 self._h5group = h5group 

56 self._rng = rng 

57 

58 self.neural_net = neural_net 

59 self.neural_net.optimizer.zero_grad() 

60 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()]( 

61 loss_function.get("args", None), **loss_function.get("kwargs", {}) 

62 ) 

63 

64 self.current_loss = torch.tensor(0.0) 

65 

66 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

67 self.true_parameters = { 

68 key: torch.tensor(val, dtype=torch.float) 

69 for key, val in true_parameters.items() 

70 } 

71 self.current_predictions = torch.zeros(len(to_learn)) 

72 

73 self.training_data = training_data 

74 

75 # Generate the batch ids 

76 batches = np.arange(0, self.training_data.shape[0], batch_size) 

77 if len(batches) == 1: 

78 batches = np.append(batches, training_data.shape[0] - 1) 

79 else: 

80 if batches[-1] != training_data.shape[0] - 1: 

81 batches = np.append(batches, training_data.shape[0] - 1) 

82 

83 self.batches = batches 

84 

85 self.scaling_factors = scaling_factors 

86 

87 # --- Set up chunked dataset to store the state data in -------------------------------------------------------- 

88 # Write the loss after every batch 

89 self._dset_loss = self._h5group.create_dataset( 

90 "loss", 

91 (0,), 

92 maxshape=(None,), 

93 chunks=True, 

94 compression=3, 

95 ) 

96 self._dset_loss.attrs["dim_names"] = ["batch"] 

97 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step" 

98 self._dset_loss.attrs["coords__batch"] = [write_start, write_every] 

99 

100 # Write the computation time of every epoch 

101 self.dset_time = self._h5group.create_dataset( 

102 "computation_time", 

103 (0,), 

104 maxshape=(None,), 

105 chunks=True, 

106 compression=3, 

107 ) 

108 self.dset_time.attrs["dim_names"] = ["epoch"] 

109 self.dset_time.attrs["coords_mode__epoch"] = "trivial" 

110 

111 # Write the parameter predictions after every batch 

112 self.dset_parameters = self._h5group.create_dataset( 

113 "parameters", 

114 (0, len(self.to_learn.keys())), 

115 maxshape=(None, len(self.to_learn.keys())), 

116 chunks=True, 

117 compression=3, 

118 ) 

119 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"] 

120 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step" 

121 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every] 

122 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

123 self.dset_parameters.attrs["coords__parameter"] = to_learn 

124 

125 # The training data and batch ids 

126 self.training_data = training_data 

127 

128 batches = np.arange(0, training_data.shape[0], batch_size) 

129 if len(batches) == 1: 

130 batches = np.append(batches, training_data.shape[0] - 1) 

131 else: 

132 if batches[-1] != training_data.shape[0] - 1: 

133 batches = np.append(batches, training_data.shape[0] - 1) 

134 self.batches = batches 

135 

136 # Batches processed 

137 self._time = 0 

138 self._write_every = write_every 

139 self._write_start = write_start 

140 

141 def epoch(self): 

142 """ 

143 An epoch is a pass over the entire dataset. The dataset is processed in batches, where B < L is the batch 

144 number. After each batch, the parameters of the neural network are updated. For example, if L = 100 and 

145 B = 50, two passes are made over the dataset -- one over the first 50 steps, and one 

146 over the second 50. The entire time series is processed, even if L is not divisible into equal segments of 

147 length B. For instance, is B is 30, the time series is processed in 3 steps of 30 and one of 10. 

148 

149 """ 

150 

151 # Process the training data in batches 

152 for batch_no, batch_idx in enumerate(self.batches[:-1]): 

153 predicted_parameters = self.neural_net( 

154 torch.flatten(self.training_data[batch_idx]) 

155 ) 

156 

157 # Get the parameters: infection rate, recovery time, noise variance 

158 p = ( 

159 self.scaling_factors.get("p_infect", 1.0) * predicted_parameters[self.to_learn["p_infect"]] 

160 if "p_infect" in self.to_learn.keys() 

161 else self.true_parameters["p_infect"] 

162 ) 

163 t = ( 

164 self.scaling_factors.get("t_infectious", 1.0) * predicted_parameters[self.to_learn["t_infectious"]] 

165 if "t_infectious" in self.to_learn.keys() 

166 else self.true_parameters["t_infectious"] 

167 ) 

168 sigma = ( 

169 self.scaling_factors.get("sigma", 1.0) * predicted_parameters[self.to_learn["sigma"]] 

170 if "sigma" in self.to_learn.keys() 

171 else self.true_parameters["sigma"] 

172 ) 

173 

174 current_densities = self.training_data[batch_idx].clone() 

175 current_densities.requires_grad_(True) 

176 

177 loss = torch.tensor(0.0, requires_grad=True) 

178 

179 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1): 

180 # Recovery rate 

181 tau = 1 / t * torch.sigmoid(1000 * (ele / t - 1)) 

182 

183 # Random noise 

184 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0)) 

185 

186 # Solve the ODE 

187 current_densities = torch.clip( 

188 current_densities 

189 + torch.stack( 

190 [ 

191 (-p * current_densities[0] - sigma * w) 

192 * current_densities[1], 

193 (p * current_densities[0] + sigma * w - tau) 

194 * current_densities[1], 

195 tau * current_densities[1], 

196 ] 

197 ), 

198 0.0, 

199 1.0, 

200 ) 

201 

202 # Calculate loss 

203 loss = loss + self.loss_function( 

204 current_densities, self.training_data[ele] 

205 ) * ( 

206 self.training_data.shape[0] 

207 / (self.batches[batch_no + 1] - batch_idx) 

208 ) 

209 

210 loss.backward() 

211 self.neural_net.optimizer.step() 

212 self.neural_net.optimizer.zero_grad() 

213 self.current_loss = loss.clone().detach().cpu().numpy().item() 

214 self.current_predictions = predicted_parameters.clone().detach().cpu() 

215 

216 # Scale the parameters 

217 for param in self.to_learn.keys(): 

218 self.current_predictions[self.to_learn[param]] *= self.scaling_factors.get(param, 1.0) 

219 self._time += 1 

220 self.write_data() 

221 

222 def write_data(self): 

223 """Write the current state (loss and parameter predictions) into the state dataset. 

224 

225 In the case of HDF5 data writing that is used here, this requires to 

226 extend the dataset size prior to writing; this way, the newly written 

227 data is always in the last row of the dataset. 

228 """ 

229 if self._time >= self._write_start and (self._time % self._write_every == 0): 

230 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0) 

231 self._dset_loss[-1] = self.current_loss 

232 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

233 self.dset_parameters[-1, :] = [ 

234 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys() 

235 ]