Coverage for models / SIR / ensemble_training / Langevin.py: 29%

42 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import sys 

2from os.path import dirname as up 

3 

4import h5py as h5 

5import torch 

6from dantro._import_tools import import_module_from_path 

7 

8sys.path.append(up(up(__file__))) 

9sys.path.append(up(up(up(__file__)))) 

10 

11base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

12 

13 

14class SIR_Langevin_sampler(base.MetropolisAdjustedLangevin): 

15 """ 

16 A Metropolis-adjusted Langevin sampler that inherits from the base class 

17 """ 

18 

19 def __init__( 

20 self, 

21 *, 

22 true_data: torch.Tensor, 

23 prior: dict, 

24 lr: float = 1e-2, 

25 lr_final: float = 1e-4, 

26 max_itr: float = 1e4, 

27 beta: float = 0.99, 

28 Lambda: float = 1e-15, 

29 centered: bool = False, 

30 write_start: int = 1, 

31 write_every: int = 1, 

32 batch_size: int = 1, 

33 h5File: h5.File, 

34 to_learn: list, 

35 true_parameters: dict, 

36 **__, 

37 ): 

38 # Parameters to learn 

39 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

40 self.true_parameters = { 

41 key: torch.tensor(val, dtype=torch.float) 

42 for key, val in true_parameters.items() 

43 } 

44 

45 # Draw an initial guess from the prior 

46 init_guess = base.random_tensor(prior, size=(len(to_learn),)) 

47 

48 super().__init__( 

49 true_data=true_data, 

50 init_guess=init_guess, 

51 lr=lr, 

52 lr_final=lr_final, 

53 max_itr=max_itr, 

54 beta=beta, 

55 Lambda=Lambda, 

56 centered=centered, 

57 write_start=write_start, 

58 write_every=write_every, 

59 batch_size=batch_size, 

60 h5File=h5File, 

61 ) 

62 

63 # Create datasets for the predictions 

64 self.dset_parameters = self.h5group.create_dataset( 

65 "parameters", 

66 (0, len(self.to_learn.keys())), 

67 maxshape=(None, len(self.to_learn.keys())), 

68 chunks=True, 

69 compression=3, 

70 ) 

71 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"] 

72 self.dset_parameters.attrs["coords_mode__sample"] = "trivial" 

73 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

74 self.dset_parameters.attrs["coords__parameter"] = to_learn 

75 

76 # Calculate the initial values of the loss and its gradient 

77 self.loss[0] = self.loss_function(self.x[0]) 

78 self.loss[1].data = self.loss[0].data 

79 

80 self.grad[0].data = torch.autograd.grad( 

81 self.loss[0], [self.x[0]], create_graph=False 

82 )[0].data 

83 self.grad[1].data = self.grad[0].data 

84 

85 def loss_function(self, input): 

86 """Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation.""" 

87 

88 if self.true_data.shape[0] - self.batch_size == 1: 

89 start = 1 

90 else: 

91 start = torch.randint( 

92 1, self.true_data.shape[0] - self.batch_size, (1,) 

93 ).item() 

94 

95 densities = [self.true_data[start - 1]] 

96 

97 # Get the parameters: infection rate, recovery time, noise variance 

98 p = ( 

99 input[self.to_learn["p_infect"]] 

100 if "p_infect" in self.to_learn.keys() 

101 else self.true_parameters["p_infect"] 

102 ) 

103 t = ( 

104 30 * input[self.to_learn["t_infectious"]] 

105 if "t_infectious" in self.to_learn.keys() 

106 else self.true_parameters["t_infectious"] 

107 ) 

108 sigma = ( 

109 input[self.to_learn["sigma"]] 

110 if "sigma" in self.to_learn.keys() 

111 else self.true_parameters["sigma"] 

112 ) 

113 alpha = ( 

114 input[self.to_learn["alpha"]] 

115 if "alpha" in self.to_learn.keys() 

116 else self.true_parameters["alpha"] 

117 ) 

118 

119 for s in range(start, start + self.batch_size - 1): 

120 # Recovery rate 

121 tau = 1 / t * torch.sigmoid(1000 * (s / t - 1)) 

122 

123 # Random noise 

124 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0)) 

125 

126 # Solve the ODE 

127 densities.append( 

128 torch.clip( 

129 densities[-1] 

130 + torch.stack( 

131 [ 

132 (-p * densities[-1][0] - sigma * w) * densities[-1][1] 

133 + 1 / (10000 + alpha), 

134 (p * densities[-1][0] + sigma * w - tau) * densities[-1][1] 

135 + 1 / (10000 + alpha), 

136 tau * densities[-1][1] + 1 / (10000 + alpha), 

137 ] 

138 ), 

139 0, 

140 1, 

141 ) 

142 ) 

143 

144 densities = torch.stack(densities) 

145 

146 # Calculate loss 

147 return torch.nn.functional.mse_loss( 

148 densities, self.true_data[start : start + self.batch_size], reduction="sum" 

149 ) 

150 

151 def write_parameters(self): 

152 if self.time > self.write_start and self.time % self.write_every == 0: 

153 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

154 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy() * [ 

155 1, 

156 30, 

157 1, 

158 ]