summarylogtreecommitdiffstats
path: root/fix_expand_dims.patch
blob: 9c09fe2cdd670581646aa76de580908fd245d4b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
diff --git a/properscoring/_crps.py b/properscoring/_crps.py
index 8fbb276..beef9a0 100644
--- a/properscoring/_crps.py
+++ b/properscoring/_crps.py
@@ -219,12 +219,15 @@ def _crps_ensemble_vectorized(observations, forecasts, weights=1):
         observations = observations[..., np.newaxis]
         with suppress_warnings('Mean of empty slice'):
             score = np.nanmean(weights * abs(forecasts - observations), -1)
+        # to fit new version numpy
+        # see https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html
+        axis_tweak = lambda axis, arr: arr.ndim if axis > arr.ndim else 0 if axis < -arr.ndim - 1 else axis
         # insert new axes along last and second to last forecast dimensions so
         # forecasts_diff expands with the array broadcasting
-        forecasts_diff = (np.expand_dims(forecasts, -1) -
-                          np.expand_dims(forecasts, -2))
-        weights_matrix = (np.expand_dims(weights, -1) *
-                          np.expand_dims(weights, -2))
+        forecasts_diff = (np.expand_dims(forecasts, axis_tweak(-1, forecasts)) -
+                          np.expand_dims(forecasts, axis_tweak(-2, forecasts)))
+        weights_matrix = (np.expand_dims(weights, axis_tweak(-1, weights)) *
+                          np.expand_dims(weights, axis_tweak(-2, weights)))
         with suppress_warnings('Mean of empty slice'):
             score += -0.5 * np.nanmean(weights_matrix * abs(forecasts_diff),
                                        axis=(-2, -1))