Coverage for models/SIR/Langevin.py: 32%
44 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import logging
3log = logging.getLogger(__name__)
5import sys
6from os.path import dirname as up
8import h5py as h5
9import torch
10from dantro._import_tools import import_module_from_path
12sys.path.append(up(up(__file__)))
13sys.path.append(up(up(up(__file__))))
15base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
18class SIR_Langevin_sampler(base.MetropolisAdjustedLangevin):
19 """
20 A Metropolis-adjusted Langevin sampler that inherits from the base class
21 """
23 def __init__(
24 self,
25 *,
26 true_data: torch.Tensor,
27 prior: dict,
28 lr: float = 1e-2,
29 lr_final: float = 1e-4,
30 max_itr: float = 1e4,
31 beta: float = 0.99,
32 Lambda: float = 1e-15,
33 centered: bool = False,
34 write_start: int = 1,
35 write_every: int = 1,
36 batch_size: int = 1,
37 h5File: h5.File,
38 to_learn: list,
39 true_parameters: dict,
40 **__,
41 ):
42 # Parameters to learn
43 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
44 self.true_parameters = {
45 key: torch.tensor(val, dtype=torch.float)
46 for key, val in true_parameters.items()
47 }
49 # Draw an initial guess from the prior
50 init_guess = base.random_tensor(prior, size=(len(to_learn),))
52 super().__init__(
53 true_data=true_data,
54 init_guess=init_guess,
55 lr=lr,
56 lr_final=lr_final,
57 max_itr=max_itr,
58 beta=beta,
59 Lambda=Lambda,
60 centered=centered,
61 write_start=write_start,
62 write_every=write_every,
63 batch_size=batch_size,
64 h5File=h5File,
65 )
67 # Create datasets for the predictions
68 self.dset_parameters = self.h5group.create_dataset(
69 "parameters",
70 (0, len(self.to_learn.keys())),
71 maxshape=(None, len(self.to_learn.keys())),
72 chunks=True,
73 compression=3,
74 )
75 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"]
76 self.dset_parameters.attrs["coords_mode__sample"] = "trivial"
77 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
78 self.dset_parameters.attrs["coords__parameter"] = to_learn
80 # Calculate the initial values of the loss and its gradient
81 self.loss[0] = self.loss_function(self.x[0])
82 self.loss[1].data = self.loss[0].data
84 self.grad[0].data = torch.autograd.grad(
85 self.loss[0], [self.x[0]], create_graph=False
86 )[0].data
87 self.grad[1].data = self.grad[0].data
89 def loss_function(self, input):
90 """Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation."""
92 if self.true_data.shape[0] - self.batch_size == 1:
93 start = 1
94 else:
95 start = torch.randint(
96 1, self.true_data.shape[0] - self.batch_size, (1,)
97 ).item()
99 densities = [self.true_data[start - 1]]
101 # Get the parameters: infection rate, recovery time, noise variance
102 p = (
103 input[self.to_learn["p_infect"]]
104 if "p_infect" in self.to_learn.keys()
105 else self.true_parameters["p_infect"]
106 )
107 t = (
108 30 * input[self.to_learn["t_infectious"]]
109 if "t_infectious" in self.to_learn.keys()
110 else self.true_parameters["t_infectious"]
111 )
112 sigma = (
113 input[self.to_learn["sigma"]]
114 if "sigma" in self.to_learn.keys()
115 else self.true_parameters["sigma"]
116 )
117 alpha = (
118 input[self.to_learn["alpha"]]
119 if "alpha" in self.to_learn.keys()
120 else self.true_parameters["alpha"]
121 )
123 for s in range(start, start + self.batch_size - 1):
124 # Recovery rate
125 tau = 1 / t * torch.sigmoid(1000 * (s / t - 1))
127 # Random noise
128 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0))
130 # Solve the ODE
131 densities.append(
132 torch.clip(
133 densities[-1]
134 + torch.stack(
135 [
136 (-p * densities[-1][0] - sigma * w) * densities[-1][1]
137 + 1 / (10000 + alpha),
138 (p * densities[-1][0] + sigma * w - tau) * densities[-1][1]
139 + 1 / (10000 + alpha),
140 tau * densities[-1][1] + 1 / (10000 + alpha),
141 ]
142 ),
143 0,
144 1,
145 )
146 )
148 densities = torch.stack(densities)
150 # Calculate loss
151 return torch.nn.functional.mse_loss(
152 densities, self.true_data[start : start + self.batch_size], reduction="sum"
153 )
155 def write_parameters(self):
156 if self.time > self.write_start and self.time % self.write_every == 0:
157 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
158 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy() * [
159 1,
160 30,
161 1,
162 ]