Coverage for model_plots/nw_ops.py: 43%
81 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-05 17:26 +0000
1import sys
3import scipy.integrate
5from .data_ops import *
8# ----------------------------------------------------------------------------------------------------------------------
9# ADJACENCY MATRIX OPERATIONS
10# ----------------------------------------------------------------------------------------------------------------------
11@is_operation("triangles")
12@apply_along_dim
13def triangles(
14 A: xr.DataArray,
15 *args,
16 input_core_dims: list = ["j"],
17 offset=0,
18 axis1=1,
19 axis2=2,
20 directed: bool = True,
21 **kwargs,
22) -> xr.DataArray:
23 """Calculates the number of triangles on each node from an adjacency matrix A along one dimension.
24 The number of triangles are given by
26 t(i) = sum_{jk} A_{ij} A_{jk} A_{ki}
28 in the directed case, which is simply the i-th entry of the diagonal of A**3. If the network is directed,
29 the number of triangles must be divided by 2. It is recommended to use ``xr.apply_ufunc`` for the inner
30 (the sample) dimension, as the ``apply_along_dim`` decorator is quite slow.
32 :param A: the adjacency matrix
33 :param offset: (optional) passed to ``np.diagonal``. Offset of the diagonal from the main diagonal.
34 Can be positive or negative. Defaults to main diagonal (0).
35 :param axis1: (optional) passed to ``np.diagonal``. Axis to be used as the first axis of the
36 2-D sub-arrays from which the diagonals should be taken. Defaults to first axis (0).
37 :param axis2: (optional) passed to ``np.diagonal``. Axis to be used as the second axis of the 2-D sub-arrays from
38 which the diagonals should be taken. Defaults to second axis (1).
39 :param input_core_dims: passed to ``xr.apply_ufunc``
40 :param directed: (optional, bool) whether the network is directed. If not, the number of triangle on each node
41 is divided by 2.
42 :param args, kwargs: additional args and kwargs passed to ``np.linalg.matrix_power``
43 """
45 res = xr.apply_ufunc(
46 np.diagonal,
47 xr.apply_ufunc(np.linalg.matrix_power, A, 3, *args, **kwargs),
48 offset,
49 axis1,
50 axis2,
51 input_core_dims=[input_core_dims, [], [], []],
52 )
54 if not directed:
55 res /= 2
57 return res.rename("triangles")
60@is_operation("binned_nw_statistic")
61@apply_along_dim
62def binned_nw_statistic(
63 nw_statistic: xr.DataArray,
64 *,
65 bins: Any,
66 ranges: Sequence = None,
67 normalize: Union[bool, float] = False,
68 sample_dim: str = "batch",
69 **kwargs,
70) -> xr.DataArray:
71 """Calculates a binned statistic from an adjacency matrix statistic along the batch dimension. This function uses
72 the `hist_1D` function to speed up computation. Since network statistics are binned along common x-values for each
73 prediction element, the x-coordinate is written as a coordinate, rather than a variable like in other marginal
74 calculations.
76 :param nw_statistic: the xr.DataArray of adjacency matrix statistics (e.g. the degrees), indexed by 'batch'
77 :param bins: bins to use. Any argument admitted by `np.histogram` is permissible
78 :param ranges: (float, float), optional: range of the bins to use. Defaults to the minimum and maximum value
79 along *all* predictions.
80 :param normalize: whether to normalize bin counts.
81 :param sample_dim: name of the sampling dimension, which will be excluded from histogramming
82 :param kwargs: passed to ``hist``
83 :return: xr.Dataset of binned statistics, indexed by the batch index and x-value
84 """
86 _along_dim = list(nw_statistic.coords)
87 _along_dim.remove(sample_dim)
88 return hist(
89 nw_statistic,
90 bins=bins,
91 ranges=ranges,
92 dim=_along_dim[0],
93 normalize=normalize,
94 use_bins_as_coords=True,
95 **kwargs,
96 ).rename("y")
99@is_operation("sel_matrix_indices")
100@apply_along_dim
101def sel_matrix_indices(
102 A: xr.DataArray, indices: xr.Dataset, drop: bool = False
103) -> xr.DataArray:
104 """Selects entries from an adjacency matrix A given in ``indices``. If specified, coordinate labels
105 are dropped.
107 :param A: adjacency matrix with rows and columns labelled ``i`` and ``j``
108 :param indices: ``xr.Dataset`` of indices to dropped; variables should be ``i`` and ``j``
109 :param drop: whether to drop the ``i`` and ``j`` coordinate labels
110 :return: selected entries of ``A``
111 """
113 A = A.isel(i=(indices["i"]), j=(indices["j"]))
114 return A.drop_vars(["i", "j"]) if drop else A
117@is_operation("largest_entry_indices")
118@apply_along_dim
119def largest_entry_indices(
120 A: xr.DataArray, n: int, *, symmetric: bool = False
121) -> xr.Dataset:
122 """Returns the two-dimensional indices of the n largest entries in an adjacency matrix as well as the corresponding
123 values. If the matrix is symmetric, only the upper triangle is considered. The entries are returned sorted from
124 highest to lowest.
126 :param A: adjacency matrix
127 :param n: number of entries to obtain
128 :param symmetric: (optional) whether the adjacency matrix is symmetric
129 :return: ``xr.Dataset`` of largest entries and their indices
130 """
132 if symmetric:
133 indices_i, indices_j = np.unravel_index(
134 np.argsort(np.triu(A).ravel()), np.shape(A)
135 )
136 else:
137 indices_i, indices_j = np.unravel_index(np.argsort(A.data.ravel()), np.shape(A))
139 i, j = indices_i[-n:][::-1], indices_j[-n:][::-1]
140 vals = A.data[i, j]
142 return xr.Dataset(
143 data_vars=dict(i=("idx", i), j=("idx", j), value=("idx", vals)),
144 coords=dict(idx=("idx", np.arange(len(i)))),
145 )
148# ----------------------------------------------------------------------------------------------------------------------
149# DISTRIBUTION OPERATIONS
150# ----------------------------------------------------------------------------------------------------------------------
153@is_operation("marginal_distribution")
154@apply_along_dim
155def marginal_distribution(
156 predictions: xr.DataArray,
157 probabilities: xr.DataArray,
158 true_values: xr.DataArray = None,
159 *,
160 bin_coord: str = "x",
161 y: str = "MLE",
162 yerr: str = "std",
163 **kwargs,
164) -> xr.Dataset:
165 """Calculates the marginal distribution from a dataset of binned network statistic (e.g. degree distributions).
166 The joint of the statistics and the loss is calculated, the marginal over the loss returned, with a y and yerr value
167 calculated. The y value can either be the mean or the mode distribution, and the yerr value is the standard deviation
168 of the marginal over the loss on each statistic bin. If passed, the true distribution is also appended to the dataset.
170 :param predictions: 2D ``xr.DataArray`` of predictions; indexed by sample dimension and bin dimension
171 :param probabilities: 1D ``xr.DataArray`` of probabilities, indexed by sample dimension
172 :param true_values: (optional) 1D ``xr.DataArray`` of true distributions, indexed by bin dimension
173 :param bin_coord: (optional) name of the x-dimension; default is 'x'
174 :param y: statistic to calculate for the y variable; default is the maximum likelihood estimator, can also be the
175 ``mean``
176 :param yerr: error statistic to use for the y variable; default is the standard deviation (std), but can also be
177 the interquartile range (iqr)
178 :param kwargs: kwargs, passed to ``marginal_from_ds``
179 :return: ``xr.Dataset`` of y and yerr values as variables, and x-values as coordinates. If the true values are
180 passed, also contains a ``type`` dimension.
181 """
183 # Temporarily rename the 'x' dimension to avoid potential naming conflicts with the marginal operation,
184 # which also produces 'x' values. This is only strictly necessary if the x-dimension is called 'x'.
185 predictions = predictions.rename({bin_coord: f"_{bin_coord}"})
187 # Broadcast the predictions and probabilities together
188 predictions_and_loss = broadcast(predictions, probabilities, x="y", p="prob")
190 # Calculate the distribution marginal for each bin
191 marginals = marginal_from_ds(
192 predictions_and_loss, x="y", y="prob", exclude_dim=[f"_{bin_coord}"], **kwargs
193 )
195 # Calculate the y-statistic: mode (default) or mean
196 if y == "mode" or y == "MLE":
197 p_max = probabilities.idxmax()
198 _y_vals = predictions.sel({p_max.name: p_max.data}, drop=True)
199 elif y == "mean":
200 _y_vals = mean(marginals, along_dim=["bin_idx"], x="x", y="y")["mean"]
202 # Calculate the standard deviation from y
203 _y_err_vals: xr.DataArray = stat_function(
204 marginals, along_dim=["bin_idx"], x="x", y="y", stat=yerr
205 )[yerr]
207 # Interquartile range is total range, so divide by 2, since errorbands are shown as ± err
208 if yerr == "iqr":
209 _y_err_vals /= 2
211 # Combine y and yerr values into a single dataset and rename the 'x' dimension
212 res = xr.Dataset(dict(y=_y_vals, yerr=_y_err_vals)).rename({f"_{bin_coord}": bin_coord})
214 # If the true values were given, add to the dataset. The true values naturally have zero error.
215 if true_values is not None:
216 # Assign the x coordinates from res to ensure compatibility, they should be the same anyway
217 # but might be different because of precision errors
218 true_values = xr.Dataset(
219 dict(y=true_values, yerr=0 * true_values)
220 ).assign_coords({bin_coord: res.coords[bin_coord]})
221 res = concat([res, true_values], "type", [y, "True values"])
223 return res
226@is_operation("marginal_distribution_stats")
227@apply_along_dim
228def marginal_distribution_stats(
229 predictions: xr.DataArray,
230 probabilities: xr.DataArray,
231 *,
232 distance_to: str = None,
233 stat: Sequence,
234 **kwargs,
235) -> xr.DataArray:
236 """Calculates the statistics of a marginal distribution. This operation circumvents having to first compile
237 marginals for all dimensions when sweeping, only to then apply a statistics function along the bin dimension,
238 thereby saving memory.
240 The ``std`` and ``Hellinger`` and ``KL`` error statistics require different marginalisations: the first requires
241 marginalising over the probability, while the second and third require marginalising over the counts. This is
242 because the ``Hellinger`` and ``KL`` divergences require the probability bins to line up, i.e. to represent the
243 same predicted distribution, so that the distance to a target distribution can be computed.
245 :param ds: dataset containing x and y variables for which to calculate the marginal
246 :param bins: bins to use for the marginal
247 :param ranges: ranges to use for the marginal
248 :param x: x dimension
249 :param y: function values p(x)
250 :param stats: list or string of statistics to calculate. Can be any argument accepted by ``_stat_function``, or
251 ``mode``, ``Hellinger``, or ``KL``.
252 :param kwargs: additional kwargs, passed to the marginal function
253 :return: xr.Dataset of marginal statistics
254 """
256 stat = set(stat)
257 if "Hellinger" in stat or "KL" in stat:
258 if distance_to is None:
259 raise ValueError(
260 f"Calculating Hellinger or relative entropy statistics requires the 'distance_to' kwarg!"
261 )
263 # Temporarily rename the 'x' dimension to avoid naming conflicts with the marginal operation,
264 # which also produces 'x' values
265 predictions = predictions.rename({"x": "_x"})
267 # Broadcast the predictions and probabilities together, and drop any distributions that are completely zero
268 predictions_and_loss = broadcast(predictions, probabilities, x="y", p="prob")
269 predictions_and_loss = predictions_and_loss.where(
270 predictions_and_loss["prob"] > 0, drop=True
271 )
273 # Calculate the distribution marginal for each bin. These are different for the Hellinger and KL divergences,
274 # since these require the marginal coordinates to align for each _x value, since they must represent one
275 # single distribution.
276 if stat != {"KL", "Hellinger"} or (
277 ("Hellinger" in stat or "KL" in stat) and "distance_to" == "mean"
278 ):
279 marginal_over_p = marginal_from_ds(
280 predictions_and_loss, x="y", y="prob", exclude_dim=["_x"], **kwargs
281 )
283 if "Hellinger" in stat or "KL" in stat:
284 # For Hellinger and KL statistics, marginalise over the counts dimension
285 marginal_over_counts = marginal_from_ds(
286 predictions_and_loss,
287 x="prob",
288 y="y",
289 exclude_dim=["_x"],
290 normalize=False,
291 **kwargs,
292 )
294 # Get the Q distribution with respect to which the error is to be calculated
295 if distance_to == "mode" or distance_to == "MLE":
296 _y_vals = marginal_over_counts["y"].isel({"bin_idx": -1}, drop=True)
297 elif distance_to == "mean":
298 _y_vals = mean(marginal_over_p, along_dim=["bin_idx"], x="x", y="y")["mean"]
300 # Get the binned loss values associated with each marginal entry
301 prob_binned = marginal_over_counts["x"].isel({"_x": 0}, drop=True)
303 # Calculate all required statistics
304 res = []
306 # Calculate the standard deviation from y
307 for _stat in stat:
308 # Average Hellinger distance
309 if _stat == "Hellinger":
310 _distributions = Hellinger_distance(
311 marginal_over_counts["y"],
312 _y_vals.expand_dims(
313 {"bin_idx": marginal_over_counts.coords["bin_idx"]}
314 ),
315 exclude_dim=["bin_idx"],
316 )
317 _err = (
318 prob_binned * _distributions["Hellinger_distance"] / prob_binned.sum()
319 ).sum("bin_idx")
320 res.append(
321 xr.DataArray(_err.data, name="stat").expand_dims({"type": [_stat]})
322 )
324 # Average relative entropy
325 elif _stat == "KL":
326 _distributions = relative_entropy(
327 marginal_over_counts["y"],
328 _y_vals.expand_dims(
329 {"bin_idx": marginal_over_counts.coords["bin_idx"]}
330 ),
331 exclude_dim=["bin_idx"],
332 )
334 _err = (
335 prob_binned
336 * abs(_distributions["relative_entropy"])
337 / prob_binned.sum("bin_idx")
338 ).sum("bin_idx")
339 res.append(
340 xr.DataArray(_err.data, name="stat").expand_dims({"type": [_stat]})
341 )
342 else:
343 # Integrate the standard deviation along x
344 _err = stat_function(
345 marginal_over_p, along_dim=["bin_idx"], x="x", y="y", stat=_stat
346 )[_stat]
347 res.append(
348 xr.DataArray(
349 scipy.integrate.trapezoid(_err.data, _err.coords["_x"]), name="stat"
350 ).expand_dims({"type": [_stat]})
351 )
353 return xr.concat(res, "type")