Coverage for models / Covid / ensemble_training / Langevin.py: 22%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import logging 

2import time 

3 

4import h5py as h5 

5import torch 

6 

7log = logging.getLogger(__name__) 

8 

9import sys 

10from os.path import dirname as up 

11 

12sys.path.append(up(up(__file__))) 

13sys.path.append(up(up(up(__file__)))) 

14from dantro._import_tools import import_module_from_path 

15 

16base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

17 

18 

19class Covid_Langevin_sampler(base.MetropolisAdjustedLangevin): 

20 """ 

21 A Metropolis-adjusted Langevin sampler for the Covid model that inherits from the base class 

22 """ 

23 

24 def __init__( 

25 self, 

26 *, 

27 true_data: torch.Tensor, 

28 prior: dict, 

29 lr: float = 1e-2, 

30 lr_final: float = 1e-4, 

31 max_itr: float = 1e4, 

32 beta: float = 0.99, 

33 Lambda: float = 1e-15, 

34 centered: bool = False, 

35 write_start: int = 1, 

36 write_every: int = 1, 

37 batch_size: int = 1, 

38 dt: float, 

39 k_q: float = 10.25, 

40 Berlin_data_loss: bool = False, 

41 to_learn: list, 

42 time_dependent_parameters: dict = None, 

43 true_parameters: dict, 

44 h5File: h5.File, 

45 **__, 

46 ): 

47 # Parameters to learn 

48 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

49 self.time_dependent_parameters = ( 

50 time_dependent_parameters if time_dependent_parameters else {} 

51 ) 

52 self.true_parameters = { 

53 key: torch.tensor(val, dtype=torch.float) 

54 for key, val in true_parameters.items() 

55 } 

56 self.all_parameters = set(self.to_learn.keys()) 

57 self.all_parameters.update(self.true_parameters.keys()) 

58 self.N = len(self.to_learn) 

59 

60 # Draw an initial guess from the prior 

61 init_guess: torch.Tensor = base.random_tensor( 

62 prior, 

63 size=( 

64 len( 

65 self.to_learn.keys(), 

66 ) 

67 ), 

68 ) 

69 

70 # Initialise the parent class with the initial values 

71 super().__init__( 

72 true_data=true_data, 

73 init_guess=init_guess, 

74 lr=lr, 

75 lr_final=lr_final, 

76 max_itr=max_itr, 

77 beta=beta, 

78 Lambda=Lambda, 

79 centered=centered, 

80 write_start=write_start, 

81 write_every=write_every, 

82 batch_size=batch_size, 

83 h5File=h5File, 

84 ) 

85 

86 # Covid equation parameters 

87 self.dt = torch.tensor(dt, dtype=torch.float) 

88 self.k_q = torch.tensor(k_q, dtype=torch.float) 

89 self.Berlin_data_loss = Berlin_data_loss 

90 

91 # Drop D, CT compartments for Berlin model, combine Q compartments 

92 if self.Berlin_data_loss: 

93 alpha = torch.sum(self.true_data, dim=0).squeeze() 

94 alpha = torch.cat([alpha[0:7], torch.sum(alpha[8:11], 0, keepdim=True)], 0) 

95 self.alpha = torch.squeeze(alpha ** (-1)) 

96 

97 # Create datasets for the predictions 

98 self.dset_parameters = self.h5group.create_dataset( 

99 "parameters", 

100 (0, len(self.to_learn.keys())), 

101 maxshape=(None, len(self.to_learn.keys())), 

102 chunks=True, 

103 compression=3, 

104 ) 

105 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"] 

106 self.dset_parameters.attrs["coords_mode__sample"] = "trivial" 

107 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

108 self.dset_parameters.attrs["coords__parameter"] = to_learn 

109 

110 # Calculate the initial values of the loss and its gradient 

111 self.loss[0] = self.loss_function(self.x[0]) 

112 self.loss[1].data = self.loss[0].data 

113 

114 self.grad[0].data = torch.autograd.grad( 

115 self.loss[0], [self.x[0]], create_graph=False 

116 )[0].data 

117 self.grad[1].data = self.grad[0].data 

118 

119 def loss_function(self, input): 

120 r"""Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation. 

121 

122 :param parameters: the vector of parameters 

123 :return: likelihood || \hat{T}(\hat{Lambda}) - T ||_2 

124 """ 

125 

126 if self.true_data.shape[0] - self.batch_size == 1: 

127 start = 1 

128 else: 

129 start = torch.randint( 

130 1, self.true_data.shape[0] - self.batch_size, (1,) 

131 ).item() 

132 

133 densities = [self.true_data[start - 1]] 

134 

135 parameters = { 

136 p: input[self.to_learn[p]] 

137 if p in self.to_learn.keys() 

138 else self.true_parameters[p] 

139 for p in self.all_parameters 

140 } 

141 

142 for t in range(start, start + self.batch_size - 1): 

143 for key, ranges in self.time_dependent_parameters.items(): 

144 for idx, r in enumerate(ranges): 

145 if not r[1]: 

146 r[1] = len(self.true_data) + 1 

147 if r[0] <= t < r[1]: 

148 parameters[key] = parameters[key + f"_{idx}"] 

149 break 

150 

151 k_Q = self.k_q * parameters["k_CT"] * densities[-1][-1] 

152 

153 # Solve the ODE 

154 densities.append( 

155 torch.clip( 

156 densities[-1] 

157 + torch.stack( 

158 [ 

159 (-parameters["k_E"] * densities[-1][2] - k_Q) 

160 * densities[-1][0] 

161 + parameters["k_S"] * densities[-1][8], 

162 parameters["k_E"] * densities[-1][0] * densities[-1][2] 

163 - (parameters["k_I"] + k_Q) * densities[-1][1], 

164 parameters["k_I"] * densities[-1][1] 

165 - (parameters["k_R"] + parameters["k_SY"] + k_Q) 

166 * densities[-1][2], 

167 parameters["k_R"] 

168 * ( 

169 densities[-1][2] 

170 + densities[-1][4] 

171 + densities[-1][5] 

172 + densities[-1][6] 

173 + densities[-1][10] 

174 ), 

175 parameters["k_SY"] * (densities[-1][2] + densities[-1][10]) 

176 - (parameters["k_R"] + parameters["k_H"]) 

177 * densities[-1][4], 

178 parameters["k_H"] * densities[-1][4] 

179 - (parameters["k_R"] + parameters["k_C"]) 

180 * densities[-1][5], 

181 parameters["k_C"] * densities[-1][5] 

182 - (parameters["k_R"] + parameters["k_D"]) 

183 * densities[-1][6], 

184 parameters["k_D"] * densities[-1][6], 

185 -parameters["k_S"] * densities[-1][8] 

186 + k_Q * densities[-1][0], 

187 -parameters["k_I"] * densities[-1][9] 

188 + k_Q * densities[-1][1], 

189 parameters["k_I"] * densities[-1][9] 

190 + k_Q * densities[-1][2] 

191 - (parameters["k_SY"] + parameters["k_R"]) 

192 * densities[-1][10], 

193 parameters["k_SY"] * densities[-1][2] 

194 - self.k_q 

195 * torch.sum(densities[-1][0:3]) 

196 * densities[-1][-1], 

197 ] 

198 ) 

199 * self.dt, 

200 0, 

201 1, 

202 ) 

203 ) 

204 

205 densities = torch.stack(densities) 

206 

207 if self.Berlin_data_loss: 

208 # Scale loss to prevent numerical underflow of the preconditioner (which is inversely proportional to the 

209 # gradient) 

210 loss = 5e4 * torch.dot( 

211 self.alpha, 

212 torch.concat( 

213 [ 

214 torch.sum( 

215 torch.pow( 

216 densities[:, 0:7] 

217 - self.true_data[start : start + self.batch_size, 0:7], 

218 2, 

219 ), 

220 dim=0, 

221 ), 

222 torch.sum( 

223 torch.pow( 

224 torch.sum(densities[:, 8:11], dim=1) 

225 - self.true_data[start : start + self.batch_size, 8], 

226 2, 

227 ), 

228 dim=0, 

229 keepdim=True, 

230 ), 

231 ], 

232 0, 

233 ).squeeze(), 

234 ) 

235 

236 else: 

237 loss = torch.sum( 

238 torch.pow( 

239 densities - self.true_data[start : start + self.batch_size], 2 

240 ) 

241 ) 

242 

243 return loss 

244 

245 def write_parameters(self): 

246 if self.time > self.write_start and self.time % self.write_every == 0: 

247 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

248 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy() 

249 

250 

251def perform_sampling(h5file, training_data, model_cfg: dict) -> None: 

252 """Runs the Covid Langevin sampler. 

253 

254 :param h5file: hdf5 file to write the data to 

255 :param training_data: training data used to calculate the likelihood 

256 :param model_cfg: configuration file 

257 """ 

258 

259 # Number of samples 

260 n_samples = model_cfg["MCMC"].pop("n_samples") 

261 

262 # Initialise the sampler 

263 sampler = Covid_Langevin_sampler( 

264 h5File=h5file, 

265 true_data=training_data[ 

266 model_cfg["Data"].get("training_data_size", slice(None, None)), :, : 

267 ], 

268 to_learn=model_cfg["Training"]["to_learn"], 

269 time_dependent_parameters=model_cfg["Data"].get( 

270 "time_dependent_parameters", None 

271 ), 

272 true_parameters=model_cfg["Training"].get("true_parameters", {}), 

273 dt=model_cfg["Data"]["synthetic_data"]["dt"], 

274 k_q=model_cfg["Data"]["synthetic_data"]["k_q"], 

275 **model_cfg["MCMC"], 

276 ) 

277 

278 # Track the sampling time 

279 start_time = time.time() 

280 

281 # Collect n_samples 

282 for i in range(n_samples): 

283 sampler.sample() 

284 sampler.write_loss() 

285 sampler.write_parameters() 

286 log.info(f"Collected {i} of {n_samples}; current loss: {sampler.loss[1]}") 

287 

288 # Write out the total sampling time 

289 sampler.write_time(time.time() - start_time) 

290 

291 log.success(" MCMC sampling complete.")