Coverage for models / Covid / ensemble_training / DataGeneration.py: 11%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import logging
3import h5py as h5
4import numpy as np
5import torch
7from .kinds import Compartments
9log = logging.getLogger(__name__)
11def generate_smooth_data(
12 cfg, *, parameters=None, init_state: torch.Tensor = None) -> torch.Tensor:
13 """Generates a dataset of counts for each compartment by iteratively solving the system of differential equations.
15 :param cfg: configuration file, containing parameter values (possibly as a ``Sequence``, if time-dependent),
16 number of steps, burn-in period, etc.
17 :param parameters: (optional) parameters used to override cfg settings
18 :param init_state: (optional) initial state to use; defaults to a generic density if ``None``
19 :return: ``torch.Tensor`` training dataset, with the burn-in period discarded
20 """
22 # Get config settings
23 num_steps: int = cfg["num_steps"]
24 burn_in: int = cfg.get("burn_in", 0)
25 dt: float = cfg["dt"]
26 k_q: float = cfg.get("k_q", 10.25)
28 # Use a generic initial state if None passed
29 if init_state is None:
30 init_state = torch.zeros(12, 1, dtype=torch.float)
31 init_state[
32 Compartments.susceptible.value
33 ] = 0.9933 # High number of susceptible agents
34 init_state[Compartments.infected.value] = (
35 1.0 - init_state[Compartments.susceptible.value]
36 ) # Some infected agents
38 # Empty dataset for counts: the initial state is always written
39 data = torch.empty((num_steps + burn_in, 12, 1), dtype=torch.float)
40 data[0, :] = init_state
42 # Get the model parameters; these can be overridden with the ``parameters`` argument
43 k_S = (
44 torch.tensor(cfg["k_S"], dtype=torch.float)
45 if parameters is None
46 else parameters[Compartments.susceptible.value]
47 )
48 k_E = (
49 torch.tensor(cfg["k_E"], dtype=torch.float)
50 if parameters is None
51 else parameters[Compartments.exposed.value]
52 )
53 k_I = (
54 torch.tensor(cfg["k_I"], dtype=torch.float)
55 if parameters is None
56 else parameters[Compartments.infected.value]
57 )
58 k_R = (
59 torch.tensor(cfg["k_R"], dtype=torch.float)
60 if parameters is None
61 else parameters[Compartments.recovered.value]
62 )
63 k_SY = (
64 torch.tensor(cfg["k_SY"], dtype=torch.float)
65 if parameters is None
66 else parameters[Compartments.symptomatic.value]
67 )
68 k_H = (
69 torch.tensor(cfg["k_H"], dtype=torch.float)
70 if parameters is None
71 else parameters[Compartments.hospitalized.value]
72 )
73 k_C = (
74 torch.tensor(cfg["k_C"], dtype=torch.float)
75 if parameters is None
76 else parameters[Compartments.critical.value]
77 )
78 k_D = (
79 torch.tensor(cfg["k_D"], dtype=torch.float)
80 if parameters is None
81 else parameters[Compartments.deceased.value]
82 )
83 k_CT = (
84 torch.tensor(cfg["k_CT"], dtype=torch.float)
85 if parameters is None
86 else parameters[Compartments.contact_traced.value]
87 )
89 # Solve the ODE
90 for t in range(1, num_steps + burn_in):
91 # Get the time-dependent parameters, if given
92 k_S_t = k_S[t] if k_S.dim() > 0 else k_S
93 k_E_t = k_E[t] if k_E.dim() > 0 else k_E
94 k_I_t = k_I[t] if k_I.dim() > 0 else k_I
95 k_R_t = k_R[t] if k_R.dim() > 0 else k_R
96 k_SY_t = k_SY[t] if k_SY.dim() > 0 else k_SY
97 k_H_t = k_H[t] if k_H.dim() > 0 else k_H
98 k_C_t = k_C[t] if k_C.dim() > 0 else k_C
99 k_D_t = k_D[t] if k_D.dim() > 0 else k_D
100 k_CT_t = k_CT[t] if k_CT.dim() > 0 else k_CT
102 # Calculate k_Q
103 k_Q_t = k_q * k_CT_t * data[t - 1][Compartments.contact_traced.value]
105 dy = torch.stack(
106 [
107 (-k_E_t * data[t - 1][Compartments.infected.value] - k_Q_t)
108 * data[t - 1][Compartments.susceptible.value]
109 + k_S_t * data[t - 1][Compartments.quarantine_S.value],
110 k_E_t
111 * data[t - 1][Compartments.susceptible.value]
112 * data[t - 1][Compartments.infected.value]
113 - (k_I_t + k_Q_t) * data[t - 1][Compartments.exposed.value],
114 k_I_t * data[t - 1][Compartments.exposed.value]
115 - (k_R_t + k_SY_t + k_Q_t) * data[t - 1][Compartments.infected.value],
116 k_R_t
117 * (
118 data[t - 1][Compartments.infected.value]
119 + data[t - 1][Compartments.symptomatic.value]
120 + data[t - 1][Compartments.hospitalized.value]
121 + data[t - 1][Compartments.critical.value]
122 + data[t - 1][Compartments.quarantine_I.value]
123 ),
124 k_SY_t
125 * (
126 data[t - 1][Compartments.infected.value]
127 + data[t - 1][Compartments.quarantine_I.value]
128 )
129 - (k_R_t + k_H_t) * data[t - 1][Compartments.symptomatic.value],
130 k_H_t * data[t - 1][Compartments.symptomatic.value]
131 - (k_R_t + k_C_t) * data[t - 1][Compartments.hospitalized.value],
132 k_C_t * data[t - 1][Compartments.hospitalized.value]
133 - (k_R_t + k_D_t) * data[t - 1][Compartments.critical.value],
134 k_D_t * data[t - 1][Compartments.critical.value],
135 -k_S_t * data[t - 1][Compartments.quarantine_S.value]
136 + k_Q_t * data[t - 1][Compartments.susceptible.value],
137 -k_I_t * data[t - 1][Compartments.quarantine_E.value]
138 + k_Q_t * data[t - 1][Compartments.exposed.value],
139 k_I_t * data[t - 1][Compartments.quarantine_E.value]
140 + k_Q_t * data[t - 1][Compartments.infected.value]
141 - (k_SY_t + k_R_t) * data[t - 1][Compartments.quarantine_I.value],
142 k_SY_t * data[t - 1][Compartments.infected.value]
143 - k_q
144 * torch.sum(data[t - 1][0:3])
145 * data[t - 1][Compartments.contact_traced.value],
146 ]
147 )
149 # Solve the ODE (simple forward Euler)
150 data[t, :] = torch.clip(data[t - 1, :] + dy * dt, 0, 1)
152 # Return the data, discarding the burn-in, if specified
153 return data[burn_in:]
156def get_data(data_cfg: dict, h5group: h5.Group) -> torch.Tensor:
157 """Returns the training data for the Covid model. If a directory is passed, the data is loaded from that directory.
158 Otherwise, synthetic training data is generated by iteratively solving the ODE system.
160 :param data_cfg: configuration file
161 :param h5group: hdf5.group to which to write the data
162 :return: ``torch.Tensor`` training data
163 """
165 # Load training data from file
166 if "load_from_dir" in data_cfg.keys():
167 log.info(" Loading training data ...")
168 # Load training data from hdf5 file
169 with h5.File(data_cfg["load_from_dir"], "r") as f:
170 training_data = torch.from_numpy(
171 np.array(f["Covid"]["true_counts"])
172 ).float()
174 # Generate synthetic data
175 elif "synthetic_data" in data_cfg.keys():
176 log.info(" Generating training data ...")
177 # Get the time dependent parameters: names and intervals
178 time_dependent_params: dict = data_cfg.get("time_dependent_parameters", {})
179 num_steps: int = data_cfg["synthetic_data"]["num_steps"]
180 burn_in: int = data_cfg["synthetic_data"].get("burn_in", 0)
182 # Replace any time-dependent parameters with a sequence
183 for key in time_dependent_params.keys():
184 p = np.zeros(num_steps + burn_in)
185 i = 0
186 for j, interval in enumerate(time_dependent_params[key]):
187 _, upper = interval
188 if not upper:
189 upper = num_steps
190 while i < upper + burn_in:
191 p[i] = data_cfg["synthetic_data"][key][j]
192 i += 1
193 data_cfg["synthetic_data"][key] = p
195 # Generate training data by integrating the model equations
196 training_data = generate_smooth_data(data_cfg["synthetic_data"])
198 else:
199 raise ValueError(
200 f"You must supply one of 'load_from_dir' or 'synthetic data' keys!"
201 )
203 # Save training data to hdf5 dataset and return
204 dset_true_counts = h5group.create_dataset(
205 "true_counts",
206 training_data.shape,
207 maxshape=training_data.shape,
208 chunks=True,
209 compression=3,
210 dtype=float,
211 )
213 dset_true_counts.attrs["dim_names"] = ["time", "kind", "dim_name__0"]
214 dset_true_counts.attrs["coords_mode__time"] = "trivial"
215 dset_true_counts.attrs["coords_mode__kind"] = "values"
216 dset_true_counts.attrs["coords__kind"] = [k.name for k in Compartments]
217 dset_true_counts.attrs["coords_mode__dim_name__0"] = "trivial"
219 dset_true_counts[:, :, :] = training_data
221 return training_data