Skip to content

Commit 2bd81d9

Browse files
iwillspeakcorpo-iwillspeaksnus-kin
authored
fix: detect and preserve image format instead of hardcoding RAW_UINT8 (#67)
* fix: detect and preserve image format instead of hardcoding RAW_UINT8 Previously, all images were sent with IMAGE_FORMAT_RAW_UINT8 regardless of their actual format (JPEG, PNG, etc.), causing incorrect metadata to be sent to the API when images weren't resized. Changes: - Add image format detection via magic number signatures for PNG, JPEG, GIF, BMP, WebP, and TIFF formats - Update ImageData to automatically detect format during initialization - Preserve detected format throughout transformation pipeline - Convert UNSPECIFIED format to RAW_UINT8 before sending to ensure API never receives UNSPECIFIED - Update resize transformer to set format to RAW_UINT8 when converting to raw pixel data This ensures: 1. Native image formats (PNG, JPEG, etc.) are correctly identified and preserved when sent without resizing 2. Resized images are correctly marked as RAW_UINT8 3. IMAGE_FORMAT_UNSPECIFIED is never sent over the API (defaults to RAW_UINT8) Tests: Added 46 new tests covering format detection, API validation, and format preservation across both streaming and single classification methods. All 166 tests passing. * build: Suppress Lint Warnings * refactor: use named constants for image format magic bytes - Remove __future__ annotations import from input_model.py - Replace inline byte literals with named constants in image_format_detector.py - Calculate lengths dynamically using len() instead of hardcoded values - Eliminates need for noqa comments on magic value comparisons * style: ruff format and line end fix --------- Co-authored-by: Will Speak <will.speak@kroll.com> Co-authored-by: snus-kin <tcarroll@snufk.in>
1 parent 7a710bb commit 2bd81d9

12 files changed

Lines changed: 462 additions & 5 deletions

File tree

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ OAUTH_AUDIENCE=crisp-athena-live
66

77
# Athena server configuration
88
# ATHENA_HOST=trust-messages.crispthinking.com
9-
ATHENA_AFFILIATE=athena-test
9+
ATHENA_AFFILIATE=athena-test

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ repos:
1313
hooks:
1414
- id: basedpyright
1515
name: basedpyright
16-
entry: basedpyright
16+
entry: uv run basedpyright
1717
language: system
1818
types_or: [python, pyi]
1919
pass_filenames: false

src/resolver_athena_client/client/athena_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,18 @@ async def classify_single(
239239
else RequestEncoding.REQUEST_ENCODING_UNCOMPRESSED
240240
)
241241

242+
# Ensure we never send UNSPECIFIED format over the API
243+
# If format is still UNSPECIFIED, default to RAW_UINT8
244+
image_format = processed_image.image_format
245+
if image_format == ImageFormat.IMAGE_FORMAT_UNSPECIFIED:
246+
image_format = ImageFormat.IMAGE_FORMAT_RAW_UINT8
247+
242248
classification_input = ClassificationInput(
243249
affiliate=self.options.affiliate,
244250
correlation_id=correlation_id,
245251
encoding=request_encoding,
246252
data=processed_image.data,
247-
format=ImageFormat.IMAGE_FORMAT_RAW_UINT8,
253+
format=image_format,
248254
hashes=[
249255
ImageHash(
250256
value=hash_value,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Utility for detecting image formats from raw bytes."""
2+
3+
from resolver_athena_client.generated.athena.models_pb2 import ImageFormat
4+
5+
PNG_MAGIC_BYTES = b"\x89PNG"
6+
JPEG_MAGIC_BYTES = b"\xff\xd8\xff"
7+
GIF87A_MAGIC_BYTES = b"GIF87a"
8+
GIF89A_MAGIC_BYTES = b"GIF89a"
9+
BMP_MAGIC_BYTES = b"BM"
10+
WEBP_RIFF_MAGIC_BYTES = b"RIFF"
11+
WEBP_WEBP_MAGIC_BYTES = b"WEBP"
12+
TIFF_LE_MAGIC_BYTES = b"II*\x00"
13+
TIFF_BE_MAGIC_BYTES = b"MM\x00*"
14+
15+
16+
def detect_image_format(data: bytes) -> ImageFormat.ValueType: # noqa: PLR0911
17+
"""Detect image format from raw bytes using magic number signatures.
18+
19+
Args:
20+
----
21+
data: Raw image bytes to analyze
22+
23+
Returns:
24+
-------
25+
ImageFormat enum value representing the detected format
26+
27+
"""
28+
if not data:
29+
return ImageFormat.IMAGE_FORMAT_UNSPECIFIED
30+
31+
# Check magic numbers for common image formats
32+
# PNG: starts with PNG_MAGIC_BYTES
33+
png_len = len(PNG_MAGIC_BYTES)
34+
if len(data) >= png_len and data[:png_len] == PNG_MAGIC_BYTES:
35+
return ImageFormat.IMAGE_FORMAT_PNG
36+
37+
# JPEG: starts with JPEG_MAGIC_BYTES
38+
jpeg_len = len(JPEG_MAGIC_BYTES)
39+
if len(data) >= jpeg_len and data[:jpeg_len] == JPEG_MAGIC_BYTES:
40+
return ImageFormat.IMAGE_FORMAT_JPEG
41+
42+
# GIF: starts with GIF87A_MAGIC_BYTES or GIF89A_MAGIC_BYTES
43+
gif_len = len(GIF87A_MAGIC_BYTES)
44+
if len(data) >= gif_len and data[:gif_len] in (
45+
GIF87A_MAGIC_BYTES,
46+
GIF89A_MAGIC_BYTES,
47+
):
48+
return ImageFormat.IMAGE_FORMAT_GIF
49+
50+
# BMP: starts with BMP_MAGIC_BYTES
51+
bmp_len = len(BMP_MAGIC_BYTES)
52+
if len(data) >= bmp_len and data[:bmp_len] == BMP_MAGIC_BYTES:
53+
return ImageFormat.IMAGE_FORMAT_BMP
54+
55+
# WebP: RIFF....WEBP (12 bytes minimum for full signature)
56+
webp_min_len = len(WEBP_RIFF_MAGIC_BYTES) + len(WEBP_WEBP_MAGIC_BYTES) + 4
57+
if (
58+
len(data) >= webp_min_len
59+
and data[:4] == WEBP_RIFF_MAGIC_BYTES
60+
and data[8:12] == WEBP_WEBP_MAGIC_BYTES
61+
):
62+
return ImageFormat.IMAGE_FORMAT_WEBP
63+
64+
# TIFF: little-endian or big-endian magic bytes
65+
tiff_len = len(TIFF_LE_MAGIC_BYTES)
66+
if len(data) >= tiff_len and (
67+
data[:tiff_len] == TIFF_LE_MAGIC_BYTES
68+
or data[:tiff_len] == TIFF_BE_MAGIC_BYTES
69+
):
70+
return ImageFormat.IMAGE_FORMAT_TIFF
71+
72+
# Fallback when format cannot be determined
73+
return ImageFormat.IMAGE_FORMAT_UNSPECIFIED

src/resolver_athena_client/client/models/input_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
"""
77

88
import hashlib
9+
from typing import TYPE_CHECKING
10+
11+
from resolver_athena_client.client.image_format_detector import (
12+
detect_image_format,
13+
)
14+
15+
if TYPE_CHECKING:
16+
from resolver_athena_client.generated.athena.models_pb2 import ImageFormat
917

1018

1119
class ImageData:
@@ -24,6 +32,8 @@ class ImageData:
2432
Attributes:
2533
----------
2634
data: The raw bytes of the image (modified in-place by transformers).
35+
image_format: The format of the image data (e.g., JPEG, PNG, RAW_UINT8).
36+
Updated by transformers when they change the format.
2737
sha256_hashes: List of SHA256 hashes tracking image transformations.
2838
Index 0 is the original image, subsequent indices track
2939
transformations.
@@ -66,6 +76,9 @@ def __init__(self, image_bytes: bytes) -> None:
6676
6777
"""
6878
self.data: bytes = image_bytes
79+
self.image_format: ImageFormat.ValueType = detect_image_format(
80+
image_bytes
81+
)
6982
self.sha256_hashes: list[str] = [
7083
hashlib.sha256(image_bytes).hexdigest()
7184
]

src/resolver_athena_client/client/transformers/classification_input.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,20 @@ def __init__(
4848
def _create_classification_input(
4949
self, image_data: ImageData
5050
) -> ClassificationInput:
51-
# Get image format and data
51+
# Ensure we never send UNSPECIFIED format over the API
52+
# If format is still UNSPECIFIED, default to RAW_UINT8
53+
image_format = image_data.image_format
54+
if image_format == ImageFormat.IMAGE_FORMAT_UNSPECIFIED:
55+
image_format = ImageFormat.IMAGE_FORMAT_RAW_UINT8
56+
5257
return ClassificationInput(
5358
affiliate=self.affiliate,
5459
correlation_id=self.correlation_provider.get_correlation_id(
5560
image_data.data
5661
),
5762
data=image_data.data,
5863
encoding=self.request_encoding,
59-
format=ImageFormat.IMAGE_FORMAT_RAW_UINT8,
64+
format=image_format,
6065
)
6166

6267
@override

src/resolver_athena_client/client/transformers/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from resolver_athena_client.client.consts import EXPECTED_HEIGHT, EXPECTED_WIDTH
1515
from resolver_athena_client.client.models import ImageData
16+
from resolver_athena_client.generated.athena.models_pb2 import ImageFormat
1617

1718
# Global optimization constants
1819
_target_size = (EXPECTED_WIDTH, EXPECTED_HEIGHT)
@@ -73,6 +74,7 @@ def process_image() -> tuple[bytes, bool]:
7374
# Only modify data and add hashes if transformation occurred
7475
if was_transformed:
7576
image_data.data = resized_bytes
77+
image_data.image_format = ImageFormat.IMAGE_FORMAT_RAW_UINT8
7678
image_data.add_transformation_hashes()
7779

7880
return image_data

tests/client/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for model classes."""
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Tests for ImageData model."""
2+
3+
import pytest
4+
5+
from resolver_athena_client.client.models import ImageData
6+
from resolver_athena_client.generated.athena.models_pb2 import ImageFormat
7+
8+
9+
def test_image_data_detects_png_format() -> None:
10+
"""Test that PNG format is detected during initialization."""
11+
png_data = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
12+
image_data = ImageData(png_data)
13+
14+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_PNG
15+
assert image_data.data == png_data
16+
assert len(image_data.sha256_hashes) == 1
17+
assert len(image_data.md5_hashes) == 1
18+
19+
20+
def test_image_data_detects_jpeg_format() -> None:
21+
"""Test that JPEG format is detected during initialization."""
22+
jpeg_data = b"\xff\xd8\xff\xe0" + b"\x00" * 100
23+
image_data = ImageData(jpeg_data)
24+
25+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_JPEG
26+
assert image_data.data == jpeg_data
27+
28+
29+
def test_image_data_detects_gif_format() -> None:
30+
"""Test that GIF format is detected during initialization."""
31+
gif_data = b"GIF89a" + b"\x00" * 100
32+
image_data = ImageData(gif_data)
33+
34+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_GIF
35+
36+
37+
def test_image_data_detects_bmp_format() -> None:
38+
"""Test that BMP format is detected during initialization."""
39+
bmp_data = b"BM" + b"\x00" * 100
40+
image_data = ImageData(bmp_data)
41+
42+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_BMP
43+
44+
45+
def test_image_data_detects_webp_format() -> None:
46+
"""Test that WebP format is detected during initialization."""
47+
webp_data = b"RIFF\x00\x00\x00\x00WEBP" + b"\x00" * 100
48+
image_data = ImageData(webp_data)
49+
50+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_WEBP
51+
52+
53+
def test_image_data_unspecified_for_unknown_format() -> None:
54+
"""Test that unknown data results in UNSPECIFIED format."""
55+
unknown_data = b"not_a_valid_image"
56+
image_data = ImageData(unknown_data)
57+
58+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_UNSPECIFIED
59+
60+
61+
def test_image_data_unspecified_for_empty_data() -> None:
62+
"""Test that empty data results in UNSPECIFIED format."""
63+
image_data = ImageData(b"")
64+
65+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_UNSPECIFIED
66+
67+
68+
def test_image_data_transformation_preserves_format() -> None:
69+
"""Test that format is preserved when transformation hashes are added."""
70+
png_data = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
71+
image_data = ImageData(png_data)
72+
73+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_PNG
74+
75+
# Simulate transformation
76+
image_data.data = b"transformed_data"
77+
image_data.add_transformation_hashes()
78+
79+
# Format should still be PNG (transformers will update it if needed)
80+
assert image_data.image_format == ImageFormat.IMAGE_FORMAT_PNG
81+
assert len(image_data.sha256_hashes) == 2 # noqa: PLR2004
82+
assert len(image_data.md5_hashes) == 2 # noqa: PLR2004
83+
84+
85+
@pytest.mark.parametrize(
86+
("data", "expected_format"),
87+
[
88+
(b"\x89PNG\r\n\x1a\n", ImageFormat.IMAGE_FORMAT_PNG),
89+
(b"\xff\xd8\xff", ImageFormat.IMAGE_FORMAT_JPEG),
90+
(b"GIF87a", ImageFormat.IMAGE_FORMAT_GIF),
91+
(b"GIF89a", ImageFormat.IMAGE_FORMAT_GIF),
92+
(b"BM", ImageFormat.IMAGE_FORMAT_BMP),
93+
(b"RIFF\x00\x00\x00\x00WEBP", ImageFormat.IMAGE_FORMAT_WEBP),
94+
(b"II*\x00", ImageFormat.IMAGE_FORMAT_TIFF),
95+
(b"MM\x00*", ImageFormat.IMAGE_FORMAT_TIFF),
96+
(b"unknown", ImageFormat.IMAGE_FORMAT_UNSPECIFIED),
97+
(b"", ImageFormat.IMAGE_FORMAT_UNSPECIFIED),
98+
],
99+
)
100+
def test_image_data_format_detection_parametrized(
101+
data: bytes, expected_format: ImageFormat.ValueType
102+
) -> None:
103+
"""Test format detection with various image data."""
104+
image_data = ImageData(data)
105+
assert image_data.image_format == expected_format

0 commit comments

Comments
 (0)