Coverage for models/SIR/DataGeneration.py: 81%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import logging 

2 

3import h5py as h5 

4import numpy as np 

5import torch 

6 

7from .ABM import SIR_ABM 

8 

9log = logging.getLogger(__name__) 

10 

11 

12# --- Data generation functions ------------------------------------------------------------------------------------ 

13def generate_data_from_ABM( 

14 *, 

15 cfg: dict, 

16 parameters=None, 

17 positions=None, 

18 kinds=None, 

19 counts=None, 

20 write_init_state: bool = True, 

21 **__, 

22): 

23 """ 

24 Runs the ABM for n iterations and writes out the data, if datasets are passed. 

25 

26 :param cfg: the data generation configuration settings 

27 :param parameters: (optional) the parameters to use to run the model. Defaults to the ABM defaults 

28 :param positions: (optional) the dataset to write the agent positions to 

29 :kinds: (optional) the dataset to write the ABM kinds to 

30 :counts: (optional) the dataset to write the ABM counts to 

31 """ 

32 

33 log.info(" Initialising the ABM ... ") 

34 

35 ABM = SIR_ABM(**cfg) 

36 num_steps: int = cfg["num_steps"] 

37 data = ( 

38 torch.empty((num_steps + 1, 3, 1), dtype=torch.float) 

39 if write_init_state 

40 else torch.empty((num_steps, 3, 1)) 

41 ) 

42 

43 if write_init_state: 

44 data[0, :, :] = ABM.current_counts.float() / ABM.N 

45 

46 parameters = ( 

47 torch.tensor([ABM.p_infect, ABM.t_infectious]) 

48 if parameters is None 

49 else parameters 

50 ) 

51 

52 log.info(" Generating synthetic data ... ") 

53 for _ in range(num_steps): 

54 # Run the ABM for a single step 

55 ABM.run_single(parameters=parameters) 

56 

57 # Get the densities 

58 densities = ABM.current_counts.float() / ABM.N 

59 

60 # Write out the new positions 

61 if positions: 

62 positions.resize(positions.shape[0] + 1, axis=0) 

63 positions[-1, :, :] = ABM.current_positions 

64 

65 # Write out the new kinds 

66 if kinds: 

67 kinds.resize(kinds.shape[0] + 1, axis=0) 

68 kinds[-1, :] = ABM.current_kinds 

69 

70 # Write out the new counts 

71 if counts: 

72 counts.resize(counts.shape[0] + 1, axis=0) 

73 counts[-1, :] = densities 

74 

75 # Append the new counts to training dataset 

76 data[_] = densities 

77 

78 log.debug(f" Completed run {_} of {num_steps} ... ") 

79 

80 return data 

81 

82 

83def generate_smooth_data( 

84 *, 

85 cfg: dict = None, 

86 num_steps: int = None, 

87 parameters=None, 

88 init_state: torch.tensor, 

89 counts=None, 

90 write_init_state: bool = True, 

91 requires_grad: bool = False, 

92 **__, 

93): 

94 """ 

95 Generates a dataset of SIR-counts by iteratively solving the system of differential equations. 

96 """ 

97 

98 num_steps: int = cfg["num_steps"] if num_steps is None else num_steps 

99 data = ( 

100 torch.empty((num_steps, 3, 1), dtype=torch.float) 

101 if not write_init_state 

102 else torch.empty((num_steps + 1, 3, 1), dtype=torch.float) 

103 ) 

104 

105 parameters = ( 

106 torch.tensor( 

107 [cfg["p_infect"], cfg["t_infectious"], cfg["sigma"]], dtype=torch.float 

108 ) 

109 if parameters is None 

110 else parameters 

111 ) 

112 

113 # Write out the initial state if required 

114 if write_init_state: 

115 data[0] = init_state 

116 if counts: 

117 counts.resize(counts.shape[0] + 1, axis=0) 

118 counts[-1, :] = init_state 

119 

120 current_state = init_state.clone() 

121 current_state.requires_grad = requires_grad 

122 

123 for _ in range(num_steps): 

124 # Generate the transformation matrix 

125 # Patients only start recovering after a certain time 

126 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0)) 

127 tau = 1 / parameters[1] * torch.sigmoid(1000 * (_ / parameters[1] - 1)) 

128 matrix = torch.vstack( 

129 [ 

130 torch.tensor([-parameters[0], -parameters[2] * w]), 

131 torch.tensor([parameters[0], -tau + parameters[2] * w]), 

132 torch.tensor([0, tau]), 

133 ] 

134 ) 

135 current_state = torch.clip( 

136 current_state 

137 + torch.matmul( 

138 matrix, 

139 torch.vstack([current_state[0] * current_state[1], current_state[1]]), 

140 ), 

141 0.0, 

142 1.0, 

143 ) 

144 

145 if write_init_state: 

146 data[_ + 1] = current_state 

147 else: 

148 data[_] = current_state 

149 

150 if counts: 

151 counts.resize(counts.shape[0] + 1, axis=0) 

152 counts[-1, :] = current_state 

153 

154 return data 

155 

156 

157def get_SIR_data(*, data_cfg: dict, h5group: h5.Group, write_init_state: bool = True): 

158 """Returns the training data for the SIR model. If a directory is passed, the 

159 data is loaded from that directory. Otherwise, synthetic training data is generated, either from an ABM, 

160 or by iteratively solving the temporal ODE system. 

161 """ 

162 if "load_from_dir" in data_cfg.keys(): 

163 with h5.File(data_cfg["load_from_dir"], "r") as f: 

164 data = np.array(f["SIR"]["true_counts"]) 

165 

166 dset_true_counts = h5group.create_dataset( 

167 "true_counts", 

168 (len(data), 3, 1), 

169 maxshape=(None, 3, 1), 

170 chunks=True, 

171 compression=3, 

172 dtype=float, 

173 ) 

174 

175 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"] 

176 dset_true_counts.attrs["coords_mode__time"] = "trivial" 

177 dset_true_counts.attrs["coords_mode__kind"] = "values" 

178 dset_true_counts.attrs["coords__kind"] = [ 

179 "susceptible", 

180 "infected", 

181 "recovered", 

182 ] 

183 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial" 

184 

185 dset_true_counts[:, :, :] = data 

186 

187 return torch.from_numpy(data).float() 

188 

189 elif "synthetic_data" in data_cfg.keys(): 

190 # True counts 

191 dset_true_counts = h5group.create_dataset( 

192 "true_counts", 

193 (0, 3, 1), 

194 maxshape=(None, 3, 1), 

195 chunks=True, 

196 compression=3, 

197 dtype=float, 

198 ) 

199 

200 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"] 

201 dset_true_counts.attrs["coords_mode__time"] = "trivial" 

202 dset_true_counts.attrs["coords_mode__kind"] = "values" 

203 dset_true_counts.attrs["coords__kind"] = [ 

204 "susceptible", 

205 "infected", 

206 "recovered", 

207 ] 

208 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial" 

209 

210 # --- Generate the data ---------------------------------------------------------------------------------------- 

211 type = data_cfg["synthetic_data"]["type"] 

212 

213 if type == "smooth": 

214 N = data_cfg["synthetic_data"]["N"] 

215 init_state = torch.tensor([[N - 1], [1], [0]], dtype=torch.float) / N 

216 training_data = generate_smooth_data( 

217 cfg=data_cfg["synthetic_data"], 

218 init_state=init_state, 

219 counts=dset_true_counts, 

220 write_init_state=write_init_state, 

221 ) 

222 

223 elif type == "from_ABM": 

224 N = data_cfg["synthetic_data"]["N"] 

225 

226 # Initialise agent position dataset 

227 dset_position = h5group.create_dataset( 

228 "position", 

229 (0, N, 2), 

230 maxshape=(None, N, 2), 

231 chunks=True, 

232 compression=3, 

233 ) 

234 dset_position.attrs["dim_names"] = ["time", "agent_id", "coords"] 

235 dset_position.attrs["coords_mode__time"] = "trivial" 

236 dset_position.attrs["coords_mode__agent_id"] = "trivial" 

237 dset_position.attrs["coords_mode__coords"] = "values" 

238 dset_position.attrs["coords__coords"] = ["x", "y"] 

239 

240 # Initialise agent kind dataset 

241 dset_kinds = h5group.create_dataset( 

242 "kinds", 

243 (0, N), 

244 maxshape=(None, N), 

245 chunks=True, 

246 compression=3, 

247 ) 

248 dset_kinds.attrs["dim_names"] = ["time", "agent_id"] 

249 dset_kinds.attrs["coords_mode__time"] = "trivial" 

250 dset_kinds.attrs["coords_mode__agent_id"] = "trivial" 

251 

252 training_data = generate_data_from_ABM( 

253 cfg=data_cfg["synthetic_data"], 

254 positions=dset_position, 

255 kinds=dset_kinds, 

256 counts=dset_true_counts, 

257 write_init_state=write_init_state, 

258 ) 

259 else: 

260 raise ValueError( 

261 f"Unrecognised arugment {type}! 'Type' must be one of 'smooth' or 'from_ABM'!" 

262 ) 

263 

264 return training_data 

265 

266 else: 

267 raise ValueError( 

268 f"You must supply one of 'load_from_dir' or 'synthetic data' keys!" 

269 )