Coverage for models / Covid / ensemble_training / Langevin.py: 22%
73 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import logging
2import time
4import h5py as h5
5import torch
7log = logging.getLogger(__name__)
9import sys
10from os.path import dirname as up
12sys.path.append(up(up(__file__)))
13sys.path.append(up(up(up(__file__))))
14from dantro._import_tools import import_module_from_path
16base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
19class Covid_Langevin_sampler(base.MetropolisAdjustedLangevin):
20 """
21 A Metropolis-adjusted Langevin sampler for the Covid model that inherits from the base class
22 """
24 def __init__(
25 self,
26 *,
27 true_data: torch.Tensor,
28 prior: dict,
29 lr: float = 1e-2,
30 lr_final: float = 1e-4,
31 max_itr: float = 1e4,
32 beta: float = 0.99,
33 Lambda: float = 1e-15,
34 centered: bool = False,
35 write_start: int = 1,
36 write_every: int = 1,
37 batch_size: int = 1,
38 dt: float,
39 k_q: float = 10.25,
40 Berlin_data_loss: bool = False,
41 to_learn: list,
42 time_dependent_parameters: dict = None,
43 true_parameters: dict,
44 h5File: h5.File,
45 **__,
46 ):
47 # Parameters to learn
48 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
49 self.time_dependent_parameters = (
50 time_dependent_parameters if time_dependent_parameters else {}
51 )
52 self.true_parameters = {
53 key: torch.tensor(val, dtype=torch.float)
54 for key, val in true_parameters.items()
55 }
56 self.all_parameters = set(self.to_learn.keys())
57 self.all_parameters.update(self.true_parameters.keys())
58 self.N = len(self.to_learn)
60 # Draw an initial guess from the prior
61 init_guess: torch.Tensor = base.random_tensor(
62 prior,
63 size=(
64 len(
65 self.to_learn.keys(),
66 )
67 ),
68 )
70 # Initialise the parent class with the initial values
71 super().__init__(
72 true_data=true_data,
73 init_guess=init_guess,
74 lr=lr,
75 lr_final=lr_final,
76 max_itr=max_itr,
77 beta=beta,
78 Lambda=Lambda,
79 centered=centered,
80 write_start=write_start,
81 write_every=write_every,
82 batch_size=batch_size,
83 h5File=h5File,
84 )
86 # Covid equation parameters
87 self.dt = torch.tensor(dt, dtype=torch.float)
88 self.k_q = torch.tensor(k_q, dtype=torch.float)
89 self.Berlin_data_loss = Berlin_data_loss
91 # Drop D, CT compartments for Berlin model, combine Q compartments
92 if self.Berlin_data_loss:
93 alpha = torch.sum(self.true_data, dim=0).squeeze()
94 alpha = torch.cat([alpha[0:7], torch.sum(alpha[8:11], 0, keepdim=True)], 0)
95 self.alpha = torch.squeeze(alpha ** (-1))
97 # Create datasets for the predictions
98 self.dset_parameters = self.h5group.create_dataset(
99 "parameters",
100 (0, len(self.to_learn.keys())),
101 maxshape=(None, len(self.to_learn.keys())),
102 chunks=True,
103 compression=3,
104 )
105 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"]
106 self.dset_parameters.attrs["coords_mode__sample"] = "trivial"
107 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
108 self.dset_parameters.attrs["coords__parameter"] = to_learn
110 # Calculate the initial values of the loss and its gradient
111 self.loss[0] = self.loss_function(self.x[0])
112 self.loss[1].data = self.loss[0].data
114 self.grad[0].data = torch.autograd.grad(
115 self.loss[0], [self.x[0]], create_graph=False
116 )[0].data
117 self.grad[1].data = self.grad[0].data
119 def loss_function(self, input):
120 r"""Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation.
122 :param parameters: the vector of parameters
123 :return: likelihood || \hat{T}(\hat{Lambda}) - T ||_2
124 """
126 if self.true_data.shape[0] - self.batch_size == 1:
127 start = 1
128 else:
129 start = torch.randint(
130 1, self.true_data.shape[0] - self.batch_size, (1,)
131 ).item()
133 densities = [self.true_data[start - 1]]
135 parameters = {
136 p: input[self.to_learn[p]]
137 if p in self.to_learn.keys()
138 else self.true_parameters[p]
139 for p in self.all_parameters
140 }
142 for t in range(start, start + self.batch_size - 1):
143 for key, ranges in self.time_dependent_parameters.items():
144 for idx, r in enumerate(ranges):
145 if not r[1]:
146 r[1] = len(self.true_data) + 1
147 if r[0] <= t < r[1]:
148 parameters[key] = parameters[key + f"_{idx}"]
149 break
151 k_Q = self.k_q * parameters["k_CT"] * densities[-1][-1]
153 # Solve the ODE
154 densities.append(
155 torch.clip(
156 densities[-1]
157 + torch.stack(
158 [
159 (-parameters["k_E"] * densities[-1][2] - k_Q)
160 * densities[-1][0]
161 + parameters["k_S"] * densities[-1][8],
162 parameters["k_E"] * densities[-1][0] * densities[-1][2]
163 - (parameters["k_I"] + k_Q) * densities[-1][1],
164 parameters["k_I"] * densities[-1][1]
165 - (parameters["k_R"] + parameters["k_SY"] + k_Q)
166 * densities[-1][2],
167 parameters["k_R"]
168 * (
169 densities[-1][2]
170 + densities[-1][4]
171 + densities[-1][5]
172 + densities[-1][6]
173 + densities[-1][10]
174 ),
175 parameters["k_SY"] * (densities[-1][2] + densities[-1][10])
176 - (parameters["k_R"] + parameters["k_H"])
177 * densities[-1][4],
178 parameters["k_H"] * densities[-1][4]
179 - (parameters["k_R"] + parameters["k_C"])
180 * densities[-1][5],
181 parameters["k_C"] * densities[-1][5]
182 - (parameters["k_R"] + parameters["k_D"])
183 * densities[-1][6],
184 parameters["k_D"] * densities[-1][6],
185 -parameters["k_S"] * densities[-1][8]
186 + k_Q * densities[-1][0],
187 -parameters["k_I"] * densities[-1][9]
188 + k_Q * densities[-1][1],
189 parameters["k_I"] * densities[-1][9]
190 + k_Q * densities[-1][2]
191 - (parameters["k_SY"] + parameters["k_R"])
192 * densities[-1][10],
193 parameters["k_SY"] * densities[-1][2]
194 - self.k_q
195 * torch.sum(densities[-1][0:3])
196 * densities[-1][-1],
197 ]
198 )
199 * self.dt,
200 0,
201 1,
202 )
203 )
205 densities = torch.stack(densities)
207 if self.Berlin_data_loss:
208 # Scale loss to prevent numerical underflow of the preconditioner (which is inversely proportional to the
209 # gradient)
210 loss = 5e4 * torch.dot(
211 self.alpha,
212 torch.concat(
213 [
214 torch.sum(
215 torch.pow(
216 densities[:, 0:7]
217 - self.true_data[start : start + self.batch_size, 0:7],
218 2,
219 ),
220 dim=0,
221 ),
222 torch.sum(
223 torch.pow(
224 torch.sum(densities[:, 8:11], dim=1)
225 - self.true_data[start : start + self.batch_size, 8],
226 2,
227 ),
228 dim=0,
229 keepdim=True,
230 ),
231 ],
232 0,
233 ).squeeze(),
234 )
236 else:
237 loss = torch.sum(
238 torch.pow(
239 densities - self.true_data[start : start + self.batch_size], 2
240 )
241 )
243 return loss
245 def write_parameters(self):
246 if self.time > self.write_start and self.time % self.write_every == 0:
247 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
248 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy()
251def perform_sampling(h5file, training_data, model_cfg: dict) -> None:
252 """Runs the Covid Langevin sampler.
254 :param h5file: hdf5 file to write the data to
255 :param training_data: training data used to calculate the likelihood
256 :param model_cfg: configuration file
257 """
259 # Number of samples
260 n_samples = model_cfg["MCMC"].pop("n_samples")
262 # Initialise the sampler
263 sampler = Covid_Langevin_sampler(
264 h5File=h5file,
265 true_data=training_data[
266 model_cfg["Data"].get("training_data_size", slice(None, None)), :, :
267 ],
268 to_learn=model_cfg["Training"]["to_learn"],
269 time_dependent_parameters=model_cfg["Data"].get(
270 "time_dependent_parameters", None
271 ),
272 true_parameters=model_cfg["Training"].get("true_parameters", {}),
273 dt=model_cfg["Data"]["synthetic_data"]["dt"],
274 k_q=model_cfg["Data"]["synthetic_data"]["k_q"],
275 **model_cfg["MCMC"],
276 )
278 # Track the sampling time
279 start_time = time.time()
281 # Collect n_samples
282 for i in range(n_samples):
283 sampler.sample()
284 sampler.write_loss()
285 sampler.write_parameters()
286 log.info(f"Collected {i} of {n_samples}; current loss: {sampler.loss[1]}")
288 # Write out the total sampling time
289 sampler.write_time(time.time() - start_time)
291 log.success(" MCMC sampling complete.")