Coverage for models / SIR / ensemble_training / NN.py: 15%
79 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import sys
2from os.path import dirname as up
3import h5py as h5
4import numpy as np
5import torch
6from dantro._import_tools import import_module_from_path
8sys.path.append(up(up(up(__file__))))
10base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
12class SIR_NN:
13 def __init__(
14 self,
15 *,
16 rng: np.random.Generator,
17 h5group: h5.Group,
18 neural_net: base.BaseNN,
19 loss_function: dict,
20 to_learn: list,
21 true_parameters: dict = {},
22 write_every: int = 1,
23 write_start: int = 1,
24 training_data: torch.Tensor,
25 batch_size: int,
26 scaling_factors: dict = {},
27 **__,
28 ):
29 """Initialize the model instance with a previously constructed RNG and
30 HDF5 group to write the output data to.
32 :param rng (np.random.Generator): The shared RNG
33 :param h5group (h5.Group): The output file group to write data to
34 :param neural_net: The neural network
35 :param loss_function (dict): the loss function to use
36 :param to_learn: the list of parameter names to learn
37 :param true_parameters: the dictionary of true parameters
38 :param training_data: the training data to use
39 :param write_every: write every iteration
40 :param write_start: iteration at which to start writing
41 :param batch_size: epoch batch size: instead of calculating the entire time series,
42 only a subsample of length batch_size can be used. The likelihood is then
43 scaled up accordingly.
44 :param scaling_factors: factors by which the parameters are to be scaled
45 """
46 self._h5group = h5group
47 self._rng = rng
49 self.neural_net = neural_net
50 self.neural_net.optimizer.zero_grad()
51 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()](
52 loss_function.get("args", None), **loss_function.get("kwargs", {})
53 )
55 self.current_loss = torch.tensor(0.0)
57 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
58 self.true_parameters = {
59 key: torch.tensor(val, dtype=torch.float)
60 for key, val in true_parameters.items()
61 }
62 self.current_predictions = torch.zeros(len(to_learn))
64 self.training_data = training_data
66 # Generate the batch ids
67 batches = np.arange(0, self.training_data.shape[0], batch_size)
68 if len(batches) == 1:
69 batches = np.append(batches, training_data.shape[0] - 1)
70 else:
71 if batches[-1] != training_data.shape[0] - 1:
72 batches = np.append(batches, training_data.shape[0] - 1)
74 self.batches = batches
76 self.scaling_factors = scaling_factors
78 # --- Set up chunked dataset to store the state data in --------------------------------------------------------
79 # Write the loss after every batch
80 self._dset_loss = self._h5group.create_dataset(
81 "loss",
82 (0,),
83 maxshape=(None,),
84 chunks=True,
85 compression=3,
86 )
87 self._dset_loss.attrs["dim_names"] = ["batch"]
88 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step"
89 self._dset_loss.attrs["coords__batch"] = [write_start, write_every]
91 # Write the computation time of every epoch
92 self.dset_time = self._h5group.create_dataset(
93 "computation_time",
94 (0,),
95 maxshape=(None,),
96 chunks=True,
97 compression=3,
98 )
99 self.dset_time.attrs["dim_names"] = ["epoch"]
100 self.dset_time.attrs["coords_mode__epoch"] = "trivial"
102 # Write the parameter predictions after every batch
103 self.dset_parameters = self._h5group.create_dataset(
104 "parameters",
105 (0, len(self.to_learn.keys())),
106 maxshape=(None, len(self.to_learn.keys())),
107 chunks=True,
108 compression=3,
109 )
110 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"]
111 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step"
112 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every]
113 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
114 self.dset_parameters.attrs["coords__parameter"] = to_learn
116 # The training data and batch ids
117 self.training_data = training_data
119 batches = np.arange(0, training_data.shape[0], batch_size)
120 if len(batches) == 1:
121 batches = np.append(batches, training_data.shape[0] - 1)
122 else:
123 if batches[-1] != training_data.shape[0] - 1:
124 batches = np.append(batches, training_data.shape[0] - 1)
125 self.batches = batches
127 # Batches processed
128 self._time = 0
129 self._write_every = write_every
130 self._write_start = write_start
132 def epoch(self):
133 """
134 An epoch is a pass over the entire dataset. The dataset is processed in batches, where B < L is the batch
135 number. After each batch, the parameters of the neural network are updated. For example, if L = 100 and
136 B = 50, two passes are made over the dataset -- one over the first 50 steps, and one
137 over the second 50. The entire time series is processed, even if L is not divisible into equal segments of
138 length B. For instance, is B is 30, the time series is processed in 3 steps of 30 and one of 10.
140 """
142 # Process the training data in batches
143 for batch_no, batch_idx in enumerate(self.batches[:-1]):
144 predicted_parameters = self.neural_net(
145 torch.flatten(self.training_data[batch_idx])
146 )
148 # Get the parameters: infection rate, recovery time, noise variance
149 p = (
150 self.scaling_factors.get("p_infect", 1.0) * predicted_parameters[self.to_learn["p_infect"]]
151 if "p_infect" in self.to_learn.keys()
152 else self.true_parameters["p_infect"]
153 )
154 t = (
155 self.scaling_factors.get("t_infectious", 1.0) * predicted_parameters[self.to_learn["t_infectious"]]
156 if "t_infectious" in self.to_learn.keys()
157 else self.true_parameters["t_infectious"]
158 )
159 sigma = (
160 self.scaling_factors.get("sigma", 1.0) * predicted_parameters[self.to_learn["sigma"]]
161 if "sigma" in self.to_learn.keys()
162 else self.true_parameters["sigma"]
163 )
165 current_densities = self.training_data[batch_idx].clone()
166 current_densities.requires_grad_(True)
168 loss = torch.tensor(0.0, requires_grad=True)
170 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1):
171 # Recovery rate
172 tau = 1 / t * torch.sigmoid(1000 * (ele / t - 1))
174 # Random noise
175 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0))
177 # Solve the ODE
178 current_densities = torch.clip(
179 current_densities
180 + torch.stack(
181 [
182 (-p * current_densities[0] - sigma * w)
183 * current_densities[1],
184 (p * current_densities[0] + sigma * w - tau)
185 * current_densities[1],
186 tau * current_densities[1],
187 ]
188 ),
189 0.0,
190 1.0,
191 )
193 # Calculate loss
194 loss = loss + self.loss_function(
195 current_densities, self.training_data[ele]
196 ) * (
197 self.training_data.shape[0]
198 / (self.batches[batch_no + 1] - batch_idx)
199 )
201 loss.backward()
202 self.neural_net.optimizer.step()
203 self.neural_net.optimizer.zero_grad()
204 self.current_loss = loss.clone().detach().cpu().numpy().item()
205 self.current_predictions = predicted_parameters.clone().detach().cpu()
207 # Scale the parameters
208 for param in self.to_learn.keys():
209 self.current_predictions[self.to_learn[param]] *= self.scaling_factors.get(param, 1.0)
210 self._time += 1
211 self.write_data()
213 def write_data(self):
214 """Write the current state (loss and parameter predictions) into the state dataset.
216 In the case of HDF5 data writing that is used here, this requires to
217 extend the dataset size prior to writing; this way, the newly written
218 data is always in the last row of the dataset.
219 """
220 if self._time >= self._write_start and (self._time % self._write_every == 0):
221 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0)
222 self._dset_loss[-1] = self.current_loss
223 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
224 self.dset_parameters[-1, :] = [
225 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys()
226 ]