1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
diff --git a/pydantic_evals/pydantic_evals/evaluators/spec.py b/pydantic_evals/pydantic_evals/evaluators/spec.py
index 4ae841112..b20aa5957 100644
--- a/pydantic_evals/pydantic_evals/evaluators/spec.py
+++ b/pydantic_evals/pydantic_evals/evaluators/spec.py
@@ -64,8 +64,7 @@ class EvaluatorSpec(BaseModel):
return {}
@model_validator(mode='wrap')
- @classmethod
- def deserialize(cls, value: Any, handler: ModelWrapValidatorHandler[EvaluatorSpec]) -> EvaluatorSpec:
+ def deserialize(self, handler: ModelWrapValidatorHandler[EvaluatorSpec]) -> EvaluatorSpec:
"""Deserialize an EvaluatorSpec from various formats.
This validator handles the various short forms of evaluator specifications,
@@ -82,11 +81,11 @@ class EvaluatorSpec(BaseModel):
ValidationError: If the value cannot be deserialized.
"""
try:
- result = handler(value)
+ result = handler(self)
return result
except ValidationError as exc:
try:
- deserialized = _SerializedEvaluatorSpec.model_validate(value)
+ deserialized = _SerializedEvaluatorSpec.model_validate(self)
except ValidationError:
raise exc # raise the original error
return deserialized.to_evaluator_spec()
diff --git a/tests/evals/test_dataset.py b/tests/evals/test_dataset.py
index 31f91bfc0..4635c1b82 100644
--- a/tests/evals/test_dataset.py
+++ b/tests/evals/test_dataset.py
@@ -8,7 +8,6 @@ from typing import Any, Literal, cast
import pytest
import yaml
-from _pytest.python_api import RaisesContext
from dirty_equals import HasRepr, IsNumber
from inline_snapshot import snapshot
from pydantic import BaseModel, TypeAdapter
@@ -964,7 +963,7 @@ async def test_from_text_failure():
],
'evaluators': ['NotAnEvaluator'],
}
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
assert exc_info.value == HasRepr(
repr(
@@ -994,7 +993,7 @@ async def test_from_text_failure():
],
'evaluators': ['LLMJudge'],
}
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
Dataset[TaskInput, TaskOutput, TaskMetadata].from_text(json.dumps(dataset_dict))
assert exc_info.value == HasRepr( # pragma: lax no cover
repr(
diff --git a/tests/evals/test_utils.py b/tests/evals/test_utils.py
index 71219a308..8b6e06908 100644
--- a/tests/evals/test_utils.py
+++ b/tests/evals/test_utils.py
@@ -7,7 +7,6 @@ from functools import partial
from typing import Any, cast
import pytest
-from _pytest.python_api import RaisesContext
from dirty_equals import HasRepr
from ..conftest import try_import
@@ -144,7 +143,7 @@ async def test_task_group_gather_with_error():
return 3
tasks = [task1, task2, task3]
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
await task_group_gather(tasks)
assert exc_info.value == HasRepr(
diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py
index 5ab588ab5..006addf31 100644
--- a/tests/models/test_fallback.py
+++ b/tests/models/test_fallback.py
@@ -7,7 +7,6 @@ from datetime import timezone
from typing import Any, cast
import pytest
-from _pytest.python_api import RaisesContext
from dirty_equals import IsJson
from inline_snapshot import snapshot
from pydantic_core import to_json
@@ -298,7 +297,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None
def test_all_failed() -> None:
fallback_model = FallbackModel(failure_model, failure_model)
agent = Agent(model=fallback_model)
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
agent.run_sync('hello')
assert 'All models from FallbackModel failed' in exc_info.value.args[0]
exceptions = exc_info.value.exceptions
@@ -321,7 +320,7 @@ def add_missing_response_model(spans: list[dict[str, Any]]) -> list[dict[str, An
def test_all_failed_instrumented(capfire: CaptureLogfire) -> None:
fallback_model = FallbackModel(failure_model, failure_model)
agent = Agent(model=fallback_model, instrument=True)
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
agent.run_sync('hello')
assert 'All models from FallbackModel failed' in exc_info.value.args[0]
exceptions = exc_info.value.exceptions
@@ -488,7 +487,7 @@ async def test_first_failed_streaming() -> None:
async def test_all_failed_streaming() -> None:
fallback_model = FallbackModel(failure_model_stream, failure_model_stream)
agent = Agent(model=fallback_model)
- with cast(RaisesContext[ExceptionGroup[Any]], pytest.raises(ExceptionGroup)) as exc_info:
+ with pytest.raises(ExceptionGroup) as exc_info:
async with agent.run_stream('hello') as result:
[c async for c, _is_last in result.stream_responses(debounce_by=None)] # pragma: lax no cover
assert 'All models from FallbackModel failed' in exc_info.value.args[0]
|