diff --git a/swift/template/base.py b/swift/template/base.py index e4ec978021..63aed1ec83 100644 --- a/swift/template/base.py +++ b/swift/template/base.py @@ -777,8 +777,8 @@ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None @staticmethod def _save_pil_image(image: Image.Image) -> str: - img_bytes = image.tobytes() - img_hash = hashlib.sha256(img_bytes).hexdigest() + img_meta = f'{image.mode}:{image.width}:{image.height}:'.encode() + img_hash = hashlib.sha256(img_meta + image.tobytes()).hexdigest() tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images') logger.info_once(f'create tmp_dir: {tmp_dir}') os.makedirs(tmp_dir, exist_ok=True) diff --git a/tests/llm/test_template.py b/tests/llm/test_template.py index 062cb439fd..1da378652a 100644 --- a/tests/llm/test_template.py +++ b/tests/llm/test_template.py @@ -1,10 +1,14 @@ import os +import tempfile import torch import unittest +from unittest.mock import patch +from PIL import Image from swift.infer_engine import RequestConfig, TransformersEngine from swift.model import get_processor from swift.template import get_template +from swift.template.base import Template from swift.utils import get_logger, seed_everything # os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -103,6 +107,28 @@ def test_tool_message_join(self): f'{observation}tool2\n{observation}tool3\n') assert res == ground_truth + def test_save_pil_image_uses_dimensions_in_cache_key(self): + width_a, height_a = 120, 80 + width_b, height_b = 80, 120 + self.assertEqual(width_a * height_a, width_b * height_b) + + pixels = bytearray() + for i in range(width_a * height_a): + row = i // width_a + pixels.extend((255, 60, 60) if row % 10 < 5 else (60, 60, 255)) + img_bytes = bytes(pixels) + + image_a = Image.frombytes('RGB', (width_a, height_a), img_bytes) + image_b = Image.frombytes('RGB', (width_b, height_b), img_bytes) + + with tempfile.TemporaryDirectory() as cache_dir, patch('swift.template.base.get_cache_dir', return_value=cache_dir): + path_a = Template._save_pil_image(image_a) + path_b = Template._save_pil_image(image_b) + + self.assertNotEqual(path_a, path_b) + self.assertEqual(Image.open(path_a).size, (width_a, height_a)) + self.assertEqual(Image.open(path_b).size, (width_b, height_b)) + if __name__ == '__main__': unittest.main()