Coverage for include/langevin.py: 14%
122 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import copy
3import h5py as h5
4import torch
5from torch.optim.optimizer import Optimizer
8class pSGLD(Optimizer):
9 """Implements pSGLD algorithm based on https://arxiv.org/pdf/1512.07666.pdf
11 Built on the PyTorch RMSprop implementation
12 (https://pytorch.org/docs/stable/_modules/torch/optim/rmsprop.html#RMSprop)
14 Adapted from https://github.com/alisiahkoohi/Langevin-dynamics
15 """
17 def __init__(
18 self,
19 params,
20 lr: float = 1e-2,
21 beta: float = 0.99,
22 Lambda: float = 1e-15,
23 weight_decay: float = 0,
24 centered: bool = False,
25 ):
26 """
27 Initializes the pSGLD optimizer.
29 Args:
30 params (iterable): Iterable of parameters to optimize.
31 lr (float, optional): Learning rate. Default is 1e-2.
32 beta (float, optional): Exponential moving average coefficient.
33 Default is 0.99.
34 Lambda (float, optional): Epsilon value. Default is 1e-15.
35 weight_decay (float, optional): Weight decay coefficient. Default
36 is 0.
37 centered (bool, optional): Whether to use centered gradients.
38 Default is False.
39 """
40 if not 0.0 <= lr:
41 raise ValueError(f"Invalid learning rate: {lr}")
42 if not 0.0 <= Lambda:
43 raise ValueError(f"Invalid epsilon value: {Lambda}")
44 if not 0.0 <= weight_decay:
45 raise ValueError(f"Invalid weight_decay value: {weight_decay}")
46 if not 0.0 <= beta:
47 raise ValueError(f"Invalid beta value: {beta}")
49 defaults = dict(
50 lr=lr,
51 beta=beta,
52 Lambda=Lambda,
53 centered=centered,
54 weight_decay=weight_decay,
55 )
56 super().__init__(params, defaults)
58 def __setstate__(self, state):
59 super().__setstate__(state)
60 for group in self.param_groups:
61 group.setdefault("centered", False)
63 def step(self, closure=None):
64 """Performs a single optimization step.
66 Args:
67 closure (callable, optional): A closure that reevaluates the model
68 and returns the loss.
70 Returns:
71 float: Value of G (as defined in the algorithm) after the step.
72 """
73 loss = None
74 if closure is not None:
75 loss = closure()
76 for group in self.param_groups:
77 for p in group["params"]:
78 if p.grad is None:
79 continue
80 grad = p.grad.data
81 if grad.is_sparse:
82 raise RuntimeError("pSGLD does not support sparse gradients")
83 state = self.state[p]
85 # State initialization
86 if len(state) == 0:
87 state["step"] = 0
88 state["V"] = torch.zeros_like(p.data)
89 if group["centered"]:
90 state["grad_avg"] = torch.zeros_like(p.data)
92 V = state["V"]
93 beta = group["beta"]
94 state["step"] += 1
96 if group["weight_decay"] != 0:
97 grad = grad.add(group["weight_decay"], p.data)
99 V.mul_(beta).addcmul_(grad, grad, value=1 - beta)
101 if group["centered"]:
102 grad_avg = state["grad_avg"]
103 grad_avg.mul_(beta).add_(1 - beta, grad)
104 G = (
105 V.addcmul(grad_avg, grad_avg, value=-1)
106 .sqrt_()
107 .add_(group["Lambda"])
108 )
109 else:
110 G = V.sqrt().add_(group["Lambda"])
111 p.data.addcdiv_(grad, G, value=-group["lr"])
113 noise_std = 2 * group["lr"] / G
114 noise_std = noise_std.sqrt()
115 noise = p.data.new(p.data.size()).normal_(mean=0, std=1) * noise_std
116 p.data.add_(noise)
118 # Only consider absolute values
119 p.data.abs_()
121 return G
124class MetropolisAdjustedLangevin:
125 """
126 A class implementing the Metropolis-Adjusted Langevin algorithm. Adapted from
127 https://github.com/alisiahkoohi/Langevin-dynamics
129 Args:
130 true_data (torch.Tensor): training data
131 init_guess (torch.Tensor): Initial input tensor.
132 lr (float, optional): Initial learning rate. Default is 1e-2.
133 lr_final (float, optional): Final learning rate. Default is 1e-4.
134 max_itr (float, optional): Maximum number of iterations. Default is
135 1e4.
136 """
138 def __init__(
139 self,
140 *,
141 true_data: torch.Tensor,
142 init_guess: torch.Tensor,
143 lr: float = 1e-2,
144 lr_final: float = 1e-4,
145 max_itr: float = 1e4,
146 beta: float = 0.99,
147 Lambda: float = 1e-15,
148 centered: bool = False,
149 write_start: int = 1,
150 write_every: int = 1,
151 batch_size: int = 1,
152 h5File: h5.File,
153 **__,
154 ):
155 super().__init__()
157 # Training data
158 self.true_data = true_data
160 # Burn-in
161 self.write_start = write_start
163 # Thinning
164 self.write_every = write_every
166 # Batch size
167 self.batch_size = batch_size
169 # Create an h5 Group for the langevin estimation
170 self.h5group = h5File.require_group("langevin_data")
172 # Dataset for the log-likelihood
173 self.dset_loss = self.h5group.create_dataset(
174 "loss",
175 (0,),
176 maxshape=(None,),
177 chunks=True,
178 compression=3,
179 )
180 self.dset_loss.attrs["dim_names"] = ["sample"]
181 self.dset_loss.attrs["coords_mode__sample"] = "trivial"
183 # Track the total time required to run the samples
184 self.dset_time = self.h5group.create_dataset(
185 "time",
186 (0,),
187 maxshape=(1,),
188 chunks=True,
189 compression=3,
190 )
191 self.dset_time.attrs["dim_names"] = ["time"]
192 self.dset_time.attrs["coords_mode__time"] = "trivial"
194 # Set the initial guess
195 self.x = [
196 torch.zeros(init_guess.shape, device=init_guess.device, requires_grad=True),
197 torch.zeros(init_guess.shape, device=init_guess.device, requires_grad=True),
198 ]
199 self.x[0].data = init_guess.data.clone()
200 self.x[1].data = init_guess.data.clone()
202 # Loss container
203 self.loss = [
204 torch.zeros([1], device=init_guess.device),
205 torch.zeros([1], device=init_guess.device),
206 ]
208 # Gradient container
209 self.grad = [
210 torch.zeros(init_guess.shape, device=init_guess.device),
211 torch.zeros(init_guess.shape, device=init_guess.device),
212 ]
214 # Optimizer
215 self.optim = pSGLD(
216 [self.x[1]],
217 lr,
218 weight_decay=0.0,
219 beta=beta,
220 Lambda=Lambda,
221 centered=centered,
222 )
223 self.P = [
224 torch.ones(init_guess.shape, device=init_guess.device, requires_grad=False),
225 torch.ones(init_guess.shape, device=init_guess.device, requires_grad=False),
226 ]
228 self.lr = lr
229 self.lr_final = lr_final
230 self.max_itr = max_itr
231 self.lr_fn = self.decay_fn(lr=lr, lr_final=lr_final, max_itr=max_itr)
232 self.time = 0
234 def sample(self, *, force_accept: bool = False) -> tuple:
235 """
236 Perform a Metropolis-Hastings step to generate a sample. The sample can be force-accepted, e.g. if it is
237 the first sample.
239 Returns:
240 tuple: A tuple containing the sampled input tensor and
241 corresponding loss value.
242 """
243 accepted = False
244 self.lr_decay()
246 while not accepted:
247 self.x[1].grad = self.grad[1].data
248 self.optim.step()
249 self.P[1] = self.optim.step()
250 self.loss[1] = self.loss_function(self.x[1])
251 self.grad[1].data = torch.autograd.grad(
252 self.loss[1], [self.x[1]], create_graph=False
253 )[0].data
254 alpha = min([1.0, self.sample_prob()])
255 if torch.rand([1]) <= alpha or force_accept:
256 self.grad[0].data = self.grad[1].data
257 self.loss[0].data = self.loss[1].data
258 self.x[0].data = self.x[1].data
259 self.P[0].data = self.P[1].data
260 accepted = True
262 else:
263 self.x[1].data = self.x[0].data
264 self.P[1].data = self.P[0].data
266 self.time += 1
268 return copy.deepcopy(self.x[1].data), self.loss[1].item()
270 def proposal_dist(self, idx: int) -> torch.Tensor:
271 """
272 Calculate the proposal distribution for Metropolis-Hastings.
274 Args:
275 idx (int): Index of the current tensor.
277 Returns:
278 torch.Tensor: The proposal distribution.
279 """
280 return (
281 -(0.25 / self.lr_fn(self.time))
282 * (
283 self.x[idx]
284 - self.x[idx ^ 1]
285 - self.lr_fn(self.time) * self.grad[idx ^ 1] / self.P[1]
286 )
287 * self.P[1]
288 @ (
289 self.x[idx]
290 - self.x[idx ^ 1]
291 - self.lr_fn(self.time) * self.grad[idx ^ 1] / self.P[1]
292 )
293 )
295 def sample_prob(self) -> torch.Tensor:
296 """
297 Calculate the acceptance probability for Metropolis-Hastings.
299 Returns:
300 torch.Tensor: The acceptance probability.
301 """
302 return torch.exp(-self.loss[1] + self.loss[0]) * torch.exp(
303 self.proposal_dist(0) - self.proposal_dist(1)
304 )
306 def decay_fn(
307 self, lr: float = 1e-2, lr_final: float = 1e-4, max_itr: float = 1e4
308 ) -> callable:
309 """
310 Generate a learning rate decay function.
312 Args:
313 lr (float, optional): Initial learning rate. Default is 1e-2.
314 lr_final (float, optional): Final learning rate. Default is 1e-4.
315 max_itr (float, optional): Maximum number of iterations. Default is
316 1e4.
318 Returns:
319 callable: Learning rate decay function.
320 """
321 gamma = -0.55
322 b = max_itr / ((lr_final / lr) ** (1 / gamma) - 1.0)
323 a = lr / (b**gamma)
325 def lr_fn(t: float, a: float = a, b: float = b, gamma: float = gamma) -> float:
326 return a * ((b + t) ** gamma)
328 return lr_fn
330 def lr_decay(self):
331 """
332 Decay the learning rate of the optimizer.
333 """
334 for param_group in self.optim.param_groups:
335 param_group["lr"] = self.lr_fn(self.time)
337 def write_loss(self):
338 """
339 Write out the loss.
340 """
341 if self.time > self.write_start and self.time % self.write_every == 0:
342 self.dset_loss.resize(self.dset_loss.shape[0] + 1, axis=0)
343 self.dset_loss[-1] = self.loss[0].detach().numpy()
345 def write_time(self, time):
346 """
347 Write out the current time
348 """
349 self.dset_time.resize(self.dset_time.shape[0] + 1, axis=0)
350 self.dset_time[-1] = time