summarylogtreecommitdiffstats
path: root/fix-tests.patch
blob: 099da39417b8698b25f03b59721d2a97fdb03763 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
diff --git a/pydantic_evals/pydantic_evals/evaluators/spec.py b/pydantic_evals/pydantic_evals/evaluators/spec.py
index 4ae841112..b20aa5957 100644
--- a/pydantic_evals/pydantic_evals/evaluators/spec.py
+++ b/pydantic_evals/pydantic_evals/evaluators/spec.py
@@ -64,8 +64,7 @@ class EvaluatorSpec(BaseModel):
         return {}
 
     @model_validator(mode='wrap')
-    @classmethod
-    def deserialize(cls, value: Any, handler: ModelWrapValidatorHandler[EvaluatorSpec]) -> EvaluatorSpec:
+    def deserialize(self, handler: ModelWrapValidatorHandler[EvaluatorSpec]) -> EvaluatorSpec:
         """Deserialize an EvaluatorSpec from various formats.
 
         This validator handles the various short forms of evaluator specifications,
@@ -82,11 +81,11 @@ class EvaluatorSpec(BaseModel):
             ValidationError: If the value cannot be deserialized.
         """
         try:
-            result = handler(value)
+            result = handler(self)
             return result
         except ValidationError as exc:
             try:
-                deserialized = _SerializedEvaluatorSpec.model_validate(value)
+                deserialized = _SerializedEvaluatorSpec.model_validate(self)
             except ValidationError:
                 raise exc  # raise the original error
             return deserialized.to_evaluator_spec()
diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py
index 31f91bfc0..4635c1b82 100644
--- a/tests/evals/test_dataset.py
+++ b/tests/evals/test_dataset.py
@@ -8,7 +8,6 @@ from typing import Any, Literal, cast
 
 import pytest
 import yaml
-from _pytest.python_api import RaisesContext
 from dirty_equals import HasRepr, IsNumber
 from inline_snapshot import snapshot
 from pydantic import BaseModel, TypeAdapter
@@ -964,7 +963,7 @@ async def test_from_text_failure():
         ],
         'evaluators': ['NotAnEvaluator'],
     }
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
     assert exc_info.value == HasRepr(
         repr(
@@ -994,7 +993,7 @@ async def test_from_text_failure():
         ],
         'evaluators': ['LLMJudge'],
     }
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
     assert exc_info.value == HasRepr(  # pragma: lax no cover
         repr(
diff --git a/tests/evals/test_utils.py b/tests/evals/test_utils.py
index 71219a308..8b6e06908 100644
--- a/tests/evals/test_utils.py
+++ b/tests/evals/test_utils.py
@@ -7,7 +7,6 @@ from functools import partial
 from typing import Any, cast
 
 import pytest
-from _pytest.python_api import RaisesContext
 from dirty_equals import HasRepr
 
 from ..conftest import try_import
@@ -144,7 +143,7 @@ async def test_task_group_gather_with_error():
         return 3
 
     tasks = [task1, task2, task3]
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         await task_group_gather(tasks)
 
     assert exc_info.value == HasRepr(
diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py
index 5ab588ab5..006addf31 100644
--- a/tests/models/test_fallback.py
+++ b/tests/models/test_fallback.py
@@ -7,7 +7,6 @@ from datetime import timezone
 from typing import Any, cast
 
 import pytest
-from _pytest.python_api import RaisesContext
 from dirty_equals import IsJson
 from inline_snapshot import snapshot
 from pydantic_core import to_json
@@ -298,7 +297,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None
 def test_all_failed() -> None:
     fallback_model = FallbackModel(failure_model, failure_model)
     agent = Agent(model=fallback_model)
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         agent.run_sync('hello')
     assert 'All models from FallbackModel failed' in exc_info.value.args[0]
     exceptions = exc_info.value.exceptions
@@ -321,7 +320,7 @@ def add_missing_response_model(spans: list[dict[str, Any]]) -> list[dict[str, An
 def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
     fallback_model = FallbackModel(failure_model, failure_model)
     agent = Agent(model=fallback_model, instrument=True)
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         agent.run_sync('hello')
     assert 'All models from FallbackModel failed' in exc_info.value.args[0]
     exceptions = exc_info.value.exceptions
@@ -488,7 +487,7 @@ async def test_first_failed_streaming() -> None:
 async def test_all_failed_streaming() -> None:
     fallback_model = FallbackModel(failure_model_stream, failure_model_stream)
     agent = Agent(model=fallback_model)
-    with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+    with pytest.raises(ExceptionGroup) as exc_info:
         async with agent.run_stream('hello') as result:
             [c async for c, _is_last in result.stream_responses(debounce_by=None)]  # pragma: lax no cover
     assert 'All models from FallbackModel failed' in exc_info.value.args[0]