Coverage for models / Covid / ensemble_training / NN.py: 17%
96 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import sys
2from os.path import dirname as up
4import h5py as h5
5import numpy as np
6import torch
7from dantro import logging
8from dantro._import_tools import import_module_from_path
10sys.path.append(up(up(__file__)))
11sys.path.append(up(up(up(__file__))))
13Covid = import_module_from_path(mod_path=up(up(__file__)), mod_str="Covid")
14base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
16log = logging.getLogger(__name__)
18# ----------------------------------------------------------------------------------------------------------------------
19# Model implementation
20# ----------------------------------------------------------------------------------------------------------------------
23class Covid_NN:
24 def __init__(
25 self,
26 *,
27 rng: np.random.Generator,
28 h5group: h5.Group,
29 neural_net: base.BaseNN,
30 loss_function: dict,
31 to_learn: list,
32 time_dependent_parameters: dict = None,
33 true_parameters: dict = {},
34 dt: float,
35 k_q: float = 10.25,
36 Berlin_data_loss: bool = False,
37 write_every: int = 1,
38 write_start: int = 1,
39 training_data: torch.Tensor,
40 batch_size: int,
41 scaling_factors: dict = {},
42 **__,
43 ):
44 """Initialize the model instance with a previously constructed RNG and
45 HDF5 group to write the output data to.
47 Args:
48 rng (np.random.Generator): The shared RNG
49 h5group (h5.Group): The output file group to write data to
50 neural_net: The neural network
51 loss_function (dict): the loss function to use
52 to_learn: the list of parameter names to learn
53 time_dependent_parameters: dictionary of time-dependent parameters and their granularity
54 true_parameters: the dictionary of true parameters
55 dt: time differential
56 k_q: contact tracing rate
57 Berlin_data_loss: whether to use the loss structure unique to the Berlin data
58 write_every: write every iteration
59 write_start: iteration at which to start writing
60 training_data: the training data to use
61 batch_size: epoch batch size: instead of calculating the entire time series,
62 only a subsample of length batch_size can be used. The likelihood is then
63 scaled up accordingly.
64 scaling_factors: dictionary of scaling factors for the different parameters. Parameter estimates are
65 multiplied by these to ensure all parameters are roughly of the same order of magnitude
66 """
67 self._h5group = h5group
68 self._rng = rng
70 self.neural_net = neural_net
71 self.neural_net.optimizer.zero_grad()
72 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()](
73 loss_function.get("args", None), **loss_function.get("kwargs", {})
74 )
76 self.dt = torch.tensor(dt, dtype=torch.float)
77 self.k_q = torch.tensor(k_q, dtype=torch.float)
78 self.Berlin_data_loss = Berlin_data_loss
80 self.current_loss = torch.tensor(0.0)
82 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
83 self.time_dependent_parameters = (
84 time_dependent_parameters if time_dependent_parameters else {}
85 )
86 self.true_parameters = {
87 key: torch.tensor(val, dtype=torch.float)
88 for key, val in true_parameters.items()
89 }
90 self.all_parameters = set(self.to_learn.keys())
91 self.all_parameters.update(self.true_parameters.keys())
92 self.current_predictions = torch.zeros(len(self.to_learn), dtype=torch.float)
94 # Training data
95 self.training_data = training_data
97 # Generate the batch ids
98 batches = np.arange(0, self.training_data.shape[0], batch_size)
99 if len(batches) == 1:
100 batches = np.append(batches, training_data.shape[0] - 1)
101 else:
102 if batches[-1] != training_data.shape[0] - 1:
103 batches = np.append(batches, training_data.shape[0] - 1)
105 self.batches = batches
107 # --- Set up chunked dataset to store the state data in --------------------------------------------------------
108 self._dset_loss = self._h5group.create_dataset(
109 "loss",
110 (0,),
111 maxshape=(None,),
112 chunks=True,
113 compression=3,
114 )
115 self._dset_loss.attrs["dim_names"] = ["batch"]
116 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step"
117 self._dset_loss.attrs["coords__batch"] = [write_start, write_every]
119 self.dset_time = self._h5group.create_dataset(
120 "computation_time",
121 (0,),
122 maxshape=(None,),
123 chunks=True,
124 compression=3,
125 )
126 self.dset_time.attrs["dim_names"] = ["epoch"]
127 self.dset_time.attrs["coords_mode__epoch"] = "trivial"
129 # Create a dataset for the parameter estimates
130 self.dset_parameters = self._h5group.create_dataset(
131 "parameters",
132 (0, len(self.to_learn.keys())),
133 maxshape=(None, len(self.to_learn.keys())),
134 chunks=True,
135 compression=3,
136 )
137 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"]
138 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step"
139 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every]
140 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
141 self.dset_parameters.attrs["coords__parameter"] = to_learn
143 # --------------------------------------------------------------------------------------------------------------
144 # Batches processed
145 self._time = 0
146 self._write_every = write_every
147 self._write_start = write_start
149 # Calculate the coefficients of each term in the loss function:
150 # \alpha_i^{-1} = \int T_i(t) dt
151 alpha = torch.sum(training_data, dim=0) * self.dt
152 alpha = torch.where(alpha > 0, alpha, torch.tensor(1.0))
153 self.alpha = (
154 torch.cat([alpha[0:7], torch.sum(alpha[8:11], 0, keepdim=True)], 0)
155 ) ** (-1)
157 # Reduced data model
158 # for idx in [0, 1, 2, 3, 7]: # S, E, I, R, Q are dropped
159 # self.alpha[idx] = 0
161 # Get all the jump points
162 self.jump_points = {}
163 if self.time_dependent_parameters:
164 self.jump_points = set(
165 np.hstack(
166 [
167 np.array(interval).flatten()
168 for _, interval in self.time_dependent_parameters.items()
169 ]
170 )
171 )
172 if None in self.jump_points:
173 self.jump_points.remove(None)
175 # Get the scaling factors
176 self.scaling_factors = torch.tensor(
177 list(
178 {
179 key: torch.tensor(scaling_factors[key], dtype=torch.float)
180 if key in scaling_factors.keys()
181 else torch.tensor(1.0, dtype=torch.float)
182 for key in self.to_learn.keys()
183 }.values()
184 ),
185 dtype=torch.float,
186 )
188 def epoch(self):
189 """Trains the model for a single epoch"""
191 # Process the training data in batches
192 for batch_no, batch_idx in enumerate(self.batches[:-1]):
193 # Make a prediction
194 predicted_parameters = self.neural_net(
195 torch.flatten(self.training_data[batch_idx])
196 )
198 # Combine the predicted and true parameters into a dictionary
199 parameters = {
200 p: predicted_parameters[self.to_learn[p]]
201 * self.scaling_factors[self.to_learn[p]]
202 if p in self.to_learn.keys()
203 else self.true_parameters[p]
204 for p in self.all_parameters
205 }
207 # Get the initial values
208 current_densities = self.training_data[batch_idx].clone()
209 current_densities.requires_grad_(True)
210 densities = [current_densities]
212 # Integrate the ODE for B steps
213 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1):
214 # Adjust for time-dependency
215 for key, ranges in self.time_dependent_parameters.items():
216 for idx, r in enumerate(ranges):
217 if not r[1]:
218 r[1] = len(self.training_data) + 1
219 if r[0] <= ele < r[1]:
220 parameters[key] = parameters[key + f"_{idx}"]
221 break
223 # Calculate the k_Q parameter from the current CT figures and k_CT estimate
224 k_Q = self.k_q * parameters["k_CT"] * densities[-1][-1]
226 # Solve the ODE
227 densities.append(
228 torch.clip(
229 densities[-1]
230 + torch.stack(
231 [
232 (-parameters["k_E"] * densities[-1][2] - k_Q)
233 * densities[-1][0]
234 + parameters["k_S"] * densities[-1][8],
235 parameters["k_E"] * densities[-1][0] * densities[-1][2]
236 - (parameters["k_I"] + k_Q) * densities[-1][1],
237 parameters["k_I"] * densities[-1][1]
238 - (parameters["k_R"] + parameters["k_SY"] + k_Q)
239 * densities[-1][2],
240 parameters["k_R"]
241 * (
242 densities[-1][2]
243 + densities[-1][4]
244 + densities[-1][5]
245 + densities[-1][6]
246 + densities[-1][10]
247 ),
248 parameters["k_SY"]
249 * (densities[-1][2] + densities[-1][10])
250 - (parameters["k_R"] + parameters["k_H"])
251 * densities[-1][4],
252 parameters["k_H"] * densities[-1][4]
253 - (parameters["k_R"] + parameters["k_C"])
254 * densities[-1][5],
255 parameters["k_C"] * densities[-1][5]
256 - (parameters["k_R"] + parameters["k_D"])
257 * densities[-1][6],
258 parameters["k_D"] * densities[-1][6],
259 -parameters["k_S"] * densities[-1][8]
260 + k_Q * densities[-1][0],
261 -parameters["k_I"] * densities[-1][9]
262 + k_Q * densities[-1][1],
263 parameters["k_I"] * densities[-1][9]
264 + k_Q * densities[-1][2]
265 - (parameters["k_SY"] + parameters["k_R"])
266 * densities[-1][10],
267 parameters["k_SY"] * densities[-1][2]
268 - self.k_q
269 * torch.sum(densities[-1][0:3])
270 * densities[-1][-1],
271 ]
272 )
273 * self.dt,
274 0,
275 1,
276 )
277 )
279 # Discard the initial condition
280 densities = torch.stack(densities[1:])
282 if self.Berlin_data_loss:
283 # For the Berlin dataset, combine the quarantine compartments and drop the deceased compartment,
284 # which is not present in the ABM data
285 densities = torch.cat(
286 [
287 densities[:, :7],
288 torch.sum(densities[:, 8:11], dim=1, keepdim=True),
289 ],
290 dim=1,
291 )
292 loss = (
293 self.alpha
294 * self.loss_function(
295 densities,
296 torch.cat(
297 [
298 self.training_data[
299 batch_idx + 1 : self.batches[batch_no + 1] + 1, :7
300 ],
301 self.training_data[
302 batch_idx + 1 : self.batches[batch_no + 1] + 1, [8]
303 ],
304 ],
305 1,
306 ),
307 ).sum(dim=0)
308 ).sum()
310 # Regular loss function
311 else:
312 loss = self.loss_function(
313 densities,
314 self.training_data[batch_idx + 1 : self.batches[batch_no + 1] + 1],
315 ) / (self.batches[batch_no + 1] - batch_idx)
317 # Perform a gradient descent step
318 loss.backward()
319 self.neural_net.optimizer.step()
320 self.neural_net.optimizer.zero_grad()
321 self.current_loss = loss.clone().detach().cpu().numpy().item()
322 self.current_predictions = torch.tensor(
323 [
324 predicted_parameters.clone().detach().cpu()[self.to_learn[p]]
325 * self.scaling_factors[self.to_learn[p]]
326 for p in self.to_learn.keys()
327 ]
328 )
329 self._time += 1
330 self.write_data()
332 def write_data(self):
333 """Write the current state (loss and parameter predictions) into the state dataset.
335 In the case of HDF5 data writing that is used here, this requires to
336 extend the dataset size prior to writing; this way, the newly written
337 data is always in the last row of the dataset.
338 """
339 if self._time >= self._write_start and (self._time % self._write_every == 0):
340 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0)
341 self._dset_loss[-1] = self.current_loss
342 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
343 self.dset_parameters[-1, :] = [
344 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys()
345 ]