Coverage for model_plots / SIR / trajectories_from_densities.py: 21%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 16:26 +0000

1import logging 

2import numpy as np 

3import torch 

4import xarray as xr 

5from tqdm import trange 

6from tqdm.contrib.logging import logging_redirect_tqdm 

7 

8from models.SIR import generate_smooth_data 

9from utopya.eval import is_operation 

10 

11log = logging.getLogger(__name__) 

12 

13def _adjust_for_time_dependency( 

14 param_cfg: dict, cfg: dict, true_counts: xr.Dataset 

15) -> dict: 

16 """Adjusts the parameter configuration for time dependent parameters, if given.""" 

17 

18 time_dependent_parameters = cfg["Data"].get("time_dependent_parameters", {}) 

19 

20 # Extend any time-dependent parameters, if necessary 

21 for param in time_dependent_parameters.keys(): 

22 val = np.zeros(len(true_counts.coords["time"])) 

23 i = 0 

24 # Replace any time-dependent parameters with a series 

25 for j, interval in enumerate(time_dependent_parameters[param]): 

26 _, upper = interval 

27 if not upper: 

28 upper = len(val) 

29 while i < upper: 

30 val[i] = param_cfg[str(param) + f"_{j}"] 

31 i += 1 

32 param_cfg[param] = val 

33 

34 return param_cfg 

35 

36 

37@is_operation("SIR_densities_from_joint") 

38def densities_from_joint( 

39 joint: xr.Dataset, 

40 *, 

41 true_counts: xr.Dataset, 

42 cfg: dict, 

43) -> xr.Dataset: 

44 """Runs the SIR ODE model with the estimated parameters, given in the xr.Dataset, and weights each time series with 

45 its corresponding probability. The probabilities must be normalised to 1. 

46 

47 :param joint: the ``xr.Dataset`` of the joint parameter distribution 

48 :param true_counts: the xr.Dataset of true counts 

49 :param cfg: the run configuration of the original data (only needed if parameters are time dependent) 

50 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities for all compartments 

51 """ 

52 

53 res = [] 

54 

55 # Stack the parameters into a single coordinate 

56 joint = joint.stack(sample=list(joint.coords)) 

57 

58 # Remove parameters with probability 0 

59 joint = joint.where(joint > 0, drop=True) 

60 

61 # Normalise the probabilities to 1 (this is not the same as integrating over the joint -- we are calculating the 

62 # expectation value only over the samples we are drawing, not of the entire joint distribution!) 

63 joint /= joint.sum() 

64 

65 sample_cfg = cfg["Data"]["synthetic_data"] 

66 

67 with logging_redirect_tqdm(): 

68 for s in (pbar := trange(len(joint.coords["sample"]))): 

69 pbar.set_description( 

70 f"Drawing {len(joint.coords['sample'])} samples from joint distribution: " 

71 ) 

72 

73 # Get the sample 

74 sample = joint.isel({"sample": [s]}) 

75 

76 # Construct the configuration, taking time-dependent parameters into account 

77 sample = sample.unstack("sample") 

78 

79 sample_cfg.update({key: val.item() for key, val in sample.coords.items()}) 

80 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

81 

82 # Generate smooth data 

83 generated_data = generate_smooth_data( 

84 cfg=param_cfg, 

85 num_steps=len(true_counts.coords["time"]) - 1, 

86 dt=cfg["Data"]["synthetic_data"].get("dt", None), 

87 init_state=torch.from_numpy( 

88 true_counts.isel({"time": 0}, drop=True).data 

89 ).float(), 

90 write_init_state=True, 

91 ).numpy() 

92 

93 res.append( 

94 xr.DataArray( 

95 data=[[generated_data]], 

96 dims=["sample", "type", "time", "kind", "dim_name__0"], 

97 coords=dict( 

98 sample=[s], 

99 type=["mean prediction"], 

100 time=true_counts.coords["time"], 

101 kind=true_counts.coords["kind"], 

102 dim_name__0=true_counts.coords["dim_name__0"], 

103 ), 

104 ) 

105 ) 

106 

107 # Concatenate all the time series 

108 res = xr.concat(res, dim="sample").squeeze(["dim_name__0"], drop=True) 

109 

110 # Get the index of the most likely parameter 

111 mode = joint.isel({"sample": joint.argmax(dim="sample")}) 

112 sample_cfg.update({key: val.item() for key, val in mode.coords.items()}) 

113 

114 # Perform a run using the mode 

115 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

116 mode_data = generate_smooth_data( 

117 cfg=mode_params, 

118 num_steps=len(true_counts.coords["time"]) - 1, 

119 dt=cfg["Data"]["synthetic_data"].get("dt", None), 

120 init_state=torch.from_numpy( 

121 true_counts.isel({"time": 0}, drop=True).data 

122 ).float(), 

123 write_init_state=True, 

124 ).numpy() 

125 

126 mode_data = xr.DataArray( 

127 data=[mode_data], 

128 dims=["type", "time", "kind", "dim_name__0"], 

129 coords=dict( 

130 type=["mode prediction"], 

131 time=true_counts.coords["time"], 

132 kind=true_counts.coords["kind"], 

133 dim_name__0=[0], 

134 ), 

135 ).squeeze(["dim_name__0"], drop=True) 

136 

137 # Reshape the probability array 

138 prob = np.reshape(joint.data, (len(joint.coords["sample"]), 1, 1, 1)) 

139 prob_stacked = np.repeat(prob, 1, 1) 

140 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["time"]), 2) 

141 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["kind"]), 3) 

142 

143 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability 

144 mean = (res * prob_stacked).sum("sample") 

145 std = np.sqrt(((res - mean) ** 2 * prob).sum("sample")) 

146 

147 # Add a type to the true counts to allow concatenation 

148 true_counts = true_counts.expand_dims({"type": ["true data"]}).squeeze( 

149 "dim_name__0", drop=True 

150 ) 

151 

152 mean = xr.concat([mean, true_counts, mode_data], dim="type") 

153 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type") 

154 

155 data = xr.Dataset(data_vars=dict(mean=mean, std=std)) 

156 

157 return data