Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion swift/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,12 @@ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None

@staticmethod
def _save_pil_image(image: Image.Image) -> str:
# `Image.tobytes()` only returns the flattened pixel stream, without mode or shape.
# Include them in the cache key so images that share pixel bytes but differ in
# mode/size do not collide onto the same cached file.
img_bytes = image.tobytes()
img_hash = hashlib.sha256(img_bytes).hexdigest()
meta = f'{image.mode}-{image.width}x{image.height}-'.encode()
img_hash = hashlib.sha256(meta + img_bytes).hexdigest()
Comment on lines +817 to +818

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Concatenating meta + img_bytes creates a new bytes object in memory, which copies the entire image byte stream. For large images, this can lead to unnecessary memory overhead and performance degradation. Instead, you can update the hash incrementally using hasher.update() to avoid this extra memory allocation.

Suggested change
meta = f'{image.mode}-{image.width}x{image.height}-'.encode()
img_hash = hashlib.sha256(meta + img_bytes).hexdigest()
meta = f'{image.mode}-{image.width}x{image.height}-'.encode()
hasher = hashlib.sha256(meta)
hasher.update(img_bytes)
img_hash = hasher.hexdigest()

tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
logger.info_once(f'create tmp_dir: {tmp_dir}')
os.makedirs(tmp_dir, exist_ok=True)
Expand Down
30 changes: 30 additions & 0 deletions tests/general/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,38 @@ def test_mllm_dataset_map():
_test_dataset_map('Qwen/Qwen2-VL-7B-Instruct', 'modelscope/coco_2014_caption:validation#100')


def test_save_pil_image_dimension_collision():
from PIL import Image

from swift.template.base import Template

# Two images that share the same flattened pixel bytes but differ in shape.
width_a, height_a = 120, 80
width_b, height_b = 80, 120
assert width_a * height_a == width_b * height_b
pixels = bytearray()
for i in range(width_a * height_a):
row = i // width_a
pixels.extend((255, 60, 60) if row % 10 < 5 else (60, 60, 255))
img_bytes = bytes(pixels)
image_a = Image.frombytes('RGB', (width_a, height_a), img_bytes)
image_b = Image.frombytes('RGB', (width_b, height_b), img_bytes)
assert image_a.tobytes() == image_b.tobytes()

path_a = Template._save_pil_image(image_a)
path_b = Template._save_pil_image(image_b)

# Different dimensions must not collide onto the same cache file.
assert path_a != path_b
with Image.open(path_a) as saved_a:
assert saved_a.size == (width_a, height_a)
with Image.open(path_b) as saved_b:
assert saved_b.size == (width_b, height_b)


if __name__ == '__main__':
test_template()
test_mllm()
test_llm_dataset_map()
test_mllm_dataset_map()
test_save_pil_image_dimension_collision()