Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions api/experimentation/migrations/0009_add_rollout_segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.2.14 on 2026-06-19 09:59

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("experimentation", "0008_experiment_results"),
("segments", "0030_add_default_to_segment_version"),
]

operations = [
migrations.AddField(
model_name="experiment",
name="rollout_segment",
field=models.OneToOneField(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
related_name="experiment_rollout",
to="segments.segment",
),
),
]
7 changes: 7 additions & 0 deletions api/experimentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class Experiment(LifecycleModelMixin, SoftDeleteExportableModel): # type: ignor
updated_at = models.DateTimeField(auto_now=True)
started_at = models.DateTimeField(null=True, blank=True)
ended_at = models.DateTimeField(null=True, blank=True)
rollout_segment = models.OneToOneField(
"segments.Segment",
on_delete=models.SET_NULL,
related_name="experiment_rollout",
null=True,
blank=True,
)

class Meta:
constraints = [
Expand Down
75 changes: 75 additions & 0 deletions api/experimentation/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.db.models import QuerySet
from rest_framework import serializers

from core.dataclasses import AuthorData
from environments.models import Environment
from experimentation.dataclasses import WarehouseEventStats
from experimentation.metric_definitions import validate_metric_definition
Expand All @@ -18,14 +19,24 @@
WarehouseConnection,
WarehouseType,
)
from experimentation.services import (
create_experiment_rollout,
get_experiment_rollout,
)
from experimentation.types import (
SNOWFLAKE_DEFAULTS,
MetricExperimentResult,
SnowflakeConfig,
)
from features.feature_states.serializers import (
FeatureValueSerializer,
MultivariateValueSerializer,
validate_multivariate_state_values,
)
from features.feature_types import MULTIVARIATE
from features.models import Feature
from features.multivariate.serializers import NestedMultivariateFeatureOptionSerializer
from features.versioning.dataclasses import MultivariateValueChangeSet


class WarehouseConnectionSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
Expand Down Expand Up @@ -207,6 +218,35 @@ class ExperimentMetricInlineSerializer(serializers.Serializer): # type: ignore[
expected_direction = serializers.ChoiceField(choices=ExpectedDirection.choices)


class ExperimentRolloutSerializer(serializers.Serializer): # type: ignore[type-arg]
enabled = serializers.BooleanField(required=True)
rollout_percentage = serializers.FloatField(
required=True, min_value=0, max_value=100
)
feature_state_value = FeatureValueSerializer(required=True)
multivariate_feature_state_values = MultivariateValueSerializer(
many=True, required=False
)

@staticmethod
def to_service_kwargs(data: dict[str, Any], request: Any) -> dict[str, Any]:
value = data["feature_state_value"]
return {
"enabled": data["enabled"],
"rollout_percentage": data["rollout_percentage"],
"feature_state_value": value["value"],
"value_type": value["type"],
"multivariate_values": [
MultivariateValueChangeSet(
multivariate_feature_option_id=mv["multivariate_feature_option"],
percentage_allocation=mv["percentage_allocation"],
)
for mv in data.get("multivariate_feature_state_values", [])
],
"author": AuthorData.from_request(request),
}


class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
# Annotated with the common base type so ExperimentListSerializer can
# override the field with a read-only representation.
Expand All @@ -215,6 +255,9 @@ class ExperimentSerializer(serializers.ModelSerializer): # type: ignore[type-ar
required=False,
write_only=True,
)
experiment_rollout: Any = ExperimentRolloutSerializer(
required=False, write_only=True
)

class Meta:
model = Experiment
Expand All @@ -225,6 +268,7 @@ class Meta:
"hypothesis",
"status",
"metrics",
"experiment_rollout",
"created_at",
"updated_at",
"started_at",
Expand Down Expand Up @@ -260,9 +304,28 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
raise serializers.ValidationError(
{"metrics": "Cannot change the metrics of an existing experiment."}
)
if self.instance is not None and "experiment_rollout" in attrs:
raise serializers.ValidationError(
{
"experiment_rollout": (
"Cannot change the rollout via this endpoint; "
"use the rollout endpoint instead."
)
}
)
self._validate_metrics(attrs.get("metrics") or [])
self._validate_rollout(attrs)
return attrs

def _validate_rollout(self, attrs: dict[str, Any]) -> None:
rollout = attrs.get("experiment_rollout")
feature = attrs.get("feature")
if not rollout or feature is None:
return
validate_multivariate_state_values(
feature, rollout.get("multivariate_feature_state_values", [])
)

def _validate_metrics(self, metrics: list[dict[str, Any]]) -> None:
metric_ids = [entry["metric"].id for entry in metrics]
if len(metric_ids) != len(set(metric_ids)):
Expand All @@ -272,6 +335,7 @@ def _validate_metrics(self, metrics: list[dict[str, Any]]) -> None:

def create(self, validated_data: dict[str, Any]) -> Experiment:
metrics: list[dict[str, Any]] = validated_data.pop("metrics", [])
rollout: dict[str, Any] | None = validated_data.pop("experiment_rollout", None)
with transaction.atomic():
experiment: Experiment = super().create(validated_data)
ExperimentMetric.objects.bulk_create(
Expand All @@ -282,6 +346,13 @@ def create(self, validated_data: dict[str, Any]) -> Experiment:
)
for entry in metrics
)
if rollout is not None:
create_experiment_rollout(
experiment,
**ExperimentRolloutSerializer.to_service_kwargs(
rollout, self.context["request"]
),
)
return experiment


Expand Down Expand Up @@ -338,6 +409,10 @@ class ExperimentListSerializer(ExperimentSerializer):
many=True,
read_only=True,
)
experiment_rollout = serializers.SerializerMethodField()

def get_experiment_rollout(self, experiment: Experiment) -> dict[str, Any] | None:
return get_experiment_rollout(experiment)


class ExperimentExposuresSerializer(serializers.ModelSerializer): # type: ignore[type-arg]
Expand Down
128 changes: 128 additions & 0 deletions api/experimentation/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from django.conf import settings
from django.db.models import Q
from django.utils import timezone
from flag_engine.segments.constants import PERCENTAGE_SPLIT
from rest_framework.exceptions import ValidationError

from audit.models import AuditLog
from audit.related_object_type import RelatedObjectType
from core.dataclasses import AuthorData
from experimentation.constants import (
CONTROL_VARIANT_KEY,
EXPERIMENT_FLAG,
Expand Down Expand Up @@ -50,14 +53,21 @@
srm_p_value,
)
from features.models import FeatureState
from features.value_types import BOOLEAN, INTEGER, STRING
from features.versioning.dataclasses import FlagChangeSet, MultivariateValueChangeSet
from features.versioning.versioning_service import update_flag
from integrations.flagsmith.client import get_openfeature_client
from segments.models import Condition, Segment, SegmentRule

_ROLLOUT_VALUE_TYPE = {INTEGER: "integer", STRING: "string", BOOLEAN: "boolean"}

if typing.TYPE_CHECKING:
from collections.abc import Sequence
from datetime import datetime

from experimentation.models import Experiment, Metric, WarehouseConnection
from experimentation.types import ExposureGranularity
from features.feature_states.models import FeatureValueType
from organisations.models import Organisation
from users.models import FFAdminUser

Expand Down Expand Up @@ -512,6 +522,124 @@ def transition_experiment_status(
return experiment


def _create_rollout_segment(
experiment: Experiment, rollout_percentage: float
) -> Segment:
segment: Segment = Segment.objects.create(
name=f"experiment-{experiment.id}-rollout",
project=experiment.feature.project,
is_system_segment=True,
)
rule = SegmentRule.objects.create(segment=segment, type=SegmentRule.ALL_RULE)
Condition.objects.create(
rule=rule,
operator=PERCENTAGE_SPLIT,
property="$.identity.key",
value=str(rollout_percentage),
)
return segment


def create_experiment_rollout(
experiment: Experiment,
*,
enabled: bool,
rollout_percentage: float,
feature_state_value: str,
value_type: FeatureValueType,
multivariate_values: list[MultivariateValueChangeSet],
author: AuthorData,
) -> None:
segment = _create_rollout_segment(experiment, rollout_percentage)
experiment.rollout_segment = segment
experiment.save()
update_flag(
experiment.environment,
experiment.feature,
FlagChangeSet(
author=author,
enabled=enabled,
feature_state_value=feature_state_value,
type_=value_type,
segment_id=segment.id,
multivariate_values=multivariate_values,
),
)


def update_experiment_rollout(
experiment: Experiment,
*,
enabled: bool,
rollout_percentage: float,
feature_state_value: str,
value_type: FeatureValueType,
multivariate_values: list[MultivariateValueChangeSet],
author: AuthorData,
) -> None:
if experiment.status in (ExperimentStatus.RUNNING, ExperimentStatus.COMPLETED):
raise ValidationError(
f"Cannot update the rollout of a {experiment.status} experiment."
)
segment = experiment.rollout_segment
if segment is None:
raise ValidationError("Experiment has no rollout to update.")

condition = Condition.objects.get(rule__segment=segment, operator=PERCENTAGE_SPLIT)
condition.value = str(rollout_percentage)
condition.save()
update_flag(
experiment.environment,
experiment.feature,
FlagChangeSet(
author=author,
enabled=enabled,
feature_state_value=feature_state_value,
type_=value_type,
segment_id=segment.id,
multivariate_values=multivariate_values,
),
)


def get_experiment_rollout(experiment: Experiment) -> dict[str, typing.Any] | None:
segment_id = experiment.rollout_segment_id
if segment_id is None:
return None

feature_state = (
experiment.feature.feature_states.filter(
environment=experiment.environment,
feature_segment__segment_id=segment_id,
identity__isnull=True,
)
.order_by("-live_from", "-version")
.first()
)
if feature_state is None:
return None

condition = Condition.objects.get(
rule__segment_id=segment_id, operator=PERCENTAGE_SPLIT
)
value = feature_state.feature_state_value
return {
"enabled": feature_state.enabled,
"rollout_percentage": float(condition.value or 0),
"feature_state_value": {
"type": _ROLLOUT_VALUE_TYPE.get(value.type or STRING, "string"),
"value": str(value.value),
},
"multivariate_feature_state_values": [
{
"multivariate_feature_option": mv.multivariate_feature_option_id,
"percentage_allocation": mv.percentage_allocation,
}
for mv in feature_state.multivariate_feature_state_values.all()
],
}


def mark_warehouse_pending_connection(
connection: WarehouseConnection,
) -> WarehouseConnection:
Expand Down
Loading
Loading