Coverage for models/SIR/NN.py: 21%
85 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
2from os.path import dirname as up
4import coloredlogs
5import h5py as h5
6import numpy as np
7import torch
8from dantro import logging
9from dantro._import_tools import import_module_from_path
11sys.path.append(up(up(__file__)))
12sys.path.append(up(up(up(__file__))))
14SIR = import_module_from_path(mod_path=up(up(__file__)), mod_str="SIR")
15base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
17log = logging.getLogger(__name__)
18coloredlogs.install(fmt="%(levelname)s %(message)s", level="INFO", logger=log)
21class SIR_NN:
22 def __init__(
23 self,
24 *,
25 rng: np.random.Generator,
26 h5group: h5.Group,
27 neural_net: base.NeuralNet,
28 loss_function: dict,
29 to_learn: list,
30 true_parameters: dict = {},
31 write_every: int = 1,
32 write_start: int = 1,
33 training_data: torch.Tensor,
34 batch_size: int,
35 scaling_factors: dict = {},
36 **__,
37 ):
38 """Initialize the model instance with a previously constructed RNG and
39 HDF5 group to write the output data to.
41 :param rng (np.random.Generator): The shared RNG
42 :param h5group (h5.Group): The output file group to write data to
43 :param neural_net: The neural network
44 :param loss_function (dict): the loss function to use
45 :param to_learn: the list of parameter names to learn
46 :param true_parameters: the dictionary of true parameters
47 :param training_data: the training data to use
48 :param write_every: write every iteration
49 :param write_start: iteration at which to start writing
50 :param batch_size: epoch batch size: instead of calculating the entire time series,
51 only a subsample of length batch_size can be used. The likelihood is then
52 scaled up accordingly.
53 :param scaling_factors: factors by which the parameters are to be scaled
54 """
55 self._h5group = h5group
56 self._rng = rng
58 self.neural_net = neural_net
59 self.neural_net.optimizer.zero_grad()
60 self.loss_function = base.LOSS_FUNCTIONS[loss_function.get("name").lower()](
61 loss_function.get("args", None), **loss_function.get("kwargs", {})
62 )
64 self.current_loss = torch.tensor(0.0)
66 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
67 self.true_parameters = {
68 key: torch.tensor(val, dtype=torch.float)
69 for key, val in true_parameters.items()
70 }
71 self.current_predictions = torch.zeros(len(to_learn))
73 self.training_data = training_data
75 # Generate the batch ids
76 batches = np.arange(0, self.training_data.shape[0], batch_size)
77 if len(batches) == 1:
78 batches = np.append(batches, training_data.shape[0] - 1)
79 else:
80 if batches[-1] != training_data.shape[0] - 1:
81 batches = np.append(batches, training_data.shape[0] - 1)
83 self.batches = batches
85 self.scaling_factors = scaling_factors
87 # --- Set up chunked dataset to store the state data in --------------------------------------------------------
88 # Write the loss after every batch
89 self._dset_loss = self._h5group.create_dataset(
90 "loss",
91 (0,),
92 maxshape=(None,),
93 chunks=True,
94 compression=3,
95 )
96 self._dset_loss.attrs["dim_names"] = ["batch"]
97 self._dset_loss.attrs["coords_mode__batch"] = "start_and_step"
98 self._dset_loss.attrs["coords__batch"] = [write_start, write_every]
100 # Write the computation time of every epoch
101 self.dset_time = self._h5group.create_dataset(
102 "computation_time",
103 (0,),
104 maxshape=(None,),
105 chunks=True,
106 compression=3,
107 )
108 self.dset_time.attrs["dim_names"] = ["epoch"]
109 self.dset_time.attrs["coords_mode__epoch"] = "trivial"
111 # Write the parameter predictions after every batch
112 self.dset_parameters = self._h5group.create_dataset(
113 "parameters",
114 (0, len(self.to_learn.keys())),
115 maxshape=(None, len(self.to_learn.keys())),
116 chunks=True,
117 compression=3,
118 )
119 self.dset_parameters.attrs["dim_names"] = ["batch", "parameter"]
120 self.dset_parameters.attrs["coords_mode__batch"] = "start_and_step"
121 self.dset_parameters.attrs["coords__batch"] = [write_start, write_every]
122 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
123 self.dset_parameters.attrs["coords__parameter"] = to_learn
125 # The training data and batch ids
126 self.training_data = training_data
128 batches = np.arange(0, training_data.shape[0], batch_size)
129 if len(batches) == 1:
130 batches = np.append(batches, training_data.shape[0] - 1)
131 else:
132 if batches[-1] != training_data.shape[0] - 1:
133 batches = np.append(batches, training_data.shape[0] - 1)
134 self.batches = batches
136 # Batches processed
137 self._time = 0
138 self._write_every = write_every
139 self._write_start = write_start
141 def epoch(self):
142 """
143 An epoch is a pass over the entire dataset. The dataset is processed in batches, where B < L is the batch
144 number. After each batch, the parameters of the neural network are updated. For example, if L = 100 and
145 B = 50, two passes are made over the dataset -- one over the first 50 steps, and one
146 over the second 50. The entire time series is processed, even if L is not divisible into equal segments of
147 length B. For instance, is B is 30, the time series is processed in 3 steps of 30 and one of 10.
149 """
151 # Process the training data in batches
152 for batch_no, batch_idx in enumerate(self.batches[:-1]):
153 predicted_parameters = self.neural_net(
154 torch.flatten(self.training_data[batch_idx])
155 )
157 # Get the parameters: infection rate, recovery time, noise variance
158 p = (
159 self.scaling_factors.get("p_infect", 1.0) * predicted_parameters[self.to_learn["p_infect"]]
160 if "p_infect" in self.to_learn.keys()
161 else self.true_parameters["p_infect"]
162 )
163 t = (
164 self.scaling_factors.get("t_infectious", 1.0) * predicted_parameters[self.to_learn["t_infectious"]]
165 if "t_infectious" in self.to_learn.keys()
166 else self.true_parameters["t_infectious"]
167 )
168 sigma = (
169 self.scaling_factors.get("sigma", 1.0) * predicted_parameters[self.to_learn["sigma"]]
170 if "sigma" in self.to_learn.keys()
171 else self.true_parameters["sigma"]
172 )
174 current_densities = self.training_data[batch_idx].clone()
175 current_densities.requires_grad_(True)
177 loss = torch.tensor(0.0, requires_grad=True)
179 for ele in range(batch_idx + 1, self.batches[batch_no + 1] + 1):
180 # Recovery rate
181 tau = 1 / t * torch.sigmoid(1000 * (ele / t - 1))
183 # Random noise
184 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0))
186 # Solve the ODE
187 current_densities = torch.clip(
188 current_densities
189 + torch.stack(
190 [
191 (-p * current_densities[0] - sigma * w)
192 * current_densities[1],
193 (p * current_densities[0] + sigma * w - tau)
194 * current_densities[1],
195 tau * current_densities[1],
196 ]
197 ),
198 0.0,
199 1.0,
200 )
202 # Calculate loss
203 loss = loss + self.loss_function(
204 current_densities, self.training_data[ele]
205 ) * (
206 self.training_data.shape[0]
207 / (self.batches[batch_no + 1] - batch_idx)
208 )
210 loss.backward()
211 self.neural_net.optimizer.step()
212 self.neural_net.optimizer.zero_grad()
213 self.current_loss = loss.clone().detach().cpu().numpy().item()
214 self.current_predictions = predicted_parameters.clone().detach().cpu()
216 # Scale the parameters
217 for param in self.to_learn.keys():
218 self.current_predictions[self.to_learn[param]] *= self.scaling_factors.get(param, 1.0)
219 self._time += 1
220 self.write_data()
222 def write_data(self):
223 """Write the current state (loss and parameter predictions) into the state dataset.
225 In the case of HDF5 data writing that is used here, this requires to
226 extend the dataset size prior to writing; this way, the newly written
227 data is always in the last row of the dataset.
228 """
229 if self._time >= self._write_start and (self._time % self._write_every == 0):
230 self._dset_loss.resize(self._dset_loss.shape[0] + 1, axis=0)
231 self._dset_loss[-1] = self.current_loss
232 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
233 self.dset_parameters[-1, :] = [
234 self.current_predictions[self.to_learn[p]] for p in self.to_learn.keys()
235 ]