Coverage for model_plots / Covid / trajectories_from_densities.py: 25%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import logging 

2from typing import Union 

3 

4import numpy as np 

5import torch 

6import xarray as xr 

7from tqdm import trange 

8from tqdm.contrib.logging import logging_redirect_tqdm 

9 

10from models.Covid import generate_smooth_data 

11from utopya.eval import is_operation 

12 

13log = logging.getLogger(__name__) 

14 

15from ..SIR.trajectories_from_densities import _adjust_for_time_dependency 

16 

17def _combine_compartments( 

18 da: Union[xr.DataArray, xr.Dataset], 

19 combine: dict, 

20) -> Union[xr.DataArray, xr.Dataset]: 

21 """Combines certain compartments by summation""" 

22 

23 for key, values in combine.items(): 

24 q = da.sel({"kind": values}).sum("kind").expand_dims({"kind": [key]}) 

25 da = xr.concat(([da, q]), dim="kind") 

26 da = da.drop_sel({"kind": values}) 

27 

28 return da 

29 

30 

31def _drop_compartments( 

32 da: Union[xr.DataArray, xr.Dataset], 

33 drop: list, 

34) -> Union[xr.DataArray, xr.Dataset]: 

35 """Drops certain compartments""" 

36 return da.drop_sel({"kind": drop}) 

37 

38 

39def _calculate_residuals(da: xr.DataArray) -> xr.DataArray: 

40 """Calculates the residuals between the predictions and the true data""" 

41 

42 residuals = xr.DataArray( 

43 data=[ 

44 ( 

45 da.sel({"type": "prediction_mode"}, drop=True) 

46 - da.sel({"type": "true_counts"}, drop=True) 

47 ) 

48 / da.sel({"type": "true_counts"}, drop=True), 

49 ( 

50 da.sel({"type": "prediction_mean"}, drop=True) 

51 - da.sel({"type": "true_counts"}, drop=True) 

52 ) 

53 / da.sel({"type": "true_counts"}, drop=True), 

54 ], 

55 dims=["type", "time", "kind"], 

56 coords=dict( 

57 type=["mode_residual", "mean_residual"], 

58 kind=da.coords["kind"], 

59 time=da.coords["time"], 

60 ), 

61 ) 

62 

63 return xr.where(residuals != np.inf, residuals, np.nan) 

64 

65 

66@is_operation("Covid_densities_residuals") 

67def print_residuals(data, train_cut: int = 200): 

68 """Prints a summary of the residuals for each compartment""" 

69 

70 residuals = _calculate_residuals(data) ** 2 

71 

72 print_intervals = {"train": slice(None, train_cut), "test": slice(train_cut, None)} 

73 

74 for key, item in print_intervals.items(): 

75 log.remark("------------------------------------------------------") 

76 log.remark(f"L2 residuals in {key}") 

77 l2_mean = np.sqrt(residuals.isel({"time": item}).mean("time", skipna=True)) 

78 for k in residuals.coords["kind"]: 

79 log.remark( 

80 f" {k.item().capitalize()}: {np.around(l2_mean.sel({'kind': k, 'type': 'mean_residual'}).data.item(), 5)} (mean), {np.around(l2_mean.sel({'kind': k, 'type': 'mode_residual'}).data.item(), 5)} (mode)" 

81 ) 

82 log.remark( 

83 f" Average: {np.around(l2_mean.sel({'type': 'mean_residual'}).mean('kind').data.item(), 5)} (mean), {np.around(l2_mean.sel({'type': 'mode_residual'}).mean('kind').data.item(), 5)} (mode)" 

84 ) 

85 

86 return data 

87 

88 

89@is_operation("Covid_densities_from_joint") 

90def densities_from_joint( 

91 parameters: xr.Dataset, 

92 prob: xr.Dataset, 

93 *, 

94 true_counts: xr.Dataset, 

95 cfg: dict, 

96 combine: dict = None, 

97 drop: list = None, 

98 mean: xr.Dataset = None, 

99) -> xr.Dataset: 

100 """Runs the model with the estimated parameters, given in an ``xr.Dataset`` ``parameters``, 

101 and weights each time series with its corresponding probability, given ``prob``. 

102 The probabilities must be normalised to 1. 

103 

104 :param parameters: the ``xr.Dataset`` of parameter estimates, indexed by the sample dimension 

105 :param prob: the xr.Dataset of probabilities associated with each estimate, indexed by sample 

106 :param true_counts: the xr.Dataset of true counts 

107 :param cfg: the run configuration of the original data 

108 :param combine: (optional) dictionary of compartments to combine by summing 

109 :param drop: (optional) list of compartments to drop 

110 :param print_residuals: (optional) whether to print the L2 residuals on the test and training periods 

111 :param train_cut: (optional) time step separating the training from the test period 

112 (for printing the residuals only) 

113 :param mean: (optional) can pass the mean dataset instead of calculating it; needed for MCMC calculationss 

114 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities (if given) for all compartments 

115 """ 

116 

117 # Name of sample dimension 

118 sample_dim: str = list(prob.coords.keys())[0] 

119 

120 # Sample configuration 

121 sample_cfg = cfg["Data"]["synthetic_data"] 

122 sample_cfg["num_steps"] = len(true_counts.coords["time"]) 

123 sample_cfg["burn_in"] = 0 

124 res = [] 

125 

126 with logging_redirect_tqdm(): 

127 for s in (pbar := trange(len(parameters.coords[sample_dim]))): 

128 pbar.set_description( 

129 f"Drawing {len(parameters.coords[sample_dim])} samples from joint distribution: " 

130 ) 

131 

132 # Construct the configuration, taking time-dependent parameters into account 

133 sample = parameters.isel({sample_dim: s}, drop=True) 

134 sample_cfg.update( 

135 { 

136 p.item(): sample.sel({"parameter": p}).item() 

137 for p in sample.coords["parameter"] 

138 } 

139 ) 

140 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

141 

142 # Generate smooth data 

143 generated_data = generate_smooth_data( 

144 cfg=param_cfg, 

145 init_state=torch.from_numpy( 

146 true_counts.isel({"time": 0}, drop=True).data 

147 ).float(), 

148 ).numpy() 

149 

150 res.append( 

151 xr.DataArray( 

152 data=[[generated_data]], 

153 dims=[sample_dim, "type", "time", "kind", "dim_name__0"], 

154 coords=dict( 

155 **{sample_dim: [s]}, 

156 type=["prediction_mean"], 

157 **true_counts.coords, 

158 ), 

159 ).squeeze(["dim_name__0"], drop=True) 

160 ) 

161 

162 # Concatenate all the time series 

163 res = xr.concat(res, dim=sample_dim) 

164 

165 # Get the index of the most likely parameter 

166 mode_idx = prob.argmax(dim=sample_dim) 

167 sample_cfg.update( 

168 { 

169 p.item(): parameters.isel({sample_dim: mode_idx}, drop=True) 

170 .sel({"parameter": p}) 

171 .item() 

172 for p in parameters.coords["parameter"] 

173 } 

174 ) 

175 

176 # Perform a run using the mode 

177 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

178 mode_data = generate_smooth_data( 

179 cfg=mode_params, 

180 init_state=torch.from_numpy( 

181 true_counts.isel({"time": 0}, drop=True).data 

182 ).float(), 

183 ).numpy() 

184 

185 mode_data = xr.DataArray( 

186 data=[mode_data], 

187 dims=["type", "time", "kind", "dim_name__0"], 

188 coords=dict(type=["prediction_mode"], **true_counts.coords), 

189 ).squeeze(["dim_name__0"], drop=True) 

190 

191 # Combine compartments, if given 

192 if combine: 

193 res = _combine_compartments(res, combine) 

194 mode_data = _combine_compartments(mode_data, combine) 

195 true_counts = _combine_compartments(true_counts, combine) 

196 

197 # Drop compartments, if given 

198 if drop: 

199 res = _drop_compartments(res, drop) 

200 mode_data = _drop_compartments(mode_data, drop) 

201 true_counts = _drop_compartments(true_counts, drop) 

202 

203 # Reshape the probability array 

204 prob = np.reshape(prob.data, (len(prob.coords[sample_dim]), 1, 1, 1)) 

205 prob_stacked = np.repeat(prob, 1, 1) 

206 prob_stacked = np.repeat(prob_stacked, len(res.coords["time"]), 2) 

207 prob_stacked = np.repeat(prob_stacked, len(res.coords["kind"]), 3) 

208 

209 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability 

210 mean = (res * prob_stacked).sum(sample_dim) if mean is None else mean 

211 std = np.sqrt(((res - mean) ** 2 * prob).sum(sample_dim)) 

212 

213 # Add a type to the true counts to allow concatenation 

214 true_counts = true_counts.expand_dims({"type": ["true_counts"]}).squeeze( 

215 "dim_name__0", drop=True 

216 ) 

217 

218 mean = xr.concat([mean, true_counts, mode_data], dim="type") 

219 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type") 

220 

221 return xr.Dataset(data_vars=dict(mean=mean, std=std))