Coverage for model_plots / Covid / trajectories_from_densities.py: 25%
73 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import logging
2from typing import Union
4import numpy as np
5import torch
6import xarray as xr
7from tqdm import trange
8from tqdm.contrib.logging import logging_redirect_tqdm
10from models.Covid import generate_smooth_data
11from utopya.eval import is_operation
13log = logging.getLogger(__name__)
15from ..SIR.trajectories_from_densities import _adjust_for_time_dependency
17def _combine_compartments(
18 da: Union[xr.DataArray, xr.Dataset],
19 combine: dict,
20) -> Union[xr.DataArray, xr.Dataset]:
21 """Combines certain compartments by summation"""
23 for key, values in combine.items():
24 q = da.sel({"kind": values}).sum("kind").expand_dims({"kind": [key]})
25 da = xr.concat(([da, q]), dim="kind")
26 da = da.drop_sel({"kind": values})
28 return da
31def _drop_compartments(
32 da: Union[xr.DataArray, xr.Dataset],
33 drop: list,
34) -> Union[xr.DataArray, xr.Dataset]:
35 """Drops certain compartments"""
36 return da.drop_sel({"kind": drop})
39def _calculate_residuals(da: xr.DataArray) -> xr.DataArray:
40 """Calculates the residuals between the predictions and the true data"""
42 residuals = xr.DataArray(
43 data=[
44 (
45 da.sel({"type": "prediction_mode"}, drop=True)
46 - da.sel({"type": "true_counts"}, drop=True)
47 )
48 / da.sel({"type": "true_counts"}, drop=True),
49 (
50 da.sel({"type": "prediction_mean"}, drop=True)
51 - da.sel({"type": "true_counts"}, drop=True)
52 )
53 / da.sel({"type": "true_counts"}, drop=True),
54 ],
55 dims=["type", "time", "kind"],
56 coords=dict(
57 type=["mode_residual", "mean_residual"],
58 kind=da.coords["kind"],
59 time=da.coords["time"],
60 ),
61 )
63 return xr.where(residuals != np.inf, residuals, np.nan)
66@is_operation("Covid_densities_residuals")
67def print_residuals(data, train_cut: int = 200):
68 """Prints a summary of the residuals for each compartment"""
70 residuals = _calculate_residuals(data) ** 2
72 print_intervals = {"train": slice(None, train_cut), "test": slice(train_cut, None)}
74 for key, item in print_intervals.items():
75 log.remark("------------------------------------------------------")
76 log.remark(f"L2 residuals in {key}")
77 l2_mean = np.sqrt(residuals.isel({"time": item}).mean("time", skipna=True))
78 for k in residuals.coords["kind"]:
79 log.remark(
80 f" {k.item().capitalize()}: {np.around(l2_mean.sel({'kind': k, 'type': 'mean_residual'}).data.item(), 5)} (mean), {np.around(l2_mean.sel({'kind': k, 'type': 'mode_residual'}).data.item(), 5)} (mode)"
81 )
82 log.remark(
83 f" Average: {np.around(l2_mean.sel({'type': 'mean_residual'}).mean('kind').data.item(), 5)} (mean), {np.around(l2_mean.sel({'type': 'mode_residual'}).mean('kind').data.item(), 5)} (mode)"
84 )
86 return data
89@is_operation("Covid_densities_from_joint")
90def densities_from_joint(
91 parameters: xr.Dataset,
92 prob: xr.Dataset,
93 *,
94 true_counts: xr.Dataset,
95 cfg: dict,
96 combine: dict = None,
97 drop: list = None,
98 mean: xr.Dataset = None,
99) -> xr.Dataset:
100 """Runs the model with the estimated parameters, given in an ``xr.Dataset`` ``parameters``,
101 and weights each time series with its corresponding probability, given ``prob``.
102 The probabilities must be normalised to 1.
104 :param parameters: the ``xr.Dataset`` of parameter estimates, indexed by the sample dimension
105 :param prob: the xr.Dataset of probabilities associated with each estimate, indexed by sample
106 :param true_counts: the xr.Dataset of true counts
107 :param cfg: the run configuration of the original data
108 :param combine: (optional) dictionary of compartments to combine by summing
109 :param drop: (optional) list of compartments to drop
110 :param print_residuals: (optional) whether to print the L2 residuals on the test and training periods
111 :param train_cut: (optional) time step separating the training from the test period
112 (for printing the residuals only)
113 :param mean: (optional) can pass the mean dataset instead of calculating it; needed for MCMC calculationss
114 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities (if given) for all compartments
115 """
117 # Name of sample dimension
118 sample_dim: str = list(prob.coords.keys())[0]
120 # Sample configuration
121 sample_cfg = cfg["Data"]["synthetic_data"]
122 sample_cfg["num_steps"] = len(true_counts.coords["time"])
123 sample_cfg["burn_in"] = 0
124 res = []
126 with logging_redirect_tqdm():
127 for s in (pbar := trange(len(parameters.coords[sample_dim]))):
128 pbar.set_description(
129 f"Drawing {len(parameters.coords[sample_dim])} samples from joint distribution: "
130 )
132 # Construct the configuration, taking time-dependent parameters into account
133 sample = parameters.isel({sample_dim: s}, drop=True)
134 sample_cfg.update(
135 {
136 p.item(): sample.sel({"parameter": p}).item()
137 for p in sample.coords["parameter"]
138 }
139 )
140 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
142 # Generate smooth data
143 generated_data = generate_smooth_data(
144 cfg=param_cfg,
145 init_state=torch.from_numpy(
146 true_counts.isel({"time": 0}, drop=True).data
147 ).float(),
148 ).numpy()
150 res.append(
151 xr.DataArray(
152 data=[[generated_data]],
153 dims=[sample_dim, "type", "time", "kind", "dim_name__0"],
154 coords=dict(
155 **{sample_dim: [s]},
156 type=["prediction_mean"],
157 **true_counts.coords,
158 ),
159 ).squeeze(["dim_name__0"], drop=True)
160 )
162 # Concatenate all the time series
163 res = xr.concat(res, dim=sample_dim)
165 # Get the index of the most likely parameter
166 mode_idx = prob.argmax(dim=sample_dim)
167 sample_cfg.update(
168 {
169 p.item(): parameters.isel({sample_dim: mode_idx}, drop=True)
170 .sel({"parameter": p})
171 .item()
172 for p in parameters.coords["parameter"]
173 }
174 )
176 # Perform a run using the mode
177 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
178 mode_data = generate_smooth_data(
179 cfg=mode_params,
180 init_state=torch.from_numpy(
181 true_counts.isel({"time": 0}, drop=True).data
182 ).float(),
183 ).numpy()
185 mode_data = xr.DataArray(
186 data=[mode_data],
187 dims=["type", "time", "kind", "dim_name__0"],
188 coords=dict(type=["prediction_mode"], **true_counts.coords),
189 ).squeeze(["dim_name__0"], drop=True)
191 # Combine compartments, if given
192 if combine:
193 res = _combine_compartments(res, combine)
194 mode_data = _combine_compartments(mode_data, combine)
195 true_counts = _combine_compartments(true_counts, combine)
197 # Drop compartments, if given
198 if drop:
199 res = _drop_compartments(res, drop)
200 mode_data = _drop_compartments(mode_data, drop)
201 true_counts = _drop_compartments(true_counts, drop)
203 # Reshape the probability array
204 prob = np.reshape(prob.data, (len(prob.coords[sample_dim]), 1, 1, 1))
205 prob_stacked = np.repeat(prob, 1, 1)
206 prob_stacked = np.repeat(prob_stacked, len(res.coords["time"]), 2)
207 prob_stacked = np.repeat(prob_stacked, len(res.coords["kind"]), 3)
209 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability
210 mean = (res * prob_stacked).sum(sample_dim) if mean is None else mean
211 std = np.sqrt(((res - mean) ** 2 * prob).sum(sample_dim))
213 # Add a type to the true counts to allow concatenation
214 true_counts = true_counts.expand_dims({"type": ["true_counts"]}).squeeze(
215 "dim_name__0", drop=True
216 )
218 mean = xr.concat([mean, true_counts, mode_data], dim="type")
219 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type")
221 return xr.Dataset(data_vars=dict(mean=mean, std=std))