Coverage for model_plots/SIR/trajectories_from_densities.py: 21%
56 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import logging
3import numpy as np
4import torch
5import xarray as xr
6from tqdm import trange
7from tqdm.contrib.logging import logging_redirect_tqdm
9from models.SIR import generate_smooth_data
10from utopya.eval import is_operation
12log = logging.getLogger(__name__)
15def _adjust_for_time_dependency(
16 param_cfg: dict, cfg: dict, true_counts: xr.Dataset
17) -> dict:
18 """Adjusts the parameter configuration for time dependent parameters, if given."""
20 time_dependent_parameters = cfg["Data"].get("time_dependent_parameters", {})
22 # Extend any time-dependent parameters, if necessary
23 for param in time_dependent_parameters.keys():
24 val = np.zeros(len(true_counts.coords["time"]))
25 i = 0
26 # Replace any time-dependent parameters with a series
27 for j, interval in enumerate(time_dependent_parameters[param]):
28 _, upper = interval
29 if not upper:
30 upper = len(val)
31 while i < upper:
32 val[i] = param_cfg[str(param) + f"_{j}"]
33 i += 1
34 param_cfg[param] = val
36 return param_cfg
39@is_operation("SIR_densities_from_joint")
40def densities_from_joint(
41 joint: xr.Dataset,
42 *,
43 true_counts: xr.Dataset,
44 cfg: dict,
45) -> xr.Dataset:
46 """Runs the SIR ODE model with the estimated parameters, given in the xr.Dataset, and weights each time series with
47 its corresponding probability. The probabilities must be normalised to 1.
49 :param joint: the ``xr.Dataset`` of the joint parameter distribution
50 :param true_counts: the xr.Dataset of true counts
51 :param cfg: the run configuration of the original data (only needed if parameters are time dependent)
52 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities for all compartments
53 """
55 res = []
57 # Stack the parameters into a single coordinate
58 joint = joint.stack(sample=list(joint.coords))
60 # Remove parameters with probability 0
61 joint = joint.where(joint > 0, drop=True)
63 # Normalise the probabilities to 1 (this is not the same as integrating over the joint -- we are calculating the
64 # expectation value only over the samples we are drawing, not of the entire joint distribution!)
65 joint /= joint.sum()
67 sample_cfg = cfg["Data"]["synthetic_data"]
69 with logging_redirect_tqdm():
70 for s in (pbar := trange(len(joint.coords["sample"]))):
71 pbar.set_description(
72 f"Drawing {len(joint.coords['sample'])} samples from joint distribution: "
73 )
75 # Get the sample
76 sample = joint.isel({"sample": [s]})
78 # Construct the configuration, taking time-dependent parameters into account
79 sample = sample.unstack("sample")
81 sample_cfg.update({key: val.item() for key, val in sample.coords.items()})
82 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
84 # Generate smooth data
85 generated_data = generate_smooth_data(
86 cfg=param_cfg,
87 num_steps=len(true_counts.coords["time"]) - 1,
88 dt=cfg["Data"]["synthetic_data"].get("dt", None),
89 init_state=torch.from_numpy(
90 true_counts.isel({"time": 0}, drop=True).data
91 ).float(),
92 write_init_state=True,
93 ).numpy()
95 res.append(
96 xr.DataArray(
97 data=[[generated_data]],
98 dims=["sample", "type", "time", "kind", "dim_name__0"],
99 coords=dict(
100 sample=[s],
101 type=["mean prediction"],
102 time=true_counts.coords["time"],
103 kind=true_counts.coords["kind"],
104 dim_name__0=true_counts.coords["dim_name__0"],
105 ),
106 )
107 )
109 # Concatenate all the time series
110 res = xr.concat(res, dim="sample").squeeze(["dim_name__0"], drop=True)
112 # Get the index of the most likely parameter
113 mode = joint.isel({"sample": joint.argmax(dim="sample")})
114 sample_cfg.update({key: val.item() for key, val in mode.coords.items()})
116 # Perform a run using the mode
117 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts)
118 mode_data = generate_smooth_data(
119 cfg=mode_params,
120 num_steps=len(true_counts.coords["time"]) - 1,
121 dt=cfg["Data"]["synthetic_data"].get("dt", None),
122 init_state=torch.from_numpy(
123 true_counts.isel({"time": 0}, drop=True).data
124 ).float(),
125 write_init_state=True,
126 ).numpy()
128 mode_data = xr.DataArray(
129 data=[mode_data],
130 dims=["type", "time", "kind", "dim_name__0"],
131 coords=dict(
132 type=["mode prediction"],
133 time=true_counts.coords["time"],
134 kind=true_counts.coords["kind"],
135 dim_name__0=[0],
136 ),
137 ).squeeze(["dim_name__0"], drop=True)
139 # Reshape the probability array
140 prob = np.reshape(joint.data, (len(joint.coords["sample"]), 1, 1, 1))
141 prob_stacked = np.repeat(prob, 1, 1)
142 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["time"]), 2)
143 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["kind"]), 3)
145 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability
146 mean = (res * prob_stacked).sum("sample")
147 std = np.sqrt(((res - mean) ** 2 * prob).sum("sample"))
149 # Add a type to the true counts to allow concatenation
150 true_counts = true_counts.expand_dims({"type": ["true data"]}).squeeze(
151 "dim_name__0", drop=True
152 )
154 mean = xr.concat([mean, true_counts, mode_data], dim="type")
155 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type")
157 data = xr.Dataset(data_vars=dict(mean=mean, std=std))
159 return data