Coverage for models / Covid / ensemble_training / NN.py: 17%

96 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4import h5py as h5 

5import numpy as np 

6import torch 

7from dantro import logging 

8from dantro._import_tools import import_module_from_path 

9 

10sys.path.append(up(up(__file__))) 

11sys.path.append(up(up(up(__file__)))) 

12 

13Covid = import_module_from_path(mod_path=up(up(__file__)), mod_str="Covid") 

14base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

15 

16log = logging.getLogger(__name__) 

17 

18# ---------------------------------------------------------------------------------------------------------------------- 

19# Model implementation 

20# ---------------------------------------------------------------------------------------------------------------------- 

21 

22 

23class Covid_NN: 

24 def __init__( 

25 self, 

26 *, 

27 rng: np.random.Generator, 

28 h5group: h5.Group, 

29 neural_net: base.BaseNN, 

30 loss_function: dict, 

31 to_learn: list, 

32 time_dependent_parameters: dict = None, 

33 true_parameters: dict = {}, 

34 dt: float, 

35 k_q: float = 10.25, 

36 Berlin_data_loss: bool = False, 

37 write_every: int = 1, 

38 write_start: int = 1, 

39 training_data: torch.Tensor, 

40 batch_size: int, 

41 scaling_factors: dict = {}, 

42 **__, 

43 ): 

44 """Initialize the model instance with a previously constructed RNG and 

45 HDF5 group to write the output data to. 

46 

47 Args: 

48 rng (np.random.Generator): The shared RNG 

49 h5group (h5.Group): The output file group to write data to 

50 neural_net: The neural network 

51 loss_function (dict): the loss function to use 

52 to_learn: the list of parameter names to learn 

53 time_dependent_parameters: dictionary of time-dependent parameters and their granularity 

54 true_parameters: the dictionary of true parameters 

55 dt: time differential 

56 k_q: contact tracing rate 

57 Berlin_data_loss: whether to use the loss structure unique to the Berlin data 

58 write_every: write every iteration 

59 write_start: iteration at which to start writing 

60 training_data: the training data to use 

61 batch_size: epoch batch size: instead of calculating the entire time series, 

62 only a subsample of length batch_size can be used. The likelihood is then 

63 scaled up accordingly. 

64 scaling_factors: dictionary of scaling factors for the different parameters. Parameter estimates are 

65 multiplied by these to ensure all parameters are roughly of the same order of magnitude 

66 """ 

67 self._h5group = h5group 

68 self._rng = rng 

69 

70 self.neural_net = neural_net 

71 self.neural_net.optimizer.zero_grad() 

72 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()]( 

73 loss_function.get("args", None), **loss_function.get("kwargs", {}) 

74 ) 

75 

76 self.dt = torch.tensor(dt, dtype=torch.float) 

77 self.k_q = torch.tensor(k_q, dtype=torch.float) 

78 self.Berlin_data_loss = Berlin_data_loss 

79 

80 self.current_loss = torch.tensor(0.0) 

81 

82 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

83 self.time_dependent_parameters = ( 

84 time_dependent_parameters if time_dependent_parameters else {} 

85 ) 

86 self.true_parameters = { 

87 key: torch.tensor(val, dtype=torch.float) 

88 for key, val in true_parameters.items() 

89 } 

90 self.all_parameters = set(self.to_learn.keys()) 

91 self.all_parameters.update(self.true_parameters.keys()) 

92 self.current_predictions = torch.zeros(len(self.to_learn), dtype=torch.float) 

93 

94 # Training data 

95 self.training_data = training_data 

96 

97 # Generate the batch ids 

98 batches = np.arange(0, self.training_data.shape[0], batch_size) 

99 if len(batches) == 1: 

100 batches = np.append(batches, training_data.shape[0] - 1) 

101 else: 

102 if batches[-1] != training_data.shape[0] - 1: 

103 batches = np.append(batches, training_data.shape[0] - 1) 

104 

105 self.batches = batches 

106 

107 # --- Set up chunked dataset to store the state data in -------------------------------------------------------- 

108 self._dset_loss = self._h5group.create_dataset( 

109 "loss", 

110 (0,), 

111 maxshape=(None,), 

112 chunks=True, 

113 compression=3, 

114 ) 

115 self._dset_loss.attrs["dim_names"] = ["batch"] 

116 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step" 

117 self._dset_loss.attrs["coords__batch"] = [write_start, write_every] 

118 

119 self.dset_time = self._h5group.create_dataset( 

120 "computation_time", 

121 (0,), 

122 maxshape=(None,), 

123 chunks=True, 

124 compression=3, 

125 ) 

126 self.dset_time.attrs["dim_names"] = ["epoch"] 

127 self.dset_time.attrs["coords_mode__epoch"] = "trivial" 

128 

129 # Create a dataset for the parameter estimates 

130 self.dset_parameters = self._h5group.create_dataset( 

131 "parameters", 

132 (0, len(self.to_learn.keys())), 

133 maxshape=(None, len(self.to_learn.keys())), 

134 chunks=True, 

135 compression=3, 

136 ) 

137 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"] 

138 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step" 

139 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every] 

140 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

141 self.dset_parameters.attrs["coords__parameter"] = to_learn 

142 

143 # -------------------------------------------------------------------------------------------------------------- 

144 # Batches processed 

145 self._time = 0 

146 self._write_every = write_every 

147 self._write_start = write_start 

148 

149 # Calculate the coefficients of each term in the loss function: 

150 # \alpha_i^{-1} = \int T_i(t) dt 

151 alpha = torch.sum(training_data, dim=0) * self.dt 

152 alpha = torch.where(alpha > 0, alpha, torch.tensor(1.0)) 

153 self.alpha = ( 

154 torch.cat([alpha[0:7], torch.sum(alpha[8:11], 0, keepdim=True)], 0) 

155 ) ** (-1) 

156 

157 # Reduced data model 

158 # for idx in [0, 1, 2, 3, 7]: # S, E, I, R, Q are dropped 

159 # self.alpha[idx] = 0 

160 

161 # Get all the jump points 

162 self.jump_points = {} 

163 if self.time_dependent_parameters: 

164 self.jump_points = set( 

165 np.hstack( 

166 [ 

167 np.array(interval).flatten() 

168 for _, interval in self.time_dependent_parameters.items() 

169 ] 

170 ) 

171 ) 

172 if None in self.jump_points: 

173 self.jump_points.remove(None) 

174 

175 # Get the scaling factors 

176 self.scaling_factors = torch.tensor( 

177 list( 

178 { 

179 key: torch.tensor(scaling_factors[key], dtype=torch.float) 

180 if key in scaling_factors.keys() 

181 else torch.tensor(1.0, dtype=torch.float) 

182 for key in self.to_learn.keys() 

183 }.values() 

184 ), 

185 dtype=torch.float, 

186 ) 

187 

188 def epoch(self): 

189 """Trains the model for a single epoch""" 

190 

191 # Process the training data in batches 

192 for batch_no, batch_idx in enumerate(self.batches[:-1]): 

193 # Make a prediction 

194 predicted_parameters = self.neural_net( 

195 torch.flatten(self.training_data[batch_idx]) 

196 ) 

197 

198 # Combine the predicted and true parameters into a dictionary 

199 parameters = { 

200 p: predicted_parameters[self.to_learn[p]] 

201 * self.scaling_factors[self.to_learn[p]] 

202 if p in self.to_learn.keys() 

203 else self.true_parameters[p] 

204 for p in self.all_parameters 

205 } 

206 

207 # Get the initial values 

208 current_densities = self.training_data[batch_idx].clone() 

209 current_densities.requires_grad_(True) 

210 densities = [current_densities] 

211 

212 # Integrate the ODE for B steps 

213 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1): 

214 # Adjust for time-dependency 

215 for key, ranges in self.time_dependent_parameters.items(): 

216 for idx, r in enumerate(ranges): 

217 if not r[1]: 

218 r[1] = len(self.training_data) + 1 

219 if r[0] <= ele < r[1]: 

220 parameters[key] = parameters[key + f"_{idx}"] 

221 break 

222 

223 # Calculate the k_Q parameter from the current CT figures and k_CT estimate 

224 k_Q = self.k_q * parameters["k_CT"] * densities[-1][-1] 

225 

226 # Solve the ODE 

227 densities.append( 

228 torch.clip( 

229 densities[-1] 

230 + torch.stack( 

231 [ 

232 (-parameters["k_E"] * densities[-1][2] - k_Q) 

233 * densities[-1][0] 

234 + parameters["k_S"] * densities[-1][8], 

235 parameters["k_E"] * densities[-1][0] * densities[-1][2] 

236 - (parameters["k_I"] + k_Q) * densities[-1][1], 

237 parameters["k_I"] * densities[-1][1] 

238 - (parameters["k_R"] + parameters["k_SY"] + k_Q) 

239 * densities[-1][2], 

240 parameters["k_R"] 

241 * ( 

242 densities[-1][2] 

243 + densities[-1][4] 

244 + densities[-1][5] 

245 + densities[-1][6] 

246 + densities[-1][10] 

247 ), 

248 parameters["k_SY"] 

249 * (densities[-1][2] + densities[-1][10]) 

250 - (parameters["k_R"] + parameters["k_H"]) 

251 * densities[-1][4], 

252 parameters["k_H"] * densities[-1][4] 

253 - (parameters["k_R"] + parameters["k_C"]) 

254 * densities[-1][5], 

255 parameters["k_C"] * densities[-1][5] 

256 - (parameters["k_R"] + parameters["k_D"]) 

257 * densities[-1][6], 

258 parameters["k_D"] * densities[-1][6], 

259 -parameters["k_S"] * densities[-1][8] 

260 + k_Q * densities[-1][0], 

261 -parameters["k_I"] * densities[-1][9] 

262 + k_Q * densities[-1][1], 

263 parameters["k_I"] * densities[-1][9] 

264 + k_Q * densities[-1][2] 

265 - (parameters["k_SY"] + parameters["k_R"]) 

266 * densities[-1][10], 

267 parameters["k_SY"] * densities[-1][2] 

268 - self.k_q 

269 * torch.sum(densities[-1][0:3]) 

270 * densities[-1][-1], 

271 ] 

272 ) 

273 * self.dt, 

274 0, 

275 1, 

276 ) 

277 ) 

278 

279 # Discard the initial condition 

280 densities = torch.stack(densities[1:]) 

281 

282 if self.Berlin_data_loss: 

283 # For the Berlin dataset, combine the quarantine compartments and drop the deceased compartment, 

284 # which is not present in the ABM data 

285 densities = torch.cat( 

286 [ 

287 densities[:, :7], 

288 torch.sum(densities[:, 8:11], dim=1, keepdim=True), 

289 ], 

290 dim=1, 

291 ) 

292 loss = ( 

293 self.alpha 

294 * self.loss_function( 

295 densities, 

296 torch.cat( 

297 [ 

298 self.training_data[ 

299 batch_idx + 1 : self.batches[batch_no + 1] + 1, :7 

300 ], 

301 self.training_data[ 

302 batch_idx + 1 : self.batches[batch_no + 1] + 1, [8] 

303 ], 

304 ], 

305 1, 

306 ), 

307 ).sum(dim=0) 

308 ).sum() 

309 

310 # Regular loss function 

311 else: 

312 loss = self.loss_function( 

313 densities, 

314 self.training_data[batch_idx + 1 : self.batches[batch_no + 1] + 1], 

315 ) / (self.batches[batch_no + 1] - batch_idx) 

316 

317 # Perform a gradient descent step 

318 loss.backward() 

319 self.neural_net.optimizer.step() 

320 self.neural_net.optimizer.zero_grad() 

321 self.current_loss = loss.clone().detach().cpu().numpy().item() 

322 self.current_predictions = torch.tensor( 

323 [ 

324 predicted_parameters.clone().detach().cpu()[self.to_learn[p]] 

325 * self.scaling_factors[self.to_learn[p]] 

326 for p in self.to_learn.keys() 

327 ] 

328 ) 

329 self._time += 1 

330 self.write_data() 

331 

332 def write_data(self): 

333 """Write the current state (loss and parameter predictions) into the state dataset. 

334 

335 In the case of HDF5 data writing that is used here, this requires to 

336 extend the dataset size prior to writing; this way, the newly written 

337 data is always in the last row of the dataset. 

338 """ 

339 if self._time >= self._write_start and (self._time % self._write_every == 0): 

340 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0) 

341 self._dset_loss[-1] = self.current_loss 

342 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

343 self.dset_parameters[-1, :] = [ 

344 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys() 

345 ]