diff --git a/crates/coglet-python/src/input.rs b/crates/coglet-python/src/input.rs index 0c23336495..61f8f17d6e 100644 --- a/crates/coglet-python/src/input.rs +++ b/crates/coglet-python/src/input.rs @@ -79,19 +79,21 @@ pub fn prepare_input( func: &Bound<'_, PyAny>, ) -> PyResult { let fields = classify_fields(py, func)?; - coerce_url_strings(py, input, &fields)?; + coerce_typed_inputs(py, input, &fields)?; let cleanup_paths = download_url_paths_into_dict(py, input)?; Ok(PreparedInput::new(input.clone().unbind(), cleanup_paths)) } -/// Whether a field should be coerced as a `cog.File` or `cog.Path`. +/// Whether a field should be coerced as a `cog.File`, `cog.Path`, or `cog.Secret`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum FieldKind { File, Path, + Secret, } -/// Result of inspecting a Python function's type annotations for File and Path fields. +/// Result of inspecting a Python function's type annotations for File, Path, +/// and Secret fields. #[derive(Debug)] struct FieldClassification { /// Fields typed as `cog.File` (or `list[File]`, `Optional[File]`, etc.) @@ -100,22 +102,31 @@ struct FieldClassification { /// Fields typed as `cog.Path` (or `list[Path]`, `Optional[Path]`, etc.) /// These use `Path.validate()` for URL coercion. path_fields: HashSet, + /// Fields typed as `cog.Secret` (only direct `Secret`, `Optional[Secret]`, + /// or `Secret | None`; `list[Secret]` is intentionally not classified). + /// These wrap plain string values in `cog.types.Secret`. + secret_fields: HashSet, } /// Inspect a Python function's type annotations to find parameters typed as -/// `cog.File` or `cog.Path` (including `list[...]`, `Optional[...]`, -/// `... | None`, etc.). +/// `cog.File`, `cog.Path`, or `cog.Secret`. For File and Path this includes +/// the `list[...]`, `Optional[...]`, and `... | None` forms. For Secret only +/// the direct `Secret`, `Optional[Secret]`, and `Secret | None` forms are +/// classified; `list[Secret]` is intentionally not classified (see the note +/// in `classify_type` below). /// -/// Returns a `FieldClassification` so that `coerce_url_strings` only coerces -/// fields that are actually File- or Path-typed, leaving `str` and other types -/// untouched. +/// Returns a `FieldClassification` so that `coerce_typed_inputs` only coerces +/// fields that are actually File-, Path-, or Secret-typed, leaving `str` and +/// other types untouched. fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult { let mut file_fields = HashSet::new(); let mut path_fields = HashSet::new(); + let mut secret_fields = HashSet::new(); let cog_types = py.import("cog.types")?; let cog_file_class = cog_types.getattr("File")?; let cog_path_class = cog_types.getattr("Path")?; + let cog_secret_class = cog_types.getattr("Secret")?; // typing.get_type_hints resolves string annotations and handles forward refs let typing = py.import("typing")?; @@ -134,11 +145,14 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult| -> PyResult> { if ty.is(&cog_file_class) { return Ok(Some(FieldKind::File)); @@ -146,6 +160,9 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult, func: &Bound<'_, PyAny>) -> PyResult { path_fields.insert(name_str); } + FieldKind::Secret => { + secret_fields.insert(name_str); + } } continue; } @@ -211,6 +231,9 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult { path_fields.insert(name_str.clone()); } + FieldKind::Secret => { + secret_fields.insert(name_str.clone()); + } } break; } @@ -219,34 +242,42 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult `File.validate()` -> returns IO-like `URLFile` +/// (URL strings only: http://, https://, data:) /// - `Path`-typed fields -> `Path.validate()` -> returns `URLPath` (downloaded later) +/// (URL strings only) +/// - `Secret`-typed fields -> `Secret(value)` -> wraps any plain string value /// -/// Only fields whose declared type is `File` or `Path` (including `list[...]`, -/// `Optional[...]`, etc.) are coerced. Fields typed as `str` or any other type -/// are left untouched, even if the value looks like a URL. +/// Only fields whose declared type is `File`, `Path`, or `Secret` are coerced. +/// For File and Path this includes the `list[...]`, `Optional[...]`, and +/// `... | None` forms; for Secret only the direct `Secret`, `Optional[Secret]`, +/// and `Secret | None` forms are coerced (`list[Secret]` is intentionally not). +/// Fields typed as `str` or any other type are left untouched, even if the +/// value looks like a URL. /// /// This replaces the type coercion that `_adt.py`'s `PrimitiveType.normalize()` /// previously performed. -fn coerce_url_strings( +fn coerce_typed_inputs( py: Python<'_>, payload: &Bound<'_, PyDict>, fields: &FieldClassification, @@ -254,10 +285,21 @@ fn coerce_url_strings( let cog_types = py.import("cog.types")?; let path_validate = cog_types.getattr("Path")?.getattr("validate")?; let file_validate = cog_types.getattr("File")?.getattr("validate")?; + let secret_class = cog_types.getattr("Secret")?; for (key, value) in payload.iter() { let key_str: String = key.extract().unwrap_or_default(); + // Secret-typed fields: wrap the plain string value in cog.types.Secret. + // Unlike File/Path, this is not URL-conditional -- any string is wrapped. + if fields.secret_fields.contains(&key_str) { + if value.extract::().is_ok() { + let coerced = secret_class.call1((&value,))?; + payload.set_item(&key, coerced)?; + } + continue; + } + // Only coerce fields that are declared as File or Path types. // str-typed and other fields are left as-is even if values look like URLs. let validate = if fields.file_fields.contains(&key_str) { @@ -630,6 +672,51 @@ mod tests { "non-Path types incorrectly detected as path: {:?}", c.path_fields ); + assert!( + c.secret_fields.is_empty(), + "non-Secret types incorrectly detected as secret: {:?}", + c.secret_fields + ); + } + + #[test] + #[ignore] // Requires cog Python package in PYTHONPATH + fn detect_direct_secret() { + let c = classify_for("from cog import Secret\ndef func(a: Secret, b: str): ..."); + assert!( + c.secret_fields.contains("a"), + "direct Secret annotation not detected" + ); + assert!( + !c.secret_fields.contains("b"), + "str incorrectly flagged as Secret" + ); + assert!( + !c.file_fields.contains("a") && !c.path_fields.contains("a"), + "Secret incorrectly flagged as File or Path" + ); + } + + #[test] + #[ignore] // Requires cog Python package in PYTHONPATH + fn detect_optional_secret() { + let c = classify_for( + "from typing import Optional\nfrom cog import Secret\ndef func(a: Optional[Secret]): ...", + ); + assert!( + c.secret_fields.contains("a"), + "Optional[Secret] annotation not detected" + ); + } + + #[test] + #[ignore] // Requires cog Python package in PYTHONPATH + fn detect_pep604_secret_or_none() { + let c = classify_for("from cog import Secret\ndef func(a: Secret | None): ..."); + assert!( + c.secret_fields.contains("a"), + "Secret | None annotation not detected as Secret" + ); } #[test] diff --git a/crates/coglet-python/tests/test_coglet.py b/crates/coglet-python/tests/test_coglet.py index cdc37a3686..57dcdd40a8 100644 --- a/crates/coglet-python/tests/test_coglet.py +++ b/crates/coglet-python/tests/test_coglet.py @@ -124,6 +124,55 @@ def predict(self, name: str = "World") -> str: return predictor +@pytest.fixture +def secret_predictor(tmp_path: Path) -> Path: + """Create a predictor exercising Secret input coercion. + + predict() echoes back, for each parameter, the runtime type name and (for + secrets) the unwrapped value so tests can assert how each annotation is + coerced. Each field is encoded as ``name=|`` and fields are + joined with ``;``. For ``None`` values the unwrapped value is the literal + string ``None``. + """ + predictor = tmp_path / "predict.py" + predictor.write_text(""" +from typing import Optional +from cog import BasePredictor, Secret + +class Predictor(BasePredictor): + def setup(self): + pass + + def predict( + self, + api_token: Secret, + plain: str = "", + opt_secret: Optional[Secret] = None, + pep604_secret: Secret | None = None, + ) -> str: + def describe(value): + type_name = type(value).__name__ + if isinstance(value, Secret): + return f"{type_name}|{value.get_secret_value()}" + return f"{type_name}|{value}" + + return ";".join([ + f"api_token={describe(api_token)}", + f"plain={describe(plain)}", + f"opt_secret={describe(opt_secret)}", + f"pep604_secret={describe(pep604_secret)}", + ]) +""") + + # Create cog.yaml + cog_yaml = tmp_path / "cog.yaml" + cog_yaml.write_text(""" +predict: "predict.py:Predictor" +""") + + return predictor + + @pytest.fixture def generator_predictor(tmp_path: Path) -> Path: """Create a generator predictor.""" @@ -471,6 +520,66 @@ def test_includes_predict_time(self, sync_predictor: Path): assert result["metrics"]["predict_time"] >= 0 +class TestSecretInput: + """Tests for cog.Secret input coercion.""" + + def _fields(self, output: str) -> dict: + """Parse the predictor's ``name=|`` encoding into a dict.""" + fields = {} + for part in output.split(";"): + name, encoded = part.split("=", 1) + type_name, value = encoded.split("|", 1) + fields[name] = (type_name, value) + return fields + + def test_direct_secret_is_wrapped(self, secret_predictor: Path): + """A ``Secret``-annotated param wraps the submitted string in Secret.""" + with CogletServer(secret_predictor) as server: + result = server.predict({"api_token": "sk-test-12345"}) + assert result["status"] == "succeeded" + fields = self._fields(result["output"]) + assert fields["api_token"] == ("Secret", "sk-test-12345") + + def test_optional_secret_with_value_is_wrapped(self, secret_predictor: Path): + """An ``Optional[Secret]`` param with a value wraps it in Secret.""" + with CogletServer(secret_predictor) as server: + result = server.predict( + {"api_token": "sk-test-12345", "opt_secret": "sk-opt-67890"} + ) + assert result["status"] == "succeeded" + fields = self._fields(result["output"]) + assert fields["opt_secret"] == ("Secret", "sk-opt-67890") + + def test_pep604_secret_with_value_is_wrapped(self, secret_predictor: Path): + """A ``Secret | None`` (PEP 604) param with a value wraps it in Secret.""" + with CogletServer(secret_predictor) as server: + result = server.predict( + {"api_token": "sk-test-12345", "pep604_secret": "sk-604-abcde"} + ) + assert result["status"] == "succeeded" + fields = self._fields(result["output"]) + assert fields["pep604_secret"] == ("Secret", "sk-604-abcde") + + def test_plain_str_is_not_wrapped(self, secret_predictor: Path): + """A plain ``str`` param stays a str and is NOT wrapped in Secret.""" + with CogletServer(secret_predictor) as server: + result = server.predict( + {"api_token": "sk-test-12345", "plain": "sk-test-12345"} + ) + assert result["status"] == "succeeded" + fields = self._fields(result["output"]) + assert fields["plain"] == ("str", "sk-test-12345") + + def test_optional_secret_omitted_stays_none(self, secret_predictor: Path): + """An omitted ``Optional[Secret]`` param stays None, not Secret("").""" + with CogletServer(secret_predictor) as server: + result = server.predict({"api_token": "sk-test-12345"}) + assert result["status"] == "succeeded" + fields = self._fields(result["output"]) + assert fields["opt_secret"] == ("NoneType", "None") + assert fields["pep604_secret"] == ("NoneType", "None") + + class TestGeneratorPredictor: """Tests for generator predictor.""" diff --git a/examples/hello-replicate/.dockerignore b/examples/hello-replicate/.dockerignore new file mode 100644 index 0000000000..1d4c71fdac --- /dev/null +++ b/examples/hello-replicate/.dockerignore @@ -0,0 +1,20 @@ +# The .dockerignore file excludes files from the container build process. +# +# https://docs.docker.com/engine/reference/builder/#dockerignore-file + +# Exclude Git files +**/.git +**/.github +**/.gitignore + +# Exclude Python tooling +.python-version + +# Exclude Python cache files +__pycache__ +.mypy_cache +.pytest_cache +.ruff_cache + +# Exclude Python virtual environment +/venv diff --git a/examples/hello-replicate/README.md b/examples/hello-replicate/README.md new file mode 100644 index 0000000000..1d8af57e8f --- /dev/null +++ b/examples/hello-replicate/README.md @@ -0,0 +1,49 @@ +# hello-replicate + +An example Cog model that demonstrates `cog.Secret` inputs by calling the +[Replicate API](https://replicate.com/docs) from inside a prediction. + +Given an input image, the model: + +1. Sends the image to `anthropic/claude-4-sonnet` to generate a detailed prompt + describing it. +2. Feeds that prompt to `black-forest-labs/flux-dev` to re-create the image. +3. Returns the generated image. + +## Secrets + +The Replicate API token is declared as a `cog.Secret` input: + +```python +from cog import Input, Secret + +def run( + replicate_api_token: Secret = Input( + description="Replicate API token used to call other models", + ), +) -> Path: + client = Client(api_token=replicate_api_token.get_secret_value()) + ... +``` + +`cog.Secret` redacts its value in logs and string representations. Read the +underlying value with `get_secret_value()`. + +## Run it + +Avoid passing the token literally on the command line, since it can leak +through your shell history and process listings. Instead, read it from an +environment variable: + +```sh +export REPLICATE_API_TOKEN=r8_... # set once, ideally via a secrets manager / not inline in shared shells +cog predict -i image=@cat.png -i replicate_api_token="$REPLICATE_API_TOKEN" +``` + +You can also read the token from a file (for example +`-i replicate_api_token="$(cat token.txt)"`) if that fits your workflow better. + +> **Note:** `cog.Secret` redacts the value in model logs and string +> representations, but it cannot protect a secret that is already exposed by +> your own shell history, environment, or process listing. Keeping the token +> out of those places is your responsibility. diff --git a/examples/hello-replicate/cat.png b/examples/hello-replicate/cat.png new file mode 100644 index 0000000000..15296784ac Binary files /dev/null and b/examples/hello-replicate/cat.png differ diff --git a/examples/hello-replicate/cog.yaml b/examples/hello-replicate/cog.yaml new file mode 100644 index 0000000000..573975d916 --- /dev/null +++ b/examples/hello-replicate/cog.yaml @@ -0,0 +1,28 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: false + + # a list of ubuntu apt packages to install + # system_packages: + # - "libgl1-mesa-glx" + # - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.12" + + # path to a Python requirements.txt file + python_requirements: requirements.txt + + # enable fast boots + fast: false + + # commands run after the environment is setup + # run: + # - "echo env is ready!" + # - "echo another command if needed" + +# main.py defines how predictions are run on your model +run: "main.py:run" diff --git a/examples/hello-replicate/main.py b/examples/hello-replicate/main.py new file mode 100644 index 0000000000..39b70a5fcf --- /dev/null +++ b/examples/hello-replicate/main.py @@ -0,0 +1,35 @@ +import os +import tempfile +import warnings + +from replicate.client import Client + +from cog import ExperimentalFeatureWarning, Input, Path, Secret + +warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) + + +def run( + image: Path = Input(description="Input image to test"), + replicate_api_token: Secret = Input( + description="Replicate API token used to call other models", + ), +) -> Path: + replicate = Client(api_token=replicate_api_token.get_secret_value()) + claude_prompt = """ +You have been asked to generate a prompt for an image model that should re-create the +image provided to you exactly. Please describe the provided image in great detail +paying close attention to the contents, layout and style. + """ + prompt = replicate.run( + "anthropic/claude-4-sonnet", input={"prompt": claude_prompt, "image": image} + ) + output = replicate.run( + "black-forest-labs/flux-dev", input={"prompt": "".join(prompt)} + ) + + with tempfile.TemporaryDirectory(delete=False) as tmpdir: + dest_path = os.path.join(tmpdir, "output.webp") + with open(dest_path, "wb") as file: + file.write(output[0].read()) + return Path(dest_path) diff --git a/examples/hello-replicate/requirements.txt b/examples/hello-replicate/requirements.txt new file mode 100644 index 0000000000..83f1f0bf4e --- /dev/null +++ b/examples/hello-replicate/requirements.txt @@ -0,0 +1,24 @@ +# This is a normal Python requirements.txt file. + +# You can add dependencies directly from PyPI: +# +# numpy==1.26.4 +# torch==2.2.1 +# torchvision==0.17.1 +replicate>=1.0.7 + + +# You can also add Git repos as dependencies, but you'll need to add git to the system_packages list in cog.yaml: +# +# build: +# system_packages: +# - "git" +# +# Then you can use a URL like this: +# +# git+https://github.com/huggingface/transformers + + +# You can also pin Git repos to a specific commit: +# +# git+https://github.com/huggingface/transformers@2d1602a