1- # License: BSD 3-Clause
21from __future__ import annotations
32
43import pytest
54import pandas as pd
65from openml ._api .resources .task import TaskV1API , TaskV2API
6+ from openml ._api .resources .base .fallback import FallbackProxy
77from openml .exceptions import OpenMLNotSupportedError
88from openml .testing import TestAPIBase
9- from openml .tasks .task import TaskType
109from openml .enums import APIVersion
10+ from openml .tasks .task import TaskType
11+
1112
12- class TestTaskV1 (TestAPIBase ):
13+ class TestTaskV1API (TestAPIBase ):
1314 def setUp (self ):
1415 super ().setUp ()
15- self .resource = TaskV1API (self .http_clients [APIVersion .V1 ])
16+ self .client = self .http_clients [APIVersion .V1 ]
17+ self .task = TaskV1API (self .client )
1618
1719 @pytest .mark .uses_test_server ()
1820 def test_list_tasks (self ):
1921 """Verify V1 list endpoint returns a populated DataFrame."""
20- tasks_df = self .resource .list (limit = 5 , offset = 0 )
22+ tasks_df = self .task .list (limit = 5 , offset = 0 )
2123 assert isinstance (tasks_df , pd .DataFrame )
2224 assert not tasks_df .empty
2325 assert "tid" in tasks_df .columns
2426
25- @pytest .mark .uses_test_server ()
26- def test_estimation_procedure_list (self ):
27- """Verify that estimation procedure list endpoint works."""
28- procs = self .resource ._get_estimation_procedure_list ()
29- assert isinstance (procs , list )
30- assert len (procs ) > 0
31- assert "id" in procs [0 ]
32-
33- class TestTaskV2 (TestAPIBase ):
27+ class TestTaskV2API (TestAPIBase ):
3428 def setUp (self ):
3529 super ().setUp ()
36- self .resource = TaskV2API (self .http_clients [APIVersion .V2 ])
30+ self .client = self .http_clients [APIVersion .V2 ]
31+ self .task = TaskV2API (self .client )
3732
3833 @pytest .mark .uses_test_server ()
3934 def test_list_tasks (self ):
4035 """Verify V2 list endpoint returns a populated DataFrame."""
4136 with pytest .raises (OpenMLNotSupportedError ):
42- self .resource .list (limit = 5 , offset = 0 )
37+ self .task .list (limit = 5 , offset = 0 )
4338
4439class TestTasksCombined (TestAPIBase ):
4540 def setUp (self ):
4641 super ().setUp ()
47- self .v1_resource = TaskV1API (self .http_clients [APIVersion .V1 ])
48- self .v2_resource = TaskV2API (self .http_clients [APIVersion .V2 ])
42+ self .v1_client = self .http_clients [APIVersion .V1 ]
43+ self .v2_client = self .http_clients [APIVersion .V2 ]
44+ self .task_v1 = TaskV1API (self .v1_client )
45+ self .task_v2 = TaskV2API (self .v2_client )
46+ self .task_fallback = FallbackProxy (self .task_v1 , self .task_v2 )
4947
5048 def _get_first_tid (self , task_type : TaskType ) -> int :
5149 """Helper to find an existing task ID for a given type using the V1 resource."""
52- tasks = self .v1_resource .list (limit = 1 , offset = 0 , task_type = task_type )
50+ tasks = self .task_v1 .list (limit = 1 , offset = 0 , task_type = task_type )
5351 if tasks .empty :
5452 pytest .skip (f"No tasks of type { task_type } found on test server." )
5553 return int (tasks .iloc [0 ]["tid" ])
5654
5755 @pytest .mark .uses_test_server ()
58- def test_v2_get_task (self ):
59- """Verify that we can get a task from V2 API using a task ID found via V1."""
56+ def test_get_matches (self ):
57+ """Verify that we can get a task from V2 API and it matches V1."""
58+ # Refactored to match the 'test_get_matches' style from Reference
59+ tid = self ._get_first_tid (TaskType .SUPERVISED_CLASSIFICATION )
60+
61+ output_v1 = self .task_v1 .get (tid )
62+ output_v2 = self .task_v2 .get (tid )
63+
64+ assert int (output_v1 .task_id ) == tid
65+ assert int (output_v2 .task_id ) == tid
66+ assert output_v1 .task_id == output_v2 .task_id
67+ assert output_v1 .task_type == output_v2 .task_type
68+
69+ @pytest .mark .uses_test_server ()
70+ def test_get_fallback (self ):
71+ """Verify the fallback proxy works for retrieving tasks."""
6072 tid = self ._get_first_tid (TaskType .SUPERVISED_CLASSIFICATION )
61- task_v1 = self .v1_resource .get (tid )
62- task_v2 = self .v2_resource .get (tid )
63- assert int (task_v1 .task_id ) == tid
64- assert int (task_v2 .task_id ) == tid
73+ output_fallback = self .task_fallback .get (tid )
74+ assert int (output_fallback .task_id ) == tid
0 commit comments