Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions flexmeasures/api/v3_0/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,29 @@
partial_sensor_schema = SensorSchema(partial=True, exclude=["generic_asset_id"])
annotation_schema = AnnotationSchema()


def regressors_loader(config: dict | None) -> list[Sensor]:
"""Extract regressor sensors from the forecasting config for permission checking.

:param config: Deserialized forecasting config (output of TrainPredictPipelineConfigSchema),
which already contains resolved Sensor objects for regressor fields.
:returns: Deduplicated list of regressor Sensor objects, or an empty list if no
config or no regressors are specified.
"""
if not config:
return []
return list(
{
sensor
for regressor_list in [
config.get("future_regressors", []),
config.get("past_regressors", []),
]
for sensor in regressor_list
}
)


# Create ForecasterParametersSchema OpenAPI compatible schema
EXCLUDED_FORECASTING_FIELDS = [
# todo: hide these in the config schema instead
Expand Down Expand Up @@ -1561,6 +1584,12 @@ def get_status(self, id, sensor):
as_kwargs=True,
)
@permission_required_for_context("create-children", ctx_arg_name="sensor_to_save")
@permission_required_for_context(
"read",
ctx_arg_name="config",
ctx_loader=regressors_loader,
pass_ctx_to_loader=True,
)
def trigger_forecast(self, id: int, **params):
"""
.. :quickref: Forecasts; Trigger forecasting job for one sensor
Expand Down
80 changes: 80 additions & 0 deletions flexmeasures/api/v3_0/tests/test_forecasting_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from flask import current_app
import isodate
import pytest
Expand All @@ -8,6 +10,9 @@
from flexmeasures.api.tests.utils import get_auth_token
from flexmeasures.data.services.forecasting import handle_forecasting_exception
from flexmeasures.data.models.forecasting.pipelines import TrainPredictPipeline
from flexmeasures.data import db
from flexmeasures.data.models.generic_assets import GenericAsset, GenericAssetType
from flexmeasures.data.models.time_series import Sensor


@pytest.mark.parametrize("requesting_user", ["test_admin_user@seita.nl"], indirect=True)
Expand Down Expand Up @@ -126,3 +131,78 @@ def test_trigger_and_fetch_forecasts(

# API should return exactly these most-recent beliefs
assert api_forecasts == expected_values


@pytest.mark.parametrize(
"regressor_field",
["future-regressors", "past-regressors", "regressors"],
)
@pytest.mark.parametrize(
"requesting_user", ["test_supplier_user_4@seita.nl"], indirect=True
)
def test_trigger_forecast_with_unreadable_regressor_returns_403(
app,
setup_roles_users_fresh_db,
setup_accounts_fresh_db,
requesting_user,
regressor_field,
):
"""Triggering a forecast that uses a regressor the requesting user cannot read must return 403."""

supplier_account = setup_accounts_fresh_db["Supplier"]
prosumer_account = setup_accounts_fresh_db["Prosumer"]

asset_type = GenericAssetType(name="test-asset-type-regressor-perm")
db.session.add(asset_type)

# Target sensor: owned by Supplier account – requesting user has create-children here
supplier_asset = GenericAsset(
name=f"supplier-target-asset-{regressor_field}",
generic_asset_type=asset_type,
owner=supplier_account,
)
db.session.add(supplier_asset)
target_sensor = Sensor(
name=f"supplier-target-sensor-{regressor_field}",
unit="kW",
event_resolution=timedelta(hours=1),
generic_asset=supplier_asset,
)
db.session.add(target_sensor)

# Regressor sensor: owned by Prosumer account – requesting user has no read access here
prosumer_asset = GenericAsset(
name=f"prosumer-private-regressor-asset-{regressor_field}",
generic_asset_type=asset_type,
owner=prosumer_account,
)
db.session.add(prosumer_asset)
private_regressor = Sensor(
name=f"prosumer-private-regressor-sensor-{regressor_field}",
unit="kW",
event_resolution=timedelta(hours=1),
generic_asset=prosumer_asset,
)
db.session.add(private_regressor)
db.session.commit()

client = app.test_client()
token = get_auth_token(client, "test_supplier_user_4@seita.nl", "testtest")

payload = {
"start": "2025-01-05T00:00:00+00:00",
"end": "2025-01-05T02:00:00+00:00",
"max-forecast-horizon": "PT1H",
"forecast-frequency": "PT1H",
"config": {
"train-start": "2025-01-01T00:00:00+00:00",
"retrain-frequency": "PT1H",
regressor_field: [private_regressor.id],
},
}

trigger_url = url_for("SensorAPI:trigger_forecast", id=target_sensor.id)
trigger_res = client.post(
trigger_url, json=payload, headers={"Authorization": token}
)
assert trigger_res.status_code == 403
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything in the response message that we could check that let's the user know which sensor or which field (e.g. "config", or even better, regressor_field) was to blame?

9 changes: 8 additions & 1 deletion flexmeasures/auth/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,17 @@ def decorated_view(*args, **kwargs):
else:
context = context_from_args

check_access(context, permission)
_check_access_for_context(context, permission)

return fn(*args, **kwargs)

return decorated_view

return wrapper


def _check_access_for_context(context, permission: str):
"""Check access for a single context or for each context in a list."""
contexts = context if isinstance(context, list) else [context]
for c in contexts:
check_access(c, permission)
8 changes: 7 additions & 1 deletion flexmeasures/data/schemas/forecasting/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def resolve_config( # noqa: C901

m_viewpoints = max(predict_period // forecast_frequency, 1)

return dict(
result = dict(
sensor=target_sensor,
model_save_dir=model_save_dir,
output_path=output_path,
Expand All @@ -519,13 +519,19 @@ def resolve_config( # noqa: C901
beliefs_before=data.get("belief_time"),
m_viewpoints=m_viewpoints,
)
# Pass through any additional keys declared in subclass schemas (e.g. config in ForecastingTriggerSchema)
for key in data:
if key not in result:
result[key] = data[key]
return result


class ForecastingTriggerSchema(ForecasterParametersSchema):

config = fields.Nested(
TrainPredictPipelineConfigSchema(),
required=False,
load_default={},
metadata={
"description": "Changing any of these will result in a new data source ID."
},
Expand Down
Loading