Coverage for model_plots/nw_ops.py: 43%

81 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-05 17:26 +0000

1import sys 

2 

3import scipy.integrate 

4 

5from .data_ops import * 

6 

7 

8# ---------------------------------------------------------------------------------------------------------------------- 

9# ADJACENCY MATRIX OPERATIONS 

10# ---------------------------------------------------------------------------------------------------------------------- 

11@is_operation("triangles") 

12@apply_along_dim 

13def triangles( 

14 A: xr.DataArray, 

15 *args, 

16 input_core_dims: list = ["j"], 

17 offset=0, 

18 axis1=1, 

19 axis2=2, 

20 directed: bool = True, 

21 **kwargs, 

22) -> xr.DataArray: 

23 """Calculates the number of triangles on each node from an adjacency matrix A along one dimension. 

24 The number of triangles are given by 

25 

26 t(i) = sum_{jk} A_{ij} A_{jk} A_{ki} 

27 

28 in the directed case, which is simply the i-th entry of the diagonal of A**3. If the network is directed, 

29 the number of triangles must be divided by 2. It is recommended to use ``xr.apply_ufunc`` for the inner 

30 (the sample) dimension, as the ``apply_along_dim`` decorator is quite slow. 

31 

32 :param A: the adjacency matrix 

33 :param offset: (optional) passed to ``np.diagonal``. Offset of the diagonal from the main diagonal. 

34 Can be positive or negative. Defaults to main diagonal (0). 

35 :param axis1: (optional) passed to ``np.diagonal``. Axis to be used as the first axis of the 

36 2-D sub-arrays from which the diagonals should be taken. Defaults to first axis (0). 

37 :param axis2: (optional) passed to ``np.diagonal``. Axis to be used as the second axis of the 2-D sub-arrays from 

38 which the diagonals should be taken. Defaults to second axis (1). 

39 :param input_core_dims: passed to ``xr.apply_ufunc`` 

40 :param directed: (optional, bool) whether the network is directed. If not, the number of triangle on each node 

41 is divided by 2. 

42 :param args, kwargs: additional args and kwargs passed to ``np.linalg.matrix_power`` 

43 """ 

44 

45 res = xr.apply_ufunc( 

46 np.diagonal, 

47 xr.apply_ufunc(np.linalg.matrix_power, A, 3, *args, **kwargs), 

48 offset, 

49 axis1, 

50 axis2, 

51 input_core_dims=[input_core_dims, [], [], []], 

52 ) 

53 

54 if not directed: 

55 res /= 2 

56 

57 return res.rename("triangles") 

58 

59 

60@is_operation("binned_nw_statistic") 

61@apply_along_dim 

62def binned_nw_statistic( 

63 nw_statistic: xr.DataArray, 

64 *, 

65 bins: Any, 

66 ranges: Sequence = None, 

67 normalize: Union[bool, float] = False, 

68 sample_dim: str = "batch", 

69 **kwargs, 

70) -> xr.DataArray: 

71 """Calculates a binned statistic from an adjacency matrix statistic along the batch dimension. This function uses 

72 the `hist_1D` function to speed up computation. Since network statistics are binned along common x-values for each 

73 prediction element, the x-coordinate is written as a coordinate, rather than a variable like in other marginal 

74 calculations. 

75 

76 :param nw_statistic: the xr.DataArray of adjacency matrix statistics (e.g. the degrees), indexed by 'batch' 

77 :param bins: bins to use. Any argument admitted by `np.histogram` is permissible 

78 :param ranges: (float, float), optional: range of the bins to use. Defaults to the minimum and maximum value 

79 along *all* predictions. 

80 :param normalize: whether to normalize bin counts. 

81 :param sample_dim: name of the sampling dimension, which will be excluded from histogramming 

82 :param kwargs: passed to ``hist`` 

83 :return: xr.Dataset of binned statistics, indexed by the batch index and x-value 

84 """ 

85 

86 _along_dim = list(nw_statistic.coords) 

87 _along_dim.remove(sample_dim) 

88 return hist( 

89 nw_statistic, 

90 bins=bins, 

91 ranges=ranges, 

92 dim=_along_dim[0], 

93 normalize=normalize, 

94 use_bins_as_coords=True, 

95 **kwargs, 

96 ).rename("y") 

97 

98 

99@is_operation("sel_matrix_indices") 

100@apply_along_dim 

101def sel_matrix_indices( 

102 A: xr.DataArray, indices: xr.Dataset, drop: bool = False 

103) -> xr.DataArray: 

104 """Selects entries from an adjacency matrix A given in ``indices``. If specified, coordinate labels 

105 are dropped. 

106 

107 :param A: adjacency matrix with rows and columns labelled ``i`` and ``j`` 

108 :param indices: ``xr.Dataset`` of indices to dropped; variables should be ``i`` and ``j`` 

109 :param drop: whether to drop the ``i`` and ``j`` coordinate labels 

110 :return: selected entries of ``A`` 

111 """ 

112 

113 A = A.isel(i=(indices["i"]), j=(indices["j"])) 

114 return A.drop_vars(["i", "j"]) if drop else A 

115 

116 

117@is_operation("largest_entry_indices") 

118@apply_along_dim 

119def largest_entry_indices( 

120 A: xr.DataArray, n: int, *, symmetric: bool = False 

121) -> xr.Dataset: 

122 """Returns the two-dimensional indices of the n largest entries in an adjacency matrix as well as the corresponding 

123 values. If the matrix is symmetric, only the upper triangle is considered. The entries are returned sorted from 

124 highest to lowest. 

125 

126 :param A: adjacency matrix 

127 :param n: number of entries to obtain 

128 :param symmetric: (optional) whether the adjacency matrix is symmetric 

129 :return: ``xr.Dataset`` of largest entries and their indices 

130 """ 

131 

132 if symmetric: 

133 indices_i, indices_j = np.unravel_index( 

134 np.argsort(np.triu(A).ravel()), np.shape(A) 

135 ) 

136 else: 

137 indices_i, indices_j = np.unravel_index(np.argsort(A.data.ravel()), np.shape(A)) 

138 

139 i, j = indices_i[-n:][::-1], indices_j[-n:][::-1] 

140 vals = A.data[i, j] 

141 

142 return xr.Dataset( 

143 data_vars=dict(i=("idx", i), j=("idx", j), value=("idx", vals)), 

144 coords=dict(idx=("idx", np.arange(len(i)))), 

145 ) 

146 

147 

148# ---------------------------------------------------------------------------------------------------------------------- 

149# DISTRIBUTION OPERATIONS 

150# ---------------------------------------------------------------------------------------------------------------------- 

151 

152 

153@is_operation("marginal_distribution") 

154@apply_along_dim 

155def marginal_distribution( 

156 predictions: xr.DataArray, 

157 probabilities: xr.DataArray, 

158 true_values: xr.DataArray = None, 

159 *, 

160 bin_coord: str = "x", 

161 y: str = "MLE", 

162 yerr: str = "std", 

163 **kwargs, 

164) -> xr.Dataset: 

165 """Calculates the marginal distribution from a dataset of binned network statistic (e.g. degree distributions). 

166 The joint of the statistics and the loss is calculated, the marginal over the loss returned, with a y and yerr value 

167 calculated. The y value can either be the mean or the mode distribution, and the yerr value is the standard deviation 

168 of the marginal over the loss on each statistic bin. If passed, the true distribution is also appended to the dataset. 

169 

170 :param predictions: 2D ``xr.DataArray`` of predictions; indexed by sample dimension and bin dimension 

171 :param probabilities: 1D ``xr.DataArray`` of probabilities, indexed by sample dimension 

172 :param true_values: (optional) 1D ``xr.DataArray`` of true distributions, indexed by bin dimension 

173 :param bin_coord: (optional) name of the x-dimension; default is 'x' 

174 :param y: statistic to calculate for the y variable; default is the maximum likelihood estimator, can also be the 

175 ``mean`` 

176 :param yerr: error statistic to use for the y variable; default is the standard deviation (std), but can also be 

177 the interquartile range (iqr) 

178 :param kwargs: kwargs, passed to ``marginal_from_ds`` 

179 :return: ``xr.Dataset`` of y and yerr values as variables, and x-values as coordinates. If the true values are 

180 passed, also contains a ``type`` dimension. 

181 """ 

182 

183 # Temporarily rename the 'x' dimension to avoid potential naming conflicts with the marginal operation, 

184 # which also produces 'x' values. This is only strictly necessary if the x-dimension is called 'x'. 

185 predictions = predictions.rename({bin_coord: f"_{bin_coord}"}) 

186 

187 # Broadcast the predictions and probabilities together 

188 predictions_and_loss = broadcast(predictions, probabilities, x="y", p="prob") 

189 

190 # Calculate the distribution marginal for each bin 

191 marginals = marginal_from_ds( 

192 predictions_and_loss, x="y", y="prob", exclude_dim=[f"_{bin_coord}"], **kwargs 

193 ) 

194 

195 # Calculate the y-statistic: mode (default) or mean 

196 if y == "mode" or y == "MLE": 

197 p_max = probabilities.idxmax() 

198 _y_vals = predictions.sel({p_max.name: p_max.data}, drop=True) 

199 elif y == "mean": 

200 _y_vals = mean(marginals, along_dim=["bin_idx"], x="x", y="y")["mean"] 

201 

202 # Calculate the standard deviation from y 

203 _y_err_vals: xr.DataArray = stat_function( 

204 marginals, along_dim=["bin_idx"], x="x", y="y", stat=yerr 

205 )[yerr] 

206 

207 # Interquartile range is total range, so divide by 2, since errorbands are shown as ± err 

208 if yerr == "iqr": 

209 _y_err_vals /= 2 

210 

211 # Combine y and yerr values into a single dataset and rename the 'x' dimension 

212 res = xr.Dataset(dict(y=_y_vals, yerr=_y_err_vals)).rename({f"_{bin_coord}": bin_coord}) 

213 

214 # If the true values were given, add to the dataset. The true values naturally have zero error. 

215 if true_values is not None: 

216 # Assign the x coordinates from res to ensure compatibility, they should be the same anyway 

217 # but might be different because of precision errors 

218 true_values = xr.Dataset( 

219 dict(y=true_values, yerr=0 * true_values) 

220 ).assign_coords({bin_coord: res.coords[bin_coord]}) 

221 res = concat([res, true_values], "type", [y, "True values"]) 

222 

223 return res 

224 

225 

226@is_operation("marginal_distribution_stats") 

227@apply_along_dim 

228def marginal_distribution_stats( 

229 predictions: xr.DataArray, 

230 probabilities: xr.DataArray, 

231 *, 

232 distance_to: str = None, 

233 stat: Sequence, 

234 **kwargs, 

235) -> xr.DataArray: 

236 """Calculates the statistics of a marginal distribution. This operation circumvents having to first compile 

237 marginals for all dimensions when sweeping, only to then apply a statistics function along the bin dimension, 

238 thereby saving memory. 

239 

240 The ``std`` and ``Hellinger`` and ``KL`` error statistics require different marginalisations: the first requires 

241 marginalising over the probability, while the second and third require marginalising over the counts. This is 

242 because the ``Hellinger`` and ``KL`` divergences require the probability bins to line up, i.e. to represent the 

243 same predicted distribution, so that the distance to a target distribution can be computed. 

244 

245 :param ds: dataset containing x and y variables for which to calculate the marginal 

246 :param bins: bins to use for the marginal 

247 :param ranges: ranges to use for the marginal 

248 :param x: x dimension 

249 :param y: function values p(x) 

250 :param stats: list or string of statistics to calculate. Can be any argument accepted by ``_stat_function``, or 

251 ``mode``, ``Hellinger``, or ``KL``. 

252 :param kwargs: additional kwargs, passed to the marginal function 

253 :return: xr.Dataset of marginal statistics 

254 """ 

255 

256 stat = set(stat) 

257 if "Hellinger" in stat or "KL" in stat: 

258 if distance_to is None: 

259 raise ValueError( 

260 f"Calculating Hellinger or relative entropy statistics requires the 'distance_to' kwarg!" 

261 ) 

262 

263 # Temporarily rename the 'x' dimension to avoid naming conflicts with the marginal operation, 

264 # which also produces 'x' values 

265 predictions = predictions.rename({"x": "_x"}) 

266 

267 # Broadcast the predictions and probabilities together, and drop any distributions that are completely zero 

268 predictions_and_loss = broadcast(predictions, probabilities, x="y", p="prob") 

269 predictions_and_loss = predictions_and_loss.where( 

270 predictions_and_loss["prob"] > 0, drop=True 

271 ) 

272 

273 # Calculate the distribution marginal for each bin. These are different for the Hellinger and KL divergences, 

274 # since these require the marginal coordinates to align for each _x value, since they must represent one 

275 # single distribution. 

276 if stat != {"KL", "Hellinger"} or ( 

277 ("Hellinger" in stat or "KL" in stat) and "distance_to" == "mean" 

278 ): 

279 marginal_over_p = marginal_from_ds( 

280 predictions_and_loss, x="y", y="prob", exclude_dim=["_x"], **kwargs 

281 ) 

282 

283 if "Hellinger" in stat or "KL" in stat: 

284 # For Hellinger and KL statistics, marginalise over the counts dimension 

285 marginal_over_counts = marginal_from_ds( 

286 predictions_and_loss, 

287 x="prob", 

288 y="y", 

289 exclude_dim=["_x"], 

290 normalize=False, 

291 **kwargs, 

292 ) 

293 

294 # Get the Q distribution with respect to which the error is to be calculated 

295 if distance_to == "mode" or distance_to == "MLE": 

296 _y_vals = marginal_over_counts["y"].isel({"bin_idx": -1}, drop=True) 

297 elif distance_to == "mean": 

298 _y_vals = mean(marginal_over_p, along_dim=["bin_idx"], x="x", y="y")["mean"] 

299 

300 # Get the binned loss values associated with each marginal entry 

301 prob_binned = marginal_over_counts["x"].isel({"_x": 0}, drop=True) 

302 

303 # Calculate all required statistics 

304 res = [] 

305 

306 # Calculate the standard deviation from y 

307 for _stat in stat: 

308 # Average Hellinger distance 

309 if _stat == "Hellinger": 

310 _distributions = Hellinger_distance( 

311 marginal_over_counts["y"], 

312 _y_vals.expand_dims( 

313 {"bin_idx": marginal_over_counts.coords["bin_idx"]} 

314 ), 

315 exclude_dim=["bin_idx"], 

316 ) 

317 _err = ( 

318 prob_binned * _distributions["Hellinger_distance"] / prob_binned.sum() 

319 ).sum("bin_idx") 

320 res.append( 

321 xr.DataArray(_err.data, name="stat").expand_dims({"type": [_stat]}) 

322 ) 

323 

324 # Average relative entropy 

325 elif _stat == "KL": 

326 _distributions = relative_entropy( 

327 marginal_over_counts["y"], 

328 _y_vals.expand_dims( 

329 {"bin_idx": marginal_over_counts.coords["bin_idx"]} 

330 ), 

331 exclude_dim=["bin_idx"], 

332 ) 

333 

334 _err = ( 

335 prob_binned 

336 * abs(_distributions["relative_entropy"]) 

337 / prob_binned.sum("bin_idx") 

338 ).sum("bin_idx") 

339 res.append( 

340 xr.DataArray(_err.data, name="stat").expand_dims({"type": [_stat]}) 

341 ) 

342 else: 

343 # Integrate the standard deviation along x 

344 _err = stat_function( 

345 marginal_over_p, along_dim=["bin_idx"], x="x", y="y", stat=_stat 

346 )[_stat] 

347 res.append( 

348 xr.DataArray( 

349 scipy.integrate.trapezoid(_err.data, _err.coords["_x"]), name="stat" 

350 ).expand_dims({"type": [_stat]}) 

351 ) 

352 

353 return xr.concat(res, "type")