Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 106 additions & 19 deletions crates/coglet-python/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,21 @@ pub fn prepare_input(
func: &Bound<'_, PyAny>,
) -> PyResult<PreparedInput> {
let fields = classify_fields(py, func)?;
coerce_url_strings(py, input, &fields)?;
coerce_typed_inputs(py, input, &fields)?;
let cleanup_paths = download_url_paths_into_dict(py, input)?;
Ok(PreparedInput::new(input.clone().unbind(), cleanup_paths))
}

/// Whether a field should be coerced as a `cog.File` or `cog.Path`.
/// Whether a field should be coerced as a `cog.File`, `cog.Path`, or `cog.Secret`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum FieldKind {
File,
Path,
Secret,
}

/// Result of inspecting a Python function's type annotations for File and Path fields.
/// Result of inspecting a Python function's type annotations for File, Path,
/// and Secret fields.
#[derive(Debug)]
struct FieldClassification {
/// Fields typed as `cog.File` (or `list[File]`, `Optional[File]`, etc.)
Expand All @@ -100,22 +102,31 @@ struct FieldClassification {
/// Fields typed as `cog.Path` (or `list[Path]`, `Optional[Path]`, etc.)
/// These use `Path.validate()` for URL coercion.
path_fields: HashSet<String>,
/// Fields typed as `cog.Secret` (only direct `Secret`, `Optional[Secret]`,
/// or `Secret | None`; `list[Secret]` is intentionally not classified).
/// These wrap plain string values in `cog.types.Secret`.
secret_fields: HashSet<String>,
}

/// Inspect a Python function's type annotations to find parameters typed as
/// `cog.File` or `cog.Path` (including `list[...]`, `Optional[...]`,
/// `... | None`, etc.).
/// `cog.File`, `cog.Path`, or `cog.Secret`. For File and Path this includes
/// the `list[...]`, `Optional[...]`, and `... | None` forms. For Secret only
/// the direct `Secret`, `Optional[Secret]`, and `Secret | None` forms are
/// classified; `list[Secret]` is intentionally not classified (see the note
/// in `classify_type` below).
///
/// Returns a `FieldClassification` so that `coerce_url_strings` only coerces
/// fields that are actually File- or Path-typed, leaving `str` and other types
/// untouched.
/// Returns a `FieldClassification` so that `coerce_typed_inputs` only coerces
/// fields that are actually File-, Path-, or Secret-typed, leaving `str` and
/// other types untouched.
fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult<FieldClassification> {
let mut file_fields = HashSet::new();
let mut path_fields = HashSet::new();
let mut secret_fields = HashSet::new();

let cog_types = py.import("cog.types")?;
let cog_file_class = cog_types.getattr("File")?;
let cog_path_class = cog_types.getattr("Path")?;
let cog_secret_class = cog_types.getattr("Secret")?;

// typing.get_type_hints resolves string annotations and handles forward refs
let typing = py.import("typing")?;
Expand All @@ -134,18 +145,24 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult<FieldCla
return Ok(FieldClassification {
file_fields,
path_fields,
secret_fields,
});
}
};

// Helper: returns the FieldKind if ty is File/Path or list[File]/list[Path].
// Helper: returns the FieldKind if ty is File/Path/Secret or list[File]/list[Path].
// Note: list[Secret] is intentionally not classified; `coerce_typed_inputs`
// only wraps single string Secret values, so lists of secrets are not coerced.
let classify_type = |ty: &Bound<'_, PyAny>| -> PyResult<Option<FieldKind>> {
if ty.is(&cog_file_class) {
return Ok(Some(FieldKind::File));
}
if ty.is(&cog_path_class) {
return Ok(Some(FieldKind::Path));
}
if ty.is(&cog_secret_class) {
return Ok(Some(FieldKind::Secret));
}
let inner_origin = get_origin.call1((ty,))?;
if !inner_origin.is_none() && inner_origin.is(&builtins_list) {
let inner_args = get_args.call1((ty,))?;
Expand Down Expand Up @@ -184,6 +201,9 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult<FieldCla
FieldKind::Path => {
path_fields.insert(name_str);
}
FieldKind::Secret => {
secret_fields.insert(name_str);
}
}
continue;
}
Expand Down Expand Up @@ -211,6 +231,9 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult<FieldCla
FieldKind::Path => {
path_fields.insert(name_str.clone());
}
FieldKind::Secret => {
secret_fields.insert(name_str.clone());
}
}
break;
}
Expand All @@ -219,45 +242,64 @@ fn classify_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult<FieldCla
}
}

if !file_fields.is_empty() || !path_fields.is_empty() {
if !file_fields.is_empty() || !path_fields.is_empty() || !secret_fields.is_empty() {
tracing::debug!(
"Detected File-typed fields: {:?}, Path-typed fields: {:?}",
"Detected File-typed fields: {:?}, Path-typed fields: {:?}, Secret-typed fields: {:?}",
file_fields,
path_fields
path_fields,
secret_fields
);
}

Ok(FieldClassification {
file_fields,
path_fields,
secret_fields,
})
}

/// Coerce URL string values in the input dict to the appropriate cog types.
/// Coerce input dict values to the appropriate cog types.
///
/// After `json.loads()`, all values are plain Python types. URL strings
/// (http://, https://, data:) that represent file inputs need to be converted:
/// After `json.loads()`, all values are plain Python types. Typed fields need
/// to be converted:
/// - `File`-typed fields -> `File.validate()` -> returns IO-like `URLFile`
/// (URL strings only: http://, https://, data:)
/// - `Path`-typed fields -> `Path.validate()` -> returns `URLPath` (downloaded later)
/// (URL strings only)
/// - `Secret`-typed fields -> `Secret(value)` -> wraps any plain string value
///
/// Only fields whose declared type is `File` or `Path` (including `list[...]`,
/// `Optional[...]`, etc.) are coerced. Fields typed as `str` or any other type
/// are left untouched, even if the value looks like a URL.
/// Only fields whose declared type is `File`, `Path`, or `Secret` are coerced.
/// For File and Path this includes the `list[...]`, `Optional[...]`, and
/// `... | None` forms; for Secret only the direct `Secret`, `Optional[Secret]`,
/// and `Secret | None` forms are coerced (`list[Secret]` is intentionally not).
/// Fields typed as `str` or any other type are left untouched, even if the
/// value looks like a URL.
///
/// This replaces the type coercion that `_adt.py`'s `PrimitiveType.normalize()`
/// previously performed.
fn coerce_url_strings(
fn coerce_typed_inputs(
py: Python<'_>,
payload: &Bound<'_, PyDict>,
fields: &FieldClassification,
) -> PyResult<()> {
let cog_types = py.import("cog.types")?;
let path_validate = cog_types.getattr("Path")?.getattr("validate")?;
let file_validate = cog_types.getattr("File")?.getattr("validate")?;
let secret_class = cog_types.getattr("Secret")?;

for (key, value) in payload.iter() {
let key_str: String = key.extract().unwrap_or_default();

// Secret-typed fields: wrap the plain string value in cog.types.Secret.
// Unlike File/Path, this is not URL-conditional -- any string is wrapped.
if fields.secret_fields.contains(&key_str) {
if value.extract::<String>().is_ok() {
let coerced = secret_class.call1((&value,))?;
payload.set_item(&key, coerced)?;
}
continue;
}

// Only coerce fields that are declared as File or Path types.
// str-typed and other fields are left as-is even if values look like URLs.
let validate = if fields.file_fields.contains(&key_str) {
Expand Down Expand Up @@ -630,6 +672,51 @@ mod tests {
"non-Path types incorrectly detected as path: {:?}",
c.path_fields
);
assert!(
c.secret_fields.is_empty(),
"non-Secret types incorrectly detected as secret: {:?}",
c.secret_fields
);
}

#[test]
#[ignore] // Requires cog Python package in PYTHONPATH
fn detect_direct_secret() {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds useful Secret annotation coverage, but these tests are all #[ignore] and only validate classification. The behavior this PR changes is runtime coercion in coerce_typed_inputs, so default CI still would not catch a predictor receiving a plain str instead of cog.Secret. Could you add active coglet-python pytest coverage that asserts a predictor annotated with Secret receives a cog.Secret whose get_secret_value() returns the submitted value? Ideally cover Secret, Optional[Secret], Union[Secret, None] / Secret | None, plus negative cases like non-secret str staying str and optional None staying None.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added active pytest coverage in TestSecretInput (Secret/Optional/PEP604 + negatives) — done in 6340d2f.

let c = classify_for("from cog import Secret\ndef func(a: Secret, b: str): ...");
assert!(
c.secret_fields.contains("a"),
"direct Secret annotation not detected"
);
assert!(
!c.secret_fields.contains("b"),
"str incorrectly flagged as Secret"
);
assert!(
!c.file_fields.contains("a") && !c.path_fields.contains("a"),
"Secret incorrectly flagged as File or Path"
);
}

#[test]
#[ignore] // Requires cog Python package in PYTHONPATH
fn detect_optional_secret() {
let c = classify_for(
"from typing import Optional\nfrom cog import Secret\ndef func(a: Optional[Secret]): ...",
);
assert!(
c.secret_fields.contains("a"),
"Optional[Secret] annotation not detected"
);
}

#[test]
#[ignore] // Requires cog Python package in PYTHONPATH
fn detect_pep604_secret_or_none() {
let c = classify_for("from cog import Secret\ndef func(a: Secret | None): ...");
assert!(
c.secret_fields.contains("a"),
"Secret | None annotation not detected as Secret"
);
}

#[test]
Expand Down
109 changes: 109 additions & 0 deletions crates/coglet-python/tests/test_coglet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,55 @@ def predict(self, name: str = "World") -> str:
return predictor


@pytest.fixture
def secret_predictor(tmp_path: Path) -> Path:
"""Create a predictor exercising Secret input coercion.

predict() echoes back, for each parameter, the runtime type name and (for
secrets) the unwrapped value so tests can assert how each annotation is
coerced. Each field is encoded as ``name=<type>|<value>`` and fields are
joined with ``;``. For ``None`` values the unwrapped value is the literal
string ``None``.
"""
predictor = tmp_path / "predict.py"
predictor.write_text("""
from typing import Optional
from cog import BasePredictor, Secret

class Predictor(BasePredictor):
def setup(self):
pass

def predict(
self,
api_token: Secret,
plain: str = "",
opt_secret: Optional[Secret] = None,
pep604_secret: Secret | None = None,
) -> str:
def describe(value):
type_name = type(value).__name__
if isinstance(value, Secret):
return f"{type_name}|{value.get_secret_value()}"
return f"{type_name}|{value}"

return ";".join([
f"api_token={describe(api_token)}",
f"plain={describe(plain)}",
f"opt_secret={describe(opt_secret)}",
f"pep604_secret={describe(pep604_secret)}",
])
""")

# Create cog.yaml
cog_yaml = tmp_path / "cog.yaml"
cog_yaml.write_text("""
predict: "predict.py:Predictor"
""")

return predictor


@pytest.fixture
def generator_predictor(tmp_path: Path) -> Path:
"""Create a generator predictor."""
Expand Down Expand Up @@ -471,6 +520,66 @@ def test_includes_predict_time(self, sync_predictor: Path):
assert result["metrics"]["predict_time"] >= 0


class TestSecretInput:
"""Tests for cog.Secret input coercion."""

def _fields(self, output: str) -> dict:
"""Parse the predictor's ``name=<type>|<value>`` encoding into a dict."""
fields = {}
for part in output.split(";"):
name, encoded = part.split("=", 1)
type_name, value = encoded.split("|", 1)
fields[name] = (type_name, value)
return fields

def test_direct_secret_is_wrapped(self, secret_predictor: Path):
"""A ``Secret``-annotated param wraps the submitted string in Secret."""
with CogletServer(secret_predictor) as server:
result = server.predict({"api_token": "sk-test-12345"})
assert result["status"] == "succeeded"
fields = self._fields(result["output"])
assert fields["api_token"] == ("Secret", "sk-test-12345")

def test_optional_secret_with_value_is_wrapped(self, secret_predictor: Path):
"""An ``Optional[Secret]`` param with a value wraps it in Secret."""
with CogletServer(secret_predictor) as server:
result = server.predict(
{"api_token": "sk-test-12345", "opt_secret": "sk-opt-67890"}
)
assert result["status"] == "succeeded"
fields = self._fields(result["output"])
assert fields["opt_secret"] == ("Secret", "sk-opt-67890")

def test_pep604_secret_with_value_is_wrapped(self, secret_predictor: Path):
"""A ``Secret | None`` (PEP 604) param with a value wraps it in Secret."""
with CogletServer(secret_predictor) as server:
result = server.predict(
{"api_token": "sk-test-12345", "pep604_secret": "sk-604-abcde"}
)
assert result["status"] == "succeeded"
fields = self._fields(result["output"])
assert fields["pep604_secret"] == ("Secret", "sk-604-abcde")

def test_plain_str_is_not_wrapped(self, secret_predictor: Path):
"""A plain ``str`` param stays a str and is NOT wrapped in Secret."""
with CogletServer(secret_predictor) as server:
result = server.predict(
{"api_token": "sk-test-12345", "plain": "sk-test-12345"}
)
assert result["status"] == "succeeded"
fields = self._fields(result["output"])
assert fields["plain"] == ("str", "sk-test-12345")

def test_optional_secret_omitted_stays_none(self, secret_predictor: Path):
"""An omitted ``Optional[Secret]`` param stays None, not Secret("")."""
with CogletServer(secret_predictor) as server:
result = server.predict({"api_token": "sk-test-12345"})
assert result["status"] == "succeeded"
fields = self._fields(result["output"])
assert fields["opt_secret"] == ("NoneType", "None")
assert fields["pep604_secret"] == ("NoneType", "None")


class TestGeneratorPredictor:
"""Tests for generator predictor."""

Expand Down
20 changes: 20 additions & 0 deletions examples/hello-replicate/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file

# Exclude Git files
**/.git
**/.github
**/.gitignore

# Exclude Python tooling
.python-version

# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache

# Exclude Python virtual environment
/venv
Loading