Coverage for model_plots / SIR / trajectories_from_densities.py: 21%
56 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 16:26 +0000
1import logging
2import numpy as np
3import torch
4import xarray as xr
5from tqdm import trange
6from tqdm.contrib.logging import logging_redirect_tqdm
8from models.SIR import generate_smooth_data
9from utopya.eval import is_operation
11log = logging.getLogger(__name__)
13def _adjust_for_time_dependency(
14 param_cfg: dict, cfg: dict, true_counts: xr.Dataset
15) -> dict:
16 """Adjusts the parameter configuration for time dependent parameters, if given."""
18 time_dependent_parameters = cfg["Data"].get("time_dependent_parameters", {})
20 # Extend any time-dependent parameters, if necessary
21 for param in time_dependent_parameters.keys():
22 val = np.zeros(len(true_counts.coords["time"]))
23 i = 0
24 # Replace any time-dependent parameters with a series
25 for j, interval in enumerate(time_dependent_parameters[param]):
26 _, upper = interval
27 if not upper:
28 upper = len(val)
29 while i < upper:
30 val[i] = param_cfg[str(param) + f"_{j}"]
31 i += 1
32 param_cfg[param] = val
34 return param_cfg
37@is_operation("SIR_densities_from_joint")
38def densities_from_joint(
39 joint: xr.Dataset,
40 *,
41 true_counts: xr.Dataset,
42 cfg: dict,
43) -> xr.Dataset:
44 """Runs the SIR ODE model with the estimated parameters, given in the xr.Dataset, and weights each time series with
45 its corresponding probability. The probabilities must be normalised to 1.
47 :param joint: the ``xr.Dataset`` of the joint parameter distribution
48 :param true_counts: the xr.Dataset of true counts
49 :param cfg: the run configuration of the original data (only needed if parameters are time dependent)
50 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities for all compartments
51 """
53 res = []
55 # Stack the parameters into a single coordinate
56 joint = joint.stack(sample=list(joint.coords))
58 # Remove parameters with probability 0
59 joint = joint.where(joint > 0, drop=True)
61 # Normalise the probabilities to 1 (this is not the same as integrating over the joint -- we are calculating the
62 # expectation value only over the samples we are drawing, not of the entire joint distribution!)
63 joint /= joint.sum()
65 sample_cfg = cfg["Data"]["synthetic_data"]
67 with logging_redirect_tqdm():
68 for s in (pbar := trange(len(joint.coords["sample"]))):
69 pbar.set_description(
70 f"Drawing {len(joint.coords['sample'])} samples from joint distribution: "
71 )
73 # Get the sample
74 sample = joint.isel({"sample": [s]})
76 # Construct the configuration, taking time-dependent parameters into account
77 sample = sample.unstack("sample")
79 sample_cfg.update({key: val.item() for key, val in sample.coords.items()})
80 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
82 # Generate smooth data
83 generated_data = generate_smooth_data(
84 cfg=param_cfg,
85 num_steps=len(true_counts.coords["time"]) - 1,
86 dt=cfg["Data"]["synthetic_data"].get("dt", None),
87 init_state=torch.from_numpy(
88 true_counts.isel({"time": 0}, drop=True).data
89 ).float(),
90 write_init_state=True,
91 ).numpy()
93 res.append(
94 xr.DataArray(
95 data=[[generated_data]],
96 dims=["sample", "type", "time", "kind", "dim_name__0"],
97 coords=dict(
98 sample=[s],
99 type=["mean prediction"],
100 time=true_counts.coords["time"],
101 kind=true_counts.coords["kind"],
102 dim_name__0=true_counts.coords["dim_name__0"],
103 ),
104 )
105 )
107 # Concatenate all the time series
108 res = xr.concat(res, dim="sample").squeeze(["dim_name__0"], drop=True)
110 # Get the index of the most likely parameter
111 mode = joint.isel({"sample": joint.argmax(dim="sample")})
112 sample_cfg.update({key: val.item() for key, val in mode.coords.items()})
114 # Perform a run using the mode
115 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
116 mode_data = generate_smooth_data(
117 cfg=mode_params,
118 num_steps=len(true_counts.coords["time"]) - 1,
119 dt=cfg["Data"]["synthetic_data"].get("dt", None),
120 init_state=torch.from_numpy(
121 true_counts.isel({"time": 0}, drop=True).data
122 ).float(),
123 write_init_state=True,
124 ).numpy()
126 mode_data = xr.DataArray(
127 data=[mode_data],
128 dims=["type", "time", "kind", "dim_name__0"],
129 coords=dict(
130 type=["mode prediction"],
131 time=true_counts.coords["time"],
132 kind=true_counts.coords["kind"],
133 dim_name__0=[0],
134 ),
135 ).squeeze(["dim_name__0"], drop=True)
137 # Reshape the probability array
138 prob = np.reshape(joint.data, (len(joint.coords["sample"]), 1, 1, 1))
139 prob_stacked = np.repeat(prob, 1, 1)
140 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["time"]), 2)
141 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["kind"]), 3)
143 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability
144 mean = (res * prob_stacked).sum("sample")
145 std = np.sqrt(((res - mean) ** 2 * prob).sum("sample"))
147 # Add a type to the true counts to allow concatenation
148 true_counts = true_counts.expand_dims({"type": ["true data"]}).squeeze(
149 "dim_name__0", drop=True
150 )
152 mean = xr.concat([mean, true_counts, mode_data], dim="type")
153 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type")
155 data = xr.Dataset(data_vars=dict(mean=mean, std=std))
157 return data