Coverage for models/SIR/Langevin.py: 32%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import logging 

2 

3log = logging.getLogger(__name__) 

4 

5import sys 

6from os.path import dirname as up 

7 

8import h5py as h5 

9import torch 

10from dantro._import_tools import import_module_from_path 

11 

12sys.path.append(up(up(__file__))) 

13sys.path.append(up(up(up(__file__)))) 

14 

15base = import_module_from_path(mod_path=up(up(up(__file__))), mod_str="include") 

16 

17 

18class SIR_Langevin_sampler(base.MetropolisAdjustedLangevin): 

19 """ 

20 A Metropolis-adjusted Langevin sampler that inherits from the base class 

21 """ 

22 

23 def __init__( 

24 self, 

25 *, 

26 true_data: torch.Tensor, 

27 prior: dict, 

28 lr: float = 1e-2, 

29 lr_final: float = 1e-4, 

30 max_itr: float = 1e4, 

31 beta: float = 0.99, 

32 Lambda: float = 1e-15, 

33 centered: bool = False, 

34 write_start: int = 1, 

35 write_every: int = 1, 

36 batch_size: int = 1, 

37 h5File: h5.File, 

38 to_learn: list, 

39 true_parameters: dict, 

40 **__, 

41 ): 

42 # Parameters to learn 

43 self.to_learn = {key: idx for idx, key in enumerate(to_learn)} 

44 self.true_parameters = { 

45 key: torch.tensor(val, dtype=torch.float) 

46 for key, val in true_parameters.items() 

47 } 

48 

49 # Draw an initial guess from the prior 

50 init_guess = base.random_tensor(prior, size=(len(to_learn),)) 

51 

52 super().__init__( 

53 true_data=true_data, 

54 init_guess=init_guess, 

55 lr=lr, 

56 lr_final=lr_final, 

57 max_itr=max_itr, 

58 beta=beta, 

59 Lambda=Lambda, 

60 centered=centered, 

61 write_start=write_start, 

62 write_every=write_every, 

63 batch_size=batch_size, 

64 h5File=h5File, 

65 ) 

66 

67 # Create datasets for the predictions 

68 self.dset_parameters = self.h5group.create_dataset( 

69 "parameters", 

70 (0, len(self.to_learn.keys())), 

71 maxshape=(None, len(self.to_learn.keys())), 

72 chunks=True, 

73 compression=3, 

74 ) 

75 self.dset_parameters.attrs["dim_names"] = ["sample", "parameter"] 

76 self.dset_parameters.attrs["coords_mode__sample"] = "trivial" 

77 self.dset_parameters.attrs["coords_mode__parameter"] = "values" 

78 self.dset_parameters.attrs["coords__parameter"] = to_learn 

79 

80 # Calculate the initial values of the loss and its gradient 

81 self.loss[0] = self.loss_function(self.x[0]) 

82 self.loss[1].data = self.loss[0].data 

83 

84 self.grad[0].data = torch.autograd.grad( 

85 self.loss[0], [self.x[0]], create_graph=False 

86 )[0].data 

87 self.grad[1].data = self.grad[0].data 

88 

89 def loss_function(self, input): 

90 """Calculates the loss (negative log-likelihood function) of a vector of parameters via simulation.""" 

91 

92 if self.true_data.shape[0] - self.batch_size == 1: 

93 start = 1 

94 else: 

95 start = torch.randint( 

96 1, self.true_data.shape[0] - self.batch_size, (1,) 

97 ).item() 

98 

99 densities = [self.true_data[start - 1]] 

100 

101 # Get the parameters: infection rate, recovery time, noise variance 

102 p = ( 

103 input[self.to_learn["p_infect"]] 

104 if "p_infect" in self.to_learn.keys() 

105 else self.true_parameters["p_infect"] 

106 ) 

107 t = ( 

108 30 * input[self.to_learn["t_infectious"]] 

109 if "t_infectious" in self.to_learn.keys() 

110 else self.true_parameters["t_infectious"] 

111 ) 

112 sigma = ( 

113 input[self.to_learn["sigma"]] 

114 if "sigma" in self.to_learn.keys() 

115 else self.true_parameters["sigma"] 

116 ) 

117 alpha = ( 

118 input[self.to_learn["alpha"]] 

119 if "alpha" in self.to_learn.keys() 

120 else self.true_parameters["alpha"] 

121 ) 

122 

123 for s in range(start, start + self.batch_size - 1): 

124 # Recovery rate 

125 tau = 1 / t * torch.sigmoid(1000 * (s / t - 1)) 

126 

127 # Random noise 

128 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0)) 

129 

130 # Solve the ODE 

131 densities.append( 

132 torch.clip( 

133 densities[-1] 

134 + torch.stack( 

135 [ 

136 (-p * densities[-1][0] - sigma * w) * densities[-1][1] 

137 + 1 / (10000 + alpha), 

138 (p * densities[-1][0] + sigma * w - tau) * densities[-1][1] 

139 + 1 / (10000 + alpha), 

140 tau * densities[-1][1] + 1 / (10000 + alpha), 

141 ] 

142 ), 

143 0, 

144 1, 

145 ) 

146 ) 

147 

148 densities = torch.stack(densities) 

149 

150 # Calculate loss 

151 return torch.nn.functional.mse_loss( 

152 densities, self.true_data[start : start + self.batch_size], reduction="sum" 

153 ) 

154 

155 def write_parameters(self): 

156 if self.time > self.write_start and self.time % self.write_every == 0: 

157 self.dset_parameters.resize(self.dset_parameters.shape[0] + 1, axis=0) 

158 self.dset_parameters[-1, :] = torch.flatten(self.x[0].detach()).numpy() * [ 

159 1, 

160 30, 

161 1, 

162 ]