summarylogtreecommitdiffstats
path: root/change_tests_to_testing.patch
blob: 277c6b67e76f06c0caf85fccbfda1f71d5146905 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
diff --git a/xskillscore/tests/test_accessor_deterministic.py b/xskillscore/tests/test_accessor_deterministic.py
index d8a827d..3c0530e 100644
--- a/xskillscore/tests/test_accessor_deterministic.py
+++ b/xskillscore/tests/test_accessor_deterministic.py
@@ -1,6 +1,6 @@
 import pytest
 import xarray as xr
-from xarray.tests import assert_allclose
+from xarray.testing import assert_allclose
 
 from xskillscore.core.deterministic import (
     effective_sample_size,
diff --git a/xskillscore/tests/test_accessor_probabilistic.py b/xskillscore/tests/test_accessor_probabilistic.py
index c32e7df..6a8ebbb 100644
--- a/xskillscore/tests/test_accessor_probabilistic.py
+++ b/xskillscore/tests/test_accessor_probabilistic.py
@@ -2,7 +2,7 @@ import numpy as np
 import pytest
 import xarray as xr
 from scipy.stats import norm
-from xarray.tests import assert_allclose
+from xarray.testing import assert_allclose
 
 from xskillscore.core.probabilistic import (
     brier_score,
diff --git a/xskillscore/tests/test_deterministic.py b/xskillscore/tests/test_deterministic.py
index 8b44645..6c8955b 100644
--- a/xskillscore/tests/test_deterministic.py
+++ b/xskillscore/tests/test_deterministic.py
@@ -1,7 +1,7 @@
 import numpy as np
 import pytest
 import xarray as xr
-from xarray.tests import assert_allclose
+from xarray.testing import assert_allclose
 
 from xskillscore.core.deterministic import (
     _preprocess_dims,
diff --git a/xskillscore/tests/test_probabilistic.py b/xskillscore/tests/test_probabilistic.py
index 283122a..68db18d 100644
--- a/xskillscore/tests/test_probabilistic.py
+++ b/xskillscore/tests/test_probabilistic.py
@@ -7,7 +7,7 @@ from dask import is_dask_collection
 from scipy.stats import norm
 from sklearn.calibration import calibration_curve
 from sklearn.metrics import roc_auc_score, roc_curve
-from xarray.tests import assert_allclose, assert_identical
+from xarray.testing import assert_allclose, assert_identical
 
 from xskillscore.core.probabilistic import (
     brier_score,
diff --git a/xskillscore/tests/test_skipna_functionality.py b/xskillscore/tests/test_skipna_functionality.py
index 8c39efa..c6f9ed0 100644
--- a/xskillscore/tests/test_skipna_functionality.py
+++ b/xskillscore/tests/test_skipna_functionality.py
@@ -2,7 +2,7 @@ from typing import Callable, List
 
 import numpy as np
 import pytest
-from xarray.tests import assert_allclose, raise_if_dask_computes
+from xarray.testing import assert_allclose
 
 from xskillscore.core.deterministic import (
     linslope,
@@ -20,6 +20,41 @@ from xskillscore.core.deterministic import (
     spearman_r_p_value,
 )
 
+from contextlib import nullcontext
+try:
+    import dask
+except ImportError:
+    has_dask = False
+else:
+    has_dask = True
+
+
+class CountingScheduler:
+    """Simple dask scheduler counting the number of computes.
+
+    Reference: https://stackoverflow.com/questions/53289286/"""
+
+    def __init__(self, max_computes=0):
+        self.total_computes = 0
+        self.max_computes = max_computes
+
+    def __call__(self, dsk, keys, **kwargs):
+        self.total_computes += 1
+        if self.total_computes > self.max_computes:
+            raise RuntimeError(
+                "Too many computes. Total: %d > max: %d."
+                % (self.total_computes, self.max_computes)
+            )
+        return dask.get(dsk, keys, **kwargs)
+
+
+def raise_if_dask_computes(max_computes=0):
+    # return a dummy context manager so that this can be used for non-dask objects
+    if not has_dask:
+        return nullcontext()
+    scheduler = CountingScheduler(max_computes)
+    return dask.config.set(scheduler=scheduler)
+
 WEIGHTED_METRICS: List[Callable] = [
     linslope,
     pearson_r,