|
2 | 2 | """Tests for Flow V1 → V2 API Migration.""" |
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import uuid |
| 6 | + |
5 | 7 | import pytest |
6 | 8 |
|
7 | 9 | from openml._api.resources import FallbackProxy, FlowsV1, FlowsV2 |
@@ -81,37 +83,67 @@ def test_list_with_tag_limit_offset(self): |
81 | 83 | if len(flows_df) > 0: |
82 | 84 | assert "id" in flows_df.columns |
83 | 85 |
|
84 | | - @pytest.mark.uses_test_server() |
85 | | - def test_publish(self): |
86 | | - """Test publishing a sklearn flow using V1 API.""" |
87 | | - from openml_sklearn.extension import SklearnExtension |
88 | | - from sklearn.tree import ExtraTreeRegressor |
89 | | - clf = ExtraTreeRegressor() |
90 | | - extension = SklearnExtension() |
91 | | - dt_flow = extension.model_to_flow(clf) |
92 | | - published_flow = self.resource.publish(dt_flow) |
93 | | - assert isinstance(published_flow, OpenMLFlow) |
94 | | - assert getattr(published_flow, "id", None) is not None |
95 | | - |
96 | 86 | @pytest.mark.uses_test_server() |
97 | 87 | def test_delete(self): |
98 | 88 | """Test deleting a flow using V1 API.""" |
99 | 89 | from openml_sklearn.extension import SklearnExtension |
100 | 90 | from sklearn.tree import ExtraTreeRegressor |
| 91 | + |
101 | 92 | clf = ExtraTreeRegressor() |
102 | 93 | extension = SklearnExtension() |
103 | 94 | dt_flow = extension.model_to_flow(clf) |
| 95 | + |
| 96 | + # Check if flow exists, if not publish it |
104 | 97 | flow_id = self.resource.exists( |
105 | 98 | name=dt_flow.name, |
106 | | - external_version=dt_flow.external_version |
| 99 | + external_version=dt_flow.external_version, |
107 | 100 | ) |
| 101 | + |
| 102 | + if not flow_id: |
| 103 | + # Publish the flow first |
| 104 | + file_elements = dt_flow._get_file_elements() |
| 105 | + if "description" not in file_elements: |
| 106 | + file_elements["description"] = dt_flow._to_xml() |
| 107 | + |
| 108 | + flow_id = self.resource.publish(file_elements) |
| 109 | + |
| 110 | + # Now delete it |
108 | 111 | result = self.resource.delete(flow_id) |
109 | 112 | assert result is True |
| 113 | + |
| 114 | + # Verify it no longer exists |
110 | 115 | exists = self.resource.exists( |
111 | 116 | name=dt_flow.name, |
112 | | - external_version=dt_flow.external_version |
| 117 | + external_version=dt_flow.external_version, |
113 | 118 | ) |
114 | 119 | assert exists is False |
| 120 | + |
| 121 | + @pytest.mark.uses_test_server() |
| 122 | + def test_publish(self): |
| 123 | + """Test publishing a sklearn flow using V1 API.""" |
| 124 | + from openml_sklearn.extension import SklearnExtension |
| 125 | + from sklearn.tree import ExtraTreeRegressor |
| 126 | + |
| 127 | + clf = ExtraTreeRegressor() |
| 128 | + extension = SklearnExtension() |
| 129 | + dt_flow = extension.model_to_flow(clf) |
| 130 | + |
| 131 | + # Check if flow already exists |
| 132 | + flow_id = self.resource.exists( |
| 133 | + name=dt_flow.name, |
| 134 | + external_version=dt_flow.external_version, |
| 135 | + ) |
| 136 | + |
| 137 | + if not flow_id: |
| 138 | + file_elements = dt_flow._get_file_elements() |
| 139 | + if "description" not in file_elements: |
| 140 | + print("Adding description to flow XML") |
| 141 | + file_elements["description"] = dt_flow._to_xml() |
| 142 | + |
| 143 | + flow_id = self.resource.publish(file_elements) |
| 144 | + |
| 145 | + assert isinstance(flow_id, int) |
| 146 | + assert flow_id > 0 |
115 | 147 |
|
116 | 148 |
|
117 | 149 |
|
|
0 commit comments