Coverage for include/langevin.py: 14%

122 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import copy 

2 

3import h5py as h5 

4import torch 

5from torch.optim.optimizer import Optimizer 

6 

7 

8class pSGLD(Optimizer): 

9 """Implements pSGLD algorithm based on https://arxiv.org/pdf/1512.07666.pdf 

10 

11 Built on the PyTorch RMSprop implementation 

12 (https://pytorch.org/docs/stable/_modules/torch/optim/rmsprop.html#RMSprop) 

13 

14 Adapted from https://github.com/alisiahkoohi/Langevin-dynamics 

15 """ 

16 

17 def __init__( 

18 self, 

19 params, 

20 lr: float = 1e-2, 

21 beta: float = 0.99, 

22 Lambda: float = 1e-15, 

23 weight_decay: float = 0, 

24 centered: bool = False, 

25 ): 

26 """ 

27 Initializes the pSGLD optimizer. 

28 

29 Args: 

30 params (iterable): Iterable of parameters to optimize. 

31 lr (float, optional): Learning rate. Default is 1e-2. 

32 beta (float, optional): Exponential moving average coefficient. 

33 Default is 0.99. 

34 Lambda (float, optional): Epsilon value. Default is 1e-15. 

35 weight_decay (float, optional): Weight decay coefficient. Default 

36 is 0. 

37 centered (bool, optional): Whether to use centered gradients. 

38 Default is False. 

39 """ 

40 if not 0.0 <= lr: 

41 raise ValueError(f"Invalid learning rate: {lr}") 

42 if not 0.0 <= Lambda: 

43 raise ValueError(f"Invalid epsilon value: {Lambda}") 

44 if not 0.0 <= weight_decay: 

45 raise ValueError(f"Invalid weight_decay value: {weight_decay}") 

46 if not 0.0 <= beta: 

47 raise ValueError(f"Invalid beta value: {beta}") 

48 

49 defaults = dict( 

50 lr=lr, 

51 beta=beta, 

52 Lambda=Lambda, 

53 centered=centered, 

54 weight_decay=weight_decay, 

55 ) 

56 super().__init__(params, defaults) 

57 

58 def __setstate__(self, state): 

59 super().__setstate__(state) 

60 for group in self.param_groups: 

61 group.setdefault("centered", False) 

62 

63 def step(self, closure=None): 

64 """Performs a single optimization step. 

65 

66 Args: 

67 closure (callable, optional): A closure that reevaluates the model 

68 and returns the loss. 

69 

70 Returns: 

71 float: Value of G (as defined in the algorithm) after the step. 

72 """ 

73 loss = None 

74 if closure is not None: 

75 loss = closure() 

76 for group in self.param_groups: 

77 for p in group["params"]: 

78 if p.grad is None: 

79 continue 

80 grad = p.grad.data 

81 if grad.is_sparse: 

82 raise RuntimeError("pSGLD does not support sparse gradients") 

83 state = self.state[p] 

84 

85 # State initialization 

86 if len(state) == 0: 

87 state["step"] = 0 

88 state["V"] = torch.zeros_like(p.data) 

89 if group["centered"]: 

90 state["grad_avg"] = torch.zeros_like(p.data) 

91 

92 V = state["V"] 

93 beta = group["beta"] 

94 state["step"] += 1 

95 

96 if group["weight_decay"] != 0: 

97 grad = grad.add(group["weight_decay"], p.data) 

98 

99 V.mul_(beta).addcmul_(grad, grad, value=1 - beta) 

100 

101 if group["centered"]: 

102 grad_avg = state["grad_avg"] 

103 grad_avg.mul_(beta).add_(1 - beta, grad) 

104 G = ( 

105 V.addcmul(grad_avg, grad_avg, value=-1) 

106 .sqrt_() 

107 .add_(group["Lambda"]) 

108 ) 

109 else: 

110 G = V.sqrt().add_(group["Lambda"]) 

111 p.data.addcdiv_(grad, G, value=-group["lr"]) 

112 

113 noise_std = 2 * group["lr"] / G 

114 noise_std = noise_std.sqrt() 

115 noise = p.data.new(p.data.size()).normal_(mean=0, std=1) * noise_std 

116 p.data.add_(noise) 

117 

118 # Only consider absolute values 

119 p.data.abs_() 

120 

121 return G 

122 

123 

124class MetropolisAdjustedLangevin: 

125 """ 

126 A class implementing the Metropolis-Adjusted Langevin algorithm. Adapted from 

127 https://github.com/alisiahkoohi/Langevin-dynamics 

128 

129 Args: 

130 true_data (torch.Tensor): training data 

131 init_guess (torch.Tensor): Initial input tensor. 

132 lr (float, optional): Initial learning rate. Default is 1e-2. 

133 lr_final (float, optional): Final learning rate. Default is 1e-4. 

134 max_itr (float, optional): Maximum number of iterations. Default is 

135 1e4. 

136 """ 

137 

138 def __init__( 

139 self, 

140 *, 

141 true_data: torch.Tensor, 

142 init_guess: torch.Tensor, 

143 lr: float = 1e-2, 

144 lr_final: float = 1e-4, 

145 max_itr: float = 1e4, 

146 beta: float = 0.99, 

147 Lambda: float = 1e-15, 

148 centered: bool = False, 

149 write_start: int = 1, 

150 write_every: int = 1, 

151 batch_size: int = 1, 

152 h5File: h5.File, 

153 **__, 

154 ): 

155 super().__init__() 

156 

157 # Training data 

158 self.true_data = true_data 

159 

160 # Burn-in 

161 self.write_start = write_start 

162 

163 # Thinning 

164 self.write_every = write_every 

165 

166 # Batch size 

167 self.batch_size = batch_size 

168 

169 # Create an h5 Group for the langevin estimation 

170 self.h5group = h5File.require_group("langevin_data") 

171 

172 # Dataset for the log-likelihood 

173 self.dset_loss = self.h5group.create_dataset( 

174 "loss", 

175 (0,), 

176 maxshape=(None,), 

177 chunks=True, 

178 compression=3, 

179 ) 

180 self.dset_loss.attrs["dim_names"] = ["sample"] 

181 self.dset_loss.attrs["coords_mode__sample"] = "trivial" 

182 

183 # Track the total time required to run the samples 

184 self.dset_time = self.h5group.create_dataset( 

185 "time", 

186 (0,), 

187 maxshape=(1,), 

188 chunks=True, 

189 compression=3, 

190 ) 

191 self.dset_time.attrs["dim_names"] = ["time"] 

192 self.dset_time.attrs["coords_mode__time"] = "trivial" 

193 

194 # Set the initial guess 

195 self.x = [ 

196 torch.zeros(init_guess.shape, device=init_guess.device, requires_grad=True), 

197 torch.zeros(init_guess.shape, device=init_guess.device, requires_grad=True), 

198 ] 

199 self.x[0].data = init_guess.data.clone() 

200 self.x[1].data = init_guess.data.clone() 

201 

202 # Loss container 

203 self.loss = [ 

204 torch.zeros([1], device=init_guess.device), 

205 torch.zeros([1], device=init_guess.device), 

206 ] 

207 

208 # Gradient container 

209 self.grad = [ 

210 torch.zeros(init_guess.shape, device=init_guess.device), 

211 torch.zeros(init_guess.shape, device=init_guess.device), 

212 ] 

213 

214 # Optimizer 

215 self.optim = pSGLD( 

216 [self.x[1]], 

217 lr, 

218 weight_decay=0.0, 

219 beta=beta, 

220 Lambda=Lambda, 

221 centered=centered, 

222 ) 

223 self.P = [ 

224 torch.ones(init_guess.shape, device=init_guess.device, requires_grad=False), 

225 torch.ones(init_guess.shape, device=init_guess.device, requires_grad=False), 

226 ] 

227 

228 self.lr = lr 

229 self.lr_final = lr_final 

230 self.max_itr = max_itr 

231 self.lr_fn = self.decay_fn(lr=lr, lr_final=lr_final, max_itr=max_itr) 

232 self.time = 0 

233 

234 def sample(self, *, force_accept: bool = False) -> tuple: 

235 """ 

236 Perform a Metropolis-Hastings step to generate a sample. The sample can be force-accepted, e.g. if it is 

237 the first sample. 

238 

239 Returns: 

240 tuple: A tuple containing the sampled input tensor and 

241 corresponding loss value. 

242 """ 

243 accepted = False 

244 self.lr_decay() 

245 

246 while not accepted: 

247 self.x[1].grad = self.grad[1].data 

248 self.optim.step() 

249 self.P[1] = self.optim.step() 

250 self.loss[1] = self.loss_function(self.x[1]) 

251 self.grad[1].data = torch.autograd.grad( 

252 self.loss[1], [self.x[1]], create_graph=False 

253 )[0].data 

254 alpha = min([1.0, self.sample_prob()]) 

255 if torch.rand([1]) <= alpha or force_accept: 

256 self.grad[0].data = self.grad[1].data 

257 self.loss[0].data = self.loss[1].data 

258 self.x[0].data = self.x[1].data 

259 self.P[0].data = self.P[1].data 

260 accepted = True 

261 

262 else: 

263 self.x[1].data = self.x[0].data 

264 self.P[1].data = self.P[0].data 

265 

266 self.time += 1 

267 

268 return copy.deepcopy(self.x[1].data), self.loss[1].item() 

269 

270 def proposal_dist(self, idx: int) -> torch.Tensor: 

271 """ 

272 Calculate the proposal distribution for Metropolis-Hastings. 

273 

274 Args: 

275 idx (int): Index of the current tensor. 

276 

277 Returns: 

278 torch.Tensor: The proposal distribution. 

279 """ 

280 return ( 

281 -(0.25 / self.lr_fn(self.time)) 

282 * ( 

283 self.x[idx] 

284 - self.x[idx ^ 1] 

285 - self.lr_fn(self.time) * self.grad[idx ^ 1] / self.P[1] 

286 ) 

287 * self.P[1] 

288 @ ( 

289 self.x[idx] 

290 - self.x[idx ^ 1] 

291 - self.lr_fn(self.time) * self.grad[idx ^ 1] / self.P[1] 

292 ) 

293 ) 

294 

295 def sample_prob(self) -> torch.Tensor: 

296 """ 

297 Calculate the acceptance probability for Metropolis-Hastings. 

298 

299 Returns: 

300 torch.Tensor: The acceptance probability. 

301 """ 

302 return torch.exp(-self.loss[1] + self.loss[0]) * torch.exp( 

303 self.proposal_dist(0) - self.proposal_dist(1) 

304 ) 

305 

306 def decay_fn( 

307 self, lr: float = 1e-2, lr_final: float = 1e-4, max_itr: float = 1e4 

308 ) -> callable: 

309 """ 

310 Generate a learning rate decay function. 

311 

312 Args: 

313 lr (float, optional): Initial learning rate. Default is 1e-2. 

314 lr_final (float, optional): Final learning rate. Default is 1e-4. 

315 max_itr (float, optional): Maximum number of iterations. Default is 

316 1e4. 

317 

318 Returns: 

319 callable: Learning rate decay function. 

320 """ 

321 gamma = -0.55 

322 b = max_itr / ((lr_final / lr) ** (1 / gamma) - 1.0) 

323 a = lr / (b**gamma) 

324 

325 def lr_fn(t: float, a: float = a, b: float = b, gamma: float = gamma) -> float: 

326 return a * ((b + t) ** gamma) 

327 

328 return lr_fn 

329 

330 def lr_decay(self): 

331 """ 

332 Decay the learning rate of the optimizer. 

333 """ 

334 for param_group in self.optim.param_groups: 

335 param_group["lr"] = self.lr_fn(self.time) 

336 

337 def write_loss(self): 

338 """ 

339 Write out the loss. 

340 """ 

341 if self.time > self.write_start and self.time % self.write_every == 0: 

342 self.dset_loss.resize(self.dset_loss.shape[0] + 1, axis=0) 

343 self.dset_loss[-1] = self.loss[0].detach().numpy() 

344 

345 def write_time(self, time): 

346 """ 

347 Write out the current time 

348 """ 

349 self.dset_time.resize(self.dset_time.shape[0] + 1, axis=0) 

350 self.dset_time[-1] = time