Coverage for models / SIR / ensemble_training / Langevin.py: 29%
42 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import sys
2from os.path import dirname as up
4import h5py as h5
5import torch
6from dantro._import_tools import import_module_from_path
8sys.path.append(up(up(__file__)))
9sys.path.append(up(up(up(__file__))))
11base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include")
14class SIR_Langevin_sampler(base.MetropolisAdjustedLangevin):
15 """
16 A Metropolis-adjusted Langevin sampler that inherits from the base class
17 """
19 def __init__(
20 self,
21 *,
22 true_data: torch.Tensor,
23 prior: dict,
24 lr: float = 1e-2,
25 lr_final: float = 1e-4,
26 max_itr: float = 1e4,
27 beta: float = 0.99,
28 Lambda: float = 1e-15,
29 centered: bool = False,
30 write_start: int = 1,
31 write_every: int = 1,
32 batch_size: int = 1,
33 h5File: h5.File,
34 to_learn: list,
35 true_parameters: dict,
36 **__,
37 ):
38 # Parameters to learn
39 self.to_learn = {key: idx for idx, key in enumerate(to_learn)}
40 self.true_parameters = {
41 key: torch.tensor(val, dtype=torch.float)
42 for key, val in true_parameters.items()
43 }
45 # Draw an initial guess from the prior
46 init_guess = base.random_tensor(prior, size=(len(to_learn),))
48 super().__init__(
49 true_data=true_data,
50 init_guess=init_guess,
51 lr=lr,
52 lr_final=lr_final,
53 max_itr=max_itr,
54 beta=beta,
55 Lambda=Lambda,
56 centered=centered,
57 write_start=write_start,
58 write_every=write_every,
59 batch_size=batch_size,
60 h5File=h5File,
61 )
63 # Create datasets for the predictions
64 self.dset_parameters = self.h5group.create_dataset(
65 "parameters",
66 (0, len(self.to_learn.keys())),
67 maxshape=(None, len(self.to_learn.keys())),
68 chunks=True,
69 compression=3,
70 )
71 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"]
72 self.dset_parameters.attrs["coords_mode__sample"] = "trivial"
73 self.dset_parameters.attrs["coords_mode__parameter"] = "values"
74 self.dset_parameters.attrs["coords__parameter"] = to_learn
76 # Calculate the initial values of the loss and its gradient
77 self.loss[0] = self.loss_function(self.x[0])
78 self.loss[1].data = self.loss[0].data
80 self.grad[0].data = torch.autograd.grad(
81 self.loss[0], [self.x[0]], create_graph=False
82 )[0].data
83 self.grad[1].data = self.grad[0].data
85 def loss_function(self, input):
86 """Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation."""
88 if self.true_data.shape[0] - self.batch_size == 1:
89 start = 1
90 else:
91 start = torch.randint(
92 1, self.true_data.shape[0] - self.batch_size, (1,)
93 ).item()
95 densities = [self.true_data[start - 1]]
97 # Get the parameters: infection rate, recovery time, noise variance
98 p = (
99 input[self.to_learn["p_infect"]]
100 if "p_infect" in self.to_learn.keys()
101 else self.true_parameters["p_infect"]
102 )
103 t = (
104 30 * input[self.to_learn["t_infectious"]]
105 if "t_infectious" in self.to_learn.keys()
106 else self.true_parameters["t_infectious"]
107 )
108 sigma = (
109 input[self.to_learn["sigma"]]
110 if "sigma" in self.to_learn.keys()
111 else self.true_parameters["sigma"]
112 )
113 alpha = (
114 input[self.to_learn["alpha"]]
115 if "alpha" in self.to_learn.keys()
116 else self.true_parameters["alpha"]
117 )
119 for s in range(start, start + self.batch_size - 1):
120 # Recovery rate
121 tau = 1 / t * torch.sigmoid(1000 * (s / t - 1))
123 # Random noise
124 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0))
126 # Solve the ODE
127 densities.append(
128 torch.clip(
129 densities[-1]
130 + torch.stack(
131 [
132 (-p * densities[-1][0] - sigma * w) * densities[-1][1]
133 + 1 / (10000 + alpha),
134 (p * densities[-1][0] + sigma * w - tau) * densities[-1][1]
135 + 1 / (10000 + alpha),
136 tau * densities[-1][1] + 1 / (10000 + alpha),
137 ]
138 ),
139 0,
140 1,
141 )
142 )
144 densities = torch.stack(densities)
146 # Calculate loss
147 return torch.nn.functional.mse_loss(
148 densities, self.true_data[start : start + self.batch_size], reduction="sum"
149 )
151 def write_parameters(self):
152 if self.time > self.write_start and self.time % self.write_every == 0:
153 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0)
154 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy() * [
155 1,
156 30,
157 1,
158 ]