Coverage for models/SIR/DataGeneration.py: 81%
93 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import logging
3import h5py as h5
4import numpy as np
5import torch
7from .ABM import SIR_ABM
9log = logging.getLogger(__name__)
12# --- Data generation functions ------------------------------------------------------------------------------------
13def generate_data_from_ABM(
14 *,
15 cfg: dict,
16 parameters=None,
17 positions=None,
18 kinds=None,
19 counts=None,
20 write_init_state: bool = True,
21 **__,
22):
23 """
24 Runs the ABM for n iterations and writes out the data, if datasets are passed.
26 :param cfg: the data generation configuration settings
27 :param parameters: (optional) the parameters to use to run the model. Defaults to the ABM defaults
28 :param positions: (optional) the dataset to write the agent positions to
29 :kinds: (optional) the dataset to write the ABM kinds to
30 :counts: (optional) the dataset to write the ABM counts to
31 """
33 log.info(" Initialising the ABM ... ")
35 ABM = SIR_ABM(**cfg)
36 num_steps: int = cfg["num_steps"]
37 data = (
38 torch.empty((num_steps + 1, 3, 1), dtype=torch.float)
39 if write_init_state
40 else torch.empty((num_steps, 3, 1))
41 )
43 if write_init_state:
44 data[0, :, :] = ABM.current_counts.float() / ABM.N
46 parameters = (
47 torch.tensor([ABM.p_infect, ABM.t_infectious])
48 if parameters is None
49 else parameters
50 )
52 log.info(" Generating synthetic data ... ")
53 for _ in range(num_steps):
54 # Run the ABM for a single step
55 ABM.run_single(parameters=parameters)
57 # Get the densities
58 densities = ABM.current_counts.float() / ABM.N
60 # Write out the new positions
61 if positions:
62 positions.resize(positions.shape[0] + 1, axis=0)
63 positions[-1, :, :] = ABM.current_positions
65 # Write out the new kinds
66 if kinds:
67 kinds.resize(kinds.shape[0] + 1, axis=0)
68 kinds[-1, :] = ABM.current_kinds
70 # Write out the new counts
71 if counts:
72 counts.resize(counts.shape[0] + 1, axis=0)
73 counts[-1, :] = densities
75 # Append the new counts to training dataset
76 data[_] = densities
78 log.debug(f" Completed run {_} of {num_steps} ... ")
80 return data
83def generate_smooth_data(
84 *,
85 cfg: dict = None,
86 num_steps: int = None,
87 parameters=None,
88 init_state: torch.tensor,
89 counts=None,
90 write_init_state: bool = True,
91 requires_grad: bool = False,
92 **__,
93):
94 """
95 Generates a dataset of SIR-counts by iteratively solving the system of differential equations.
96 """
98 num_steps: int = cfg["num_steps"] if num_steps is None else num_steps
99 data = (
100 torch.empty((num_steps, 3, 1), dtype=torch.float)
101 if not write_init_state
102 else torch.empty((num_steps + 1, 3, 1), dtype=torch.float)
103 )
105 parameters = (
106 torch.tensor(
107 [cfg["p_infect"], cfg["t_infectious"], cfg["sigma"]], dtype=torch.float
108 )
109 if parameters is None
110 else parameters
111 )
113 # Write out the initial state if required
114 if write_init_state:
115 data[0] = init_state
116 if counts:
117 counts.resize(counts.shape[0] + 1, axis=0)
118 counts[-1, :] = init_state
120 current_state = init_state.clone()
121 current_state.requires_grad = requires_grad
123 for _ in range(num_steps):
124 # Generate the transformation matrix
125 # Patients only start recovering after a certain time
126 w = torch.normal(torch.tensor(0.0), torch.tensor(1.0))
127 tau = 1 / parameters[1] * torch.sigmoid(1000 * (_ / parameters[1] - 1))
128 matrix = torch.vstack(
129 [
130 torch.tensor([-parameters[0], -parameters[2] * w]),
131 torch.tensor([parameters[0], -tau + parameters[2] * w]),
132 torch.tensor([0, tau]),
133 ]
134 )
135 current_state = torch.clip(
136 current_state
137 + torch.matmul(
138 matrix,
139 torch.vstack([current_state[0] * current_state[1], current_state[1]]),
140 ),
141 0.0,
142 1.0,
143 )
145 if write_init_state:
146 data[_ + 1] = current_state
147 else:
148 data[_] = current_state
150 if counts:
151 counts.resize(counts.shape[0] + 1, axis=0)
152 counts[-1, :] = current_state
154 return data
157def get_SIR_data(*, data_cfg: dict, h5group: h5.Group, write_init_state: bool = True):
158 """Returns the training data for the SIR model. If a directory is passed, the
159 data is loaded from that directory. Otherwise, synthetic training data is generated, either from an ABM,
160 or by iteratively solving the temporal ODE system.
161 """
162 if "load_from_dir" in data_cfg.keys():
163 with h5.File(data_cfg["load_from_dir"], "r") as f:
164 data = np.array(f["SIR"]["true_counts"])
166 dset_true_counts = h5group.create_dataset(
167 "true_counts",
168 (len(data), 3, 1),
169 maxshape=(None, 3, 1),
170 chunks=True,
171 compression=3,
172 dtype=float,
173 )
175 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"]
176 dset_true_counts.attrs["coords_mode__time"] = "trivial"
177 dset_true_counts.attrs["coords_mode__kind"] = "values"
178 dset_true_counts.attrs["coords__kind"] = [
179 "susceptible",
180 "infected",
181 "recovered",
182 ]
183 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial"
185 dset_true_counts[:, :, :] = data
187 return torch.from_numpy(data).float()
189 elif "synthetic_data" in data_cfg.keys():
190 # True counts
191 dset_true_counts = h5group.create_dataset(
192 "true_counts",
193 (0, 3, 1),
194 maxshape=(None, 3, 1),
195 chunks=True,
196 compression=3,
197 dtype=float,
198 )
200 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"]
201 dset_true_counts.attrs["coords_mode__time"] = "trivial"
202 dset_true_counts.attrs["coords_mode__kind"] = "values"
203 dset_true_counts.attrs["coords__kind"] = [
204 "susceptible",
205 "infected",
206 "recovered",
207 ]
208 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial"
210 # --- Generate the data ----------------------------------------------------------------------------------------
211 type = data_cfg["synthetic_data"]["type"]
213 if type == "smooth":
214 N = data_cfg["synthetic_data"]["N"]
215 init_state = torch.tensor([[N - 1], [1], [0]], dtype=torch.float) / N
216 training_data = generate_smooth_data(
217 cfg=data_cfg["synthetic_data"],
218 init_state=init_state,
219 counts=dset_true_counts,
220 write_init_state=write_init_state,
221 )
223 elif type == "from_ABM":
224 N = data_cfg["synthetic_data"]["N"]
226 # Initialise agent position dataset
227 dset_position = h5group.create_dataset(
228 "position",
229 (0, N, 2),
230 maxshape=(None, N, 2),
231 chunks=True,
232 compression=3,
233 )
234 dset_position.attrs["dim_names"] = ["time", "agent_id", "coords"]
235 dset_position.attrs["coords_mode__time"] = "trivial"
236 dset_position.attrs["coords_mode__agent_id"] = "trivial"
237 dset_position.attrs["coords_mode__coords"] = "values"
238 dset_position.attrs["coords__coords"] = ["x", "y"]
240 # Initialise agent kind dataset
241 dset_kinds = h5group.create_dataset(
242 "kinds",
243 (0, N),
244 maxshape=(None, N),
245 chunks=True,
246 compression=3,
247 )
248 dset_kinds.attrs["dim_names"] = ["time", "agent_id"]
249 dset_kinds.attrs["coords_mode__time"] = "trivial"
250 dset_kinds.attrs["coords_mode__agent_id"] = "trivial"
252 training_data = generate_data_from_ABM(
253 cfg=data_cfg["synthetic_data"],
254 positions=dset_position,
255 kinds=dset_kinds,
256 counts=dset_true_counts,
257 write_init_state=write_init_state,
258 )
259 else:
260 raise ValueError(
261 f"Unrecognised arugment {type}! 'Type' must be one of 'smooth' or 'from_ABM'!"
262 )
264 return training_data
266 else:
267 raise ValueError(
268 f"You must supply one of 'load_from_dir' or 'synthetic data' keys!"
269 )