Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions evalbench/test/dataform_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Unit tests for DataformHelper utility."""

from typing import Generator
from unittest.mock import MagicMock, patch

from google.api_core import exceptions as api_exceptions
from google.cloud import dataform_v1beta1
import pytest
from util.dataform import DataformHelper

PROJECT_ID = "test-project"
LOCATION = "us-west4"
REPO_ID = "test-repo"
WORKSPACE_ID = "test-workspace"


@pytest.fixture(name="mock_client")
def fixture_mock_client() -> Generator[MagicMock, None, None]:
with patch("util.dataform.dataform_v1beta1.DataformClient") as mock_class:
mock_instance = MagicMock()
mock_class.return_value = mock_instance
yield mock_instance


@pytest.fixture(name="helper")
def fixture_helper(mock_client: MagicMock) -> DataformHelper:
del mock_client
return DataformHelper(PROJECT_ID, LOCATION)


def test_create_repository_success(
mock_client: MagicMock, helper: DataformHelper
):
mock_response = MagicMock()
mock_response.name = (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}"
)
mock_client.create_repository.return_value = mock_response

repo_name = helper.create_repository(REPO_ID)

assert repo_name == mock_response.name
mock_client.create_repository.assert_called_once()


def test_create_repository_generic_exception(
mock_client: MagicMock, helper: DataformHelper
):
mock_client.create_repository.side_effect = Exception("failed")

with pytest.raises(Exception) as exc_info:
helper.create_repository(REPO_ID)

assert "failed" in str(exc_info.value)
mock_client.create_repository.assert_called_once()


def test_create_workspace_success(
mock_client: MagicMock, helper: DataformHelper
):
mock_response = MagicMock()
mock_response.name = (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}/workspaces/{WORKSPACE_ID}"
)
mock_client.create_workspace.return_value = mock_response

workspace_name = helper.create_workspace(REPO_ID, WORKSPACE_ID)

assert workspace_name == mock_response.name
mock_client.create_workspace.assert_called_once()


def test_create_workspace_generic_exception(
mock_client: MagicMock, helper: DataformHelper
):
mock_client.create_workspace.side_effect = Exception("failed")

with pytest.raises(Exception) as exc_info:
helper.create_workspace(REPO_ID, WORKSPACE_ID)

assert "failed" in str(exc_info.value)
mock_client.create_workspace.assert_called_once()


def test_delete_workspace_success(
mock_client: MagicMock, helper: DataformHelper
):
helper.delete_workspace(REPO_ID, WORKSPACE_ID)

mock_client.delete_workspace.assert_called_once_with(
request={
"name": (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}/workspaces/{WORKSPACE_ID}"
)
}
)


def test_delete_workspace_not_found(
mock_client: MagicMock, helper: DataformHelper
):
mock_client.delete_workspace.side_effect = (
api_exceptions.NotFound("not found")
)

helper.delete_workspace(REPO_ID, WORKSPACE_ID)

mock_client.delete_workspace.assert_called_once()


def test_delete_workspace_exception(
mock_client: MagicMock, helper: DataformHelper
):
mock_client.delete_workspace.side_effect = Exception("failed")

with pytest.raises(Exception) as exc_info:
helper.delete_workspace(REPO_ID, WORKSPACE_ID)

assert "failed" in str(exc_info.value)
mock_client.delete_workspace.assert_called_once()


def test_delete_repository_success(
mock_client: MagicMock, helper: DataformHelper
):
# Mock list_workspaces to return two workspaces
ws1 = MagicMock()
ws1.name = (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}/workspaces/ws1"
)
ws2 = MagicMock()
ws2.name = (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}/workspaces/ws2"
)
mock_client.list_workspaces.return_value = [ws1, ws2]

# Patch the helper's own delete_workspace method to verify delegation
with patch.object(helper, "delete_workspace") as mock_delete_ws:
helper.delete_repository(REPO_ID)

mock_client.list_workspaces.assert_called_once()
assert mock_delete_ws.call_count == 2
mock_delete_ws.assert_any_call(REPO_ID, "ws1")
mock_delete_ws.assert_any_call(REPO_ID, "ws2")

mock_client.delete_repository.assert_called_once_with(
request={
"name": (
f"projects/{PROJECT_ID}/locations/{LOCATION}"
f"/repositories/{REPO_ID}"
),
"force": True,
}
)


def test_delete_repository_exception(
mock_client: MagicMock, helper: DataformHelper
):
mock_client.list_workspaces.side_effect = Exception("failed")

with pytest.raises(Exception) as exc_info:
helper.delete_repository(REPO_ID)

assert "failed" in str(exc_info.value)
148 changes: 148 additions & 0 deletions evalbench/util/dataform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Utility for managing temporary GCP Dataform repositories and workspaces."""

import logging

from google.api_core import exceptions as api_exceptions
from google.cloud import dataform_v1beta1

logger = logging.getLogger(__name__)


class DataformHelper:
"""Helper class to interact with Google Cloud Dataform API."""

def __init__(self, project_id: str, location: str):
"""Initializes the Dataform client helper.

Args:
project_id: The GCP Project ID.
location: The GCP region (e.g. 'us-west4').
"""
self.client = dataform_v1beta1.DataformClient()
self.parent = f"projects/{project_id}/locations/{location}"

def create_repository(self, repository_id: str) -> str:
"""Creates a new Dataform repository in the project and location.

Args:
repository_id: The unique ID for the repository.

Returns:
The full resource path of the created repository.
"""
repository_path = f"{self.parent}/repositories/{repository_id}"
logger.info("Creating Dataform repository: %s", repository_path)

# We create a clean, empty repository object.
repository = dataform_v1beta1.Repository()

try:
response = self.client.create_repository(
request={
"parent": self.parent,
"repository_id": repository_id,
"repository": repository,
}
)
logger.info("Successfully created repository: %s", response.name)
return response.name
except Exception:
logger.exception(
"Failed to create repository: %s", repository_id
)
raise

def create_workspace(self, repository_id: str,
workspace_id: str) -> str:
"""Creates a new Dataform workspace inside the specified repository.

Args:
repository_id: The ID of the parent repository.
workspace_id: The unique ID for the workspace.

Returns:
The full resource path of the created workspace.
"""
repository_path = f"{self.parent}/repositories/{repository_id}"
workspace_path = f"{repository_path}/workspaces/{workspace_id}"
logger.info("Creating Dataform workspace: %s", workspace_path)

workspace = dataform_v1beta1.Workspace()

try:
response = self.client.create_workspace(
request={
"parent": repository_path,
"workspace_id": workspace_id,
"workspace": workspace,
}
)
logger.info("Successfully created workspace: %s", response.name)
return response.name
except Exception:
logger.exception(
"Failed to create workspace %s in repo %s",
workspace_id,
repository_id,
)
raise

def delete_workspace(self, repository_id: str,
workspace_id: str) -> None:
"""Deletes a Dataform workspace inside the specified repository.

Args:
repository_id: The ID of the parent repository.
workspace_id: The unique ID for the workspace.
"""
repository_path = f"{self.parent}/repositories/{repository_id}"
workspace_path = f"{repository_path}/workspaces/{workspace_id}"
logger.info("Deleting Dataform workspace: %s", workspace_path)

try:
self.client.delete_workspace(request={"name": workspace_path})
logger.info("Successfully deleted workspace: %s", workspace_path)
except api_exceptions.NotFound:
logger.warning("Workspace already deleted: %s", workspace_path)
except Exception:
logger.exception(
"Failed to delete workspace %s in repo %s",
workspace_id,
repository_id,
)
raise

def delete_repository(self, repository_id: str) -> None:
"""Deletes a Dataform repository and all its nested resources.

This performs a cascading delete by first programmatically deleting
all workspaces inside the repository, and then deleting the
repository itself with the force flag enabled.

Args:
repository_id: The ID of the repository to delete.
"""
repository_path = f"{self.parent}/repositories/{repository_id}"
logger.info("Deleting Dataform repository: %s", repository_path)

try:
workspaces = self.client.list_workspaces(
request={"parent": repository_path}
)
for ws in workspaces:
ws_id = ws.name.split("/")[-1]
self.delete_workspace(repository_id, ws_id)

self.client.delete_repository(
request={"name": repository_path, "force": True}
)
logger.info(
"Successfully deleted repository and nested resources: %s",
repository_path,
)
except Exception:
logger.exception(
"Failed to delete repository and nested resources: %s",
repository_id,
)
raise
Loading