Coverage for models / Covid / ensemble_training / DataGeneration.py: 11%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import logging 

2 

3import h5py as h5 

4import numpy as np 

5import torch 

6 

7from .kinds import Compartments 

8 

9log = logging.getLogger(__name__) 

10 

11def generate_smooth_data( 

12 cfg, *, parameters=None, init_state: torch.Tensor = None) -> torch.Tensor: 

13 """Generates a dataset of counts for each compartment by iteratively solving the system of differential equations. 

14 

15 :param cfg: configuration file, containing parameter values (possibly as a ``Sequence``, if time-dependent), 

16 number of steps, burn-in period, etc. 

17 :param parameters: (optional) parameters used to override cfg settings 

18 :param init_state: (optional) initial state to use; defaults to a generic density if ``None`` 

19 :return: ``torch.Tensor`` training dataset, with the burn-in period discarded 

20 """ 

21 

22 # Get config settings 

23 num_steps: int = cfg["num_steps"] 

24 burn_in: int = cfg.get("burn_in", 0) 

25 dt: float = cfg["dt"] 

26 k_q: float = cfg.get("k_q", 10.25) 

27 

28 # Use a generic initial state if None passed 

29 if init_state is None: 

30 init_state = torch.zeros(12, 1, dtype=torch.float) 

31 init_state[ 

32 Compartments.susceptible.value 

33 ] = 0.9933 # High number of susceptible agents 

34 init_state[Compartments.infected.value] = ( 

35 1.0 - init_state[Compartments.susceptible.value] 

36 ) # Some infected agents 

37 

38 # Empty dataset for counts: the initial state is always written 

39 data = torch.empty((num_steps + burn_in, 12, 1), dtype=torch.float) 

40 data[0, :] = init_state 

41 

42 # Get the model parameters; these can be overridden with the ``parameters`` argument 

43 k_S = ( 

44 torch.tensor(cfg["k_S"], dtype=torch.float) 

45 if parameters is None 

46 else parameters[Compartments.susceptible.value] 

47 ) 

48 k_E = ( 

49 torch.tensor(cfg["k_E"], dtype=torch.float) 

50 if parameters is None 

51 else parameters[Compartments.exposed.value] 

52 ) 

53 k_I = ( 

54 torch.tensor(cfg["k_I"], dtype=torch.float) 

55 if parameters is None 

56 else parameters[Compartments.infected.value] 

57 ) 

58 k_R = ( 

59 torch.tensor(cfg["k_R"], dtype=torch.float) 

60 if parameters is None 

61 else parameters[Compartments.recovered.value] 

62 ) 

63 k_SY = ( 

64 torch.tensor(cfg["k_SY"], dtype=torch.float) 

65 if parameters is None 

66 else parameters[Compartments.symptomatic.value] 

67 ) 

68 k_H = ( 

69 torch.tensor(cfg["k_H"], dtype=torch.float) 

70 if parameters is None 

71 else parameters[Compartments.hospitalized.value] 

72 ) 

73 k_C = ( 

74 torch.tensor(cfg["k_C"], dtype=torch.float) 

75 if parameters is None 

76 else parameters[Compartments.critical.value] 

77 ) 

78 k_D = ( 

79 torch.tensor(cfg["k_D"], dtype=torch.float) 

80 if parameters is None 

81 else parameters[Compartments.deceased.value] 

82 ) 

83 k_CT = ( 

84 torch.tensor(cfg["k_CT"], dtype=torch.float) 

85 if parameters is None 

86 else parameters[Compartments.contact_traced.value] 

87 ) 

88 

89 # Solve the ODE 

90 for t in range(1, num_steps + burn_in): 

91 # Get the time-dependent parameters, if given 

92 k_S_t = k_S[t] if k_S.dim() > 0 else k_S 

93 k_E_t = k_E[t] if k_E.dim() > 0 else k_E 

94 k_I_t = k_I[t] if k_I.dim() > 0 else k_I 

95 k_R_t = k_R[t] if k_R.dim() > 0 else k_R 

96 k_SY_t = k_SY[t] if k_SY.dim() > 0 else k_SY 

97 k_H_t = k_H[t] if k_H.dim() > 0 else k_H 

98 k_C_t = k_C[t] if k_C.dim() > 0 else k_C 

99 k_D_t = k_D[t] if k_D.dim() > 0 else k_D 

100 k_CT_t = k_CT[t] if k_CT.dim() > 0 else k_CT 

101 

102 # Calculate k_Q 

103 k_Q_t = k_q * k_CT_t * data[t - 1][Compartments.contact_traced.value] 

104 

105 dy = torch.stack( 

106 [ 

107 (-k_E_t * data[t - 1][Compartments.infected.value] - k_Q_t) 

108 * data[t - 1][Compartments.susceptible.value] 

109 + k_S_t * data[t - 1][Compartments.quarantine_S.value], 

110 k_E_t 

111 * data[t - 1][Compartments.susceptible.value] 

112 * data[t - 1][Compartments.infected.value] 

113 - (k_I_t + k_Q_t) * data[t - 1][Compartments.exposed.value], 

114 k_I_t * data[t - 1][Compartments.exposed.value] 

115 - (k_R_t + k_SY_t + k_Q_t) * data[t - 1][Compartments.infected.value], 

116 k_R_t 

117 * ( 

118 data[t - 1][Compartments.infected.value] 

119 + data[t - 1][Compartments.symptomatic.value] 

120 + data[t - 1][Compartments.hospitalized.value] 

121 + data[t - 1][Compartments.critical.value] 

122 + data[t - 1][Compartments.quarantine_I.value] 

123 ), 

124 k_SY_t 

125 * ( 

126 data[t - 1][Compartments.infected.value] 

127 + data[t - 1][Compartments.quarantine_I.value] 

128 ) 

129 - (k_R_t + k_H_t) * data[t - 1][Compartments.symptomatic.value], 

130 k_H_t * data[t - 1][Compartments.symptomatic.value] 

131 - (k_R_t + k_C_t) * data[t - 1][Compartments.hospitalized.value], 

132 k_C_t * data[t - 1][Compartments.hospitalized.value] 

133 - (k_R_t + k_D_t) * data[t - 1][Compartments.critical.value], 

134 k_D_t * data[t - 1][Compartments.critical.value], 

135 -k_S_t * data[t - 1][Compartments.quarantine_S.value] 

136 + k_Q_t * data[t - 1][Compartments.susceptible.value], 

137 -k_I_t * data[t - 1][Compartments.quarantine_E.value] 

138 + k_Q_t * data[t - 1][Compartments.exposed.value], 

139 k_I_t * data[t - 1][Compartments.quarantine_E.value] 

140 + k_Q_t * data[t - 1][Compartments.infected.value] 

141 - (k_SY_t + k_R_t) * data[t - 1][Compartments.quarantine_I.value], 

142 k_SY_t * data[t - 1][Compartments.infected.value] 

143 - k_q 

144 * torch.sum(data[t - 1][0:3]) 

145 * data[t - 1][Compartments.contact_traced.value], 

146 ] 

147 ) 

148 

149 # Solve the ODE (simple forward Euler) 

150 data[t, :] = torch.clip(data[t - 1, :] + dy * dt, 0, 1) 

151 

152 # Return the data, discarding the burn-in, if specified 

153 return data[burn_in:] 

154 

155 

156def get_data(data_cfg: dict, h5group: h5.Group) -> torch.Tensor: 

157 """Returns the training data for the Covid model. If a directory is passed, the data is loaded from that directory. 

158 Otherwise, synthetic training data is generated by iteratively solving the ODE system. 

159 

160 :param data_cfg: configuration file 

161 :param h5group: hdf5.group to which to write the data 

162 :return: ``torch.Tensor`` training data 

163 """ 

164 

165 # Load training data from file 

166 if "load_from_dir" in data_cfg.keys(): 

167 log.info(" Loading training data ...") 

168 # Load training data from hdf5 file 

169 with h5.File(data_cfg["load_from_dir"], "r") as f: 

170 training_data = torch.from_numpy( 

171 np.array(f["Covid"]["true_counts"]) 

172 ).float() 

173 

174 # Generate synthetic data 

175 elif "synthetic_data" in data_cfg.keys(): 

176 log.info(" Generating training data ...") 

177 # Get the time dependent parameters: names and intervals 

178 time_dependent_params: dict = data_cfg.get("time_dependent_parameters", {}) 

179 num_steps: int = data_cfg["synthetic_data"]["num_steps"] 

180 burn_in: int = data_cfg["synthetic_data"].get("burn_in", 0) 

181 

182 # Replace any time-dependent parameters with a sequence 

183 for key in time_dependent_params.keys(): 

184 p = np.zeros(num_steps + burn_in) 

185 i = 0 

186 for j, interval in enumerate(time_dependent_params[key]): 

187 _, upper = interval 

188 if not upper: 

189 upper = num_steps 

190 while i < upper + burn_in: 

191 p[i] = data_cfg["synthetic_data"][key][j] 

192 i += 1 

193 data_cfg["synthetic_data"][key] = p 

194 

195 # Generate training data by integrating the model equations 

196 training_data = generate_smooth_data(data_cfg["synthetic_data"]) 

197 

198 else: 

199 raise ValueError( 

200 f"You must supply one of 'load_from_dir' or 'synthetic data' keys!" 

201 ) 

202 

203 # Save training data to hdf5 dataset and return 

204 dset_true_counts = h5group.create_dataset( 

205 "true_counts", 

206 training_data.shape, 

207 maxshape=training_data.shape, 

208 chunks=True, 

209 compression=3, 

210 dtype=float, 

211 ) 

212 

213 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"] 

214 dset_true_counts.attrs["coords_mode__time"] = "trivial" 

215 dset_true_counts.attrs["coords_mode__kind"] = "values" 

216 dset_true_counts.attrs["coords__kind"] = [k.name for k in Compartments] 

217 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial" 

218 

219 dset_true_counts[:, :, :] = training_data 

220 

221 return training_data