Coverage for model_plots/SIR/trajectories_from_densities.py: 21%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import logging 

2 

3import numpy as np 

4import torch 

5import xarray as xr 

6from tqdm import trange 

7from tqdm.contrib.logging import logging_redirect_tqdm 

8 

9from models.SIR import generate_smooth_data 

10from utopya.eval import is_operation 

11 

12log = logging.getLogger(__name__) 

13 

14 

15def _adjust_for_time_dependency( 

16 param_cfg: dict, cfg: dict, true_counts: xr.Dataset 

17) -> dict: 

18 """Adjusts the parameter configuration for time dependent parameters, if given.""" 

19 

20 time_dependent_parameters = cfg["Data"].get("time_dependent_parameters", {}) 

21 

22 # Extend any time-dependent parameters, if necessary 

23 for param in time_dependent_parameters.keys(): 

24 val = np.zeros(len(true_counts.coords["time"])) 

25 i = 0 

26 # Replace any time-dependent parameters with a series 

27 for j, interval in enumerate(time_dependent_parameters[param]): 

28 _, upper = interval 

29 if not upper: 

30 upper = len(val) 

31 while i < upper: 

32 val[i] = param_cfg[str(param) + f"_{j}"] 

33 i += 1 

34 param_cfg[param] = val 

35 

36 return param_cfg 

37 

38 

39@is_operation("SIR_densities_from_joint") 

40def densities_from_joint( 

41 joint: xr.Dataset, 

42 *, 

43 true_counts: xr.Dataset, 

44 cfg: dict, 

45) -> xr.Dataset: 

46 """Runs the SIR ODE model with the estimated parameters, given in the xr.Dataset, and weights each time series with 

47 its corresponding probability. The probabilities must be normalised to 1. 

48 

49 :param joint: the ``xr.Dataset`` of the joint parameter distribution 

50 :param true_counts: the xr.Dataset of true counts 

51 :param cfg: the run configuration of the original data (only needed if parameters are time dependent) 

52 :return: an ``xr.Dataset`` of the mean, mode, std, and true densities for all compartments 

53 """ 

54 

55 res = [] 

56 

57 # Stack the parameters into a single coordinate 

58 joint = joint.stack(sample=list(joint.coords)) 

59 

60 # Remove parameters with probability 0 

61 joint = joint.where(joint > 0, drop=True) 

62 

63 # Normalise the probabilities to 1 (this is not the same as integrating over the joint -- we are calculating the 

64 # expectation value only over the samples we are drawing, not of the entire joint distribution!) 

65 joint /= joint.sum() 

66 

67 sample_cfg = cfg["Data"]["synthetic_data"] 

68 

69 with logging_redirect_tqdm(): 

70 for s in (pbar := trange(len(joint.coords["sample"]))): 

71 pbar.set_description( 

72 f"Drawing {len(joint.coords['sample'])} samples from joint distribution: " 

73 ) 

74 

75 # Get the sample 

76 sample = joint.isel({"sample": [s]}) 

77 

78 # Construct the configuration, taking time-dependent parameters into account 

79 sample = sample.unstack("sample") 

80 

81 sample_cfg.update({key: val.item() for key, val in sample.coords.items()}) 

82 param_cfg = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

83 

84 # Generate smooth data 

85 generated_data = generate_smooth_data( 

86 cfg=param_cfg, 

87 num_steps=len(true_counts.coords["time"]) - 1, 

88 dt=cfg["Data"]["synthetic_data"].get("dt", None), 

89 init_state=torch.from_numpy( 

90 true_counts.isel({"time": 0}, drop=True).data 

91 ).float(), 

92 write_init_state=True, 

93 ).numpy() 

94 

95 res.append( 

96 xr.DataArray( 

97 data=[[generated_data]], 

98 dims=["sample", "type", "time", "kind", "dim_name__0"], 

99 coords=dict( 

100 sample=[s], 

101 type=["mean prediction"], 

102 time=true_counts.coords["time"], 

103 kind=true_counts.coords["kind"], 

104 dim_name__0=true_counts.coords["dim_name__0"], 

105 ), 

106 ) 

107 ) 

108 

109 # Concatenate all the time series 

110 res = xr.concat(res, dim="sample").squeeze(["dim_name__0"], drop=True) 

111 

112 # Get the index of the most likely parameter 

113 mode = joint.isel({"sample": joint.argmax(dim="sample")}) 

114 sample_cfg.update({key: val.item() for key, val in mode.coords.items()}) 

115 

116 # Perform a run using the mode 

117 mode_params = _adjust_for_time_dependency(sample_cfg, cfg, true_counts) 

118 mode_data = generate_smooth_data( 

119 cfg=mode_params, 

120 num_steps=len(true_counts.coords["time"]) - 1, 

121 dt=cfg["Data"]["synthetic_data"].get("dt", None), 

122 init_state=torch.from_numpy( 

123 true_counts.isel({"time": 0}, drop=True).data 

124 ).float(), 

125 write_init_state=True, 

126 ).numpy() 

127 

128 mode_data = xr.DataArray( 

129 data=[mode_data], 

130 dims=["type", "time", "kind", "dim_name__0"], 

131 coords=dict( 

132 type=["mode prediction"], 

133 time=true_counts.coords["time"], 

134 kind=true_counts.coords["kind"], 

135 dim_name__0=[0], 

136 ), 

137 ).squeeze(["dim_name__0"], drop=True) 

138 

139 # Reshape the probability array 

140 prob = np.reshape(joint.data, (len(joint.coords["sample"]), 1, 1, 1)) 

141 prob_stacked = np.repeat(prob, 1, 1) 

142 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["time"]), 2) 

143 prob_stacked = np.repeat(prob_stacked, len(true_counts.coords["kind"]), 3) 

144 

145 # Calculate the mean and standard deviation by multiplying the predictions with the associated probability 

146 mean = (res * prob_stacked).sum("sample") 

147 std = np.sqrt(((res - mean) ** 2 * prob).sum("sample")) 

148 

149 # Add a type to the true counts to allow concatenation 

150 true_counts = true_counts.expand_dims({"type": ["true data"]}).squeeze( 

151 "dim_name__0", drop=True 

152 ) 

153 

154 mean = xr.concat([mean, true_counts, mode_data], dim="type") 

155 std = xr.concat([std, 0 * true_counts, 0 * mode_data], dim="type") 

156 

157 data = xr.Dataset(data_vars=dict(mean=mean, std=std)) 

158 

159 return data