From 0cf1cd63915264e3f6fcdcc21fa01d16ebdc761e Mon Sep 17 00:00:00 2001 From: aGallea Date: Mon, 30 Mar 2026 17:21:55 +0300 Subject: [PATCH 1/6] feat(deps): replace openai with litellm for multi-provider LLM support --- pyproject.toml | 4 +- uv.lock | 147 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 144 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 61f6c42..553ee41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "dash>=3.4,<4", "plotly>=6.5,<7", "aiohttp>=3.11,<4", - "openai>=1.60,<2", + "litellm>=1.65,<2", "scikit-learn>=1.6,<2", "numpy>=2.2,<3", "Pillow>=11,<12", @@ -83,7 +83,7 @@ module = [ "plotly.*", "sklearn.*", "PIL.*", - "openai.*", + "litellm.*", "aiohttp.*", "umap.*", ] diff --git a/uv.lock b/uv.lock index db32f7a..ad1128b 100644 --- a/uv.lock +++ b/uv.lock @@ -366,8 +366,8 @@ dependencies = [ { name = "chromadb", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "dash", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "fastapi", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "litellm", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "numpy", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "openai", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "pillow", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "plotly", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "pydantic", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -402,9 +402,9 @@ requires-dist = [ { name = "dash", specifier = ">=3.4,<4" }, { name = "fastapi", specifier = ">=0.115,<1" }, { name = "httpx", marker = "extra == 'dev'", specifier = ">=0.28,<1" }, + { name = "litellm", specifier = ">=1.65,<2" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.14,<2" }, { name = "numpy", specifier = ">=2.2,<3" }, - { name = "openai", specifier = ">=1.60,<2" }, { name = "pillow", specifier = ">=11,<12" }, { name = "plotly", specifier = ">=6.5,<7" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4,<5" }, @@ -441,6 +441,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/dd/d0ee25348ac58245ee9f90b6f3cbb666bf01f69be7e0911f9851bddbda16/fastapi-0.129.0-py3-none-any.whl", hash = "sha256:b4946880e48f462692b31c083be0432275cbfb6e2274566b1be91479cc1a84ec", size = 102950, upload-time = "2026-02-12T13:54:54.528Z" }, ] +[[package]] +name = "fastuuid" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/7d/d9daedf0f2ebcacd20d599928f8913e9d2aea1d56d2d355a93bfa2b611d7/fastuuid-0.14.0.tar.gz", hash = "sha256:178947fc2f995b38497a74172adee64fdeb8b7ec18f2a5934d037641ba265d26", size = 18232, upload-time = "2025-10-19T22:19:22.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/83/ae12dd39b9a39b55d7f90abb8971f1a5f3c321fd72d5aa83f90dc67fe9ed/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77a09cb7427e7af74c594e409f7731a0cf887221de2f698e1ca0ebf0f3139021", size = 510720, upload-time = "2025-10-19T22:42:34.633Z" }, + { url = "https://files.pythonhosted.org/packages/53/b0/a4b03ff5d00f563cc7546b933c28cb3f2a07344b2aec5834e874f7d44143/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9bd57289daf7b153bfa3e8013446aa144ce5e8c825e9e366d455155ede5ea2dc", size = 262024, upload-time = "2025-10-19T22:30:25.482Z" }, + { url = "https://files.pythonhosted.org/packages/9c/6d/64aee0a0f6a58eeabadd582e55d0d7d70258ffdd01d093b30c53d668303b/fastuuid-0.14.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac60fc860cdf3c3f327374db87ab8e064c86566ca8c49d2e30df15eda1b0c2d5", size = 251679, upload-time = "2025-10-19T22:36:14.096Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d3/8ce11827c783affffd5bd4d6378b28eb6cc6d2ddf41474006b8d62e7448e/fastuuid-0.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e678459cf4addaedd9936bbb038e35b3f6b2061330fd8f2f6a1d80414c0f87", size = 278278, upload-time = "2025-10-19T22:29:43.809Z" }, + { url = "https://files.pythonhosted.org/packages/bc/17/354d04951ce114bf4afc78e27a18cfbd6ee319ab1829c2d5fb5e94063ac6/fastuuid-0.14.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1383fff584fa249b16329a059c68ad45d030d5a4b70fb7c73a08d98fd53bcdab", size = 450921, upload-time = "2025-10-19T22:31:02.151Z" }, + { url = "https://files.pythonhosted.org/packages/16/c9/8c7660d1fe3862e3f8acabd9be7fc9ad71eb270f1c65cce9a2b7a31329ab/fastuuid-0.14.0-cp314-cp314-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:b852a870a61cfc26c884af205d502881a2e59cc07076b60ab4a951cc0c94d1ad", size = 510600, upload-time = "2025-10-19T22:43:44.17Z" }, + { url = "https://files.pythonhosted.org/packages/4c/f4/a989c82f9a90d0ad995aa957b3e572ebef163c5299823b4027986f133dfb/fastuuid-0.14.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:c7502d6f54cd08024c3ea9b3514e2d6f190feb2f46e6dbcd3747882264bb5f7b", size = 262069, upload-time = "2025-10-19T22:43:38.38Z" }, + { url = "https://files.pythonhosted.org/packages/da/6c/a1a24f73574ac995482b1326cf7ab41301af0fabaa3e37eeb6b3df00e6e2/fastuuid-0.14.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1ca61b592120cf314cfd66e662a5b54a578c5a15b26305e1b8b618a6f22df714", size = 251543, upload-time = "2025-10-19T22:32:22.537Z" }, + { url = "https://files.pythonhosted.org/packages/ef/33/4105ca574f6ded0af6a797d39add041bcfb468a1255fbbe82fcb6f592da2/fastuuid-0.14.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8a0dfea3972200f72d4c7df02c8ac70bad1bb4c58d7e0ec1e6f341679073a7f", size = 278283, upload-time = "2025-10-19T22:29:02.812Z" }, + { url = "https://files.pythonhosted.org/packages/e0/c8/2ce1c78f983a2c4987ea865d9516dbdfb141a120fd3abb977ae6f02ba7ca/fastuuid-0.14.0-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:ec27778c6ca3393ef662e2762dba8af13f4ec1aaa32d08d77f71f2a70ae9feb8", size = 450837, upload-time = "2025-10-19T22:34:37.178Z" }, +] + [[package]] name = "filelock" version = "3.24.3" @@ -733,6 +751,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, ] +[[package]] +name = "jsonschema" +version = "4.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jsonschema-specifications", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "referencing", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rpds-py", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + [[package]] name = "kubernetes" version = "35.0.0" @@ -773,6 +818,29 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/39/191d3d28abc26c9099b19852e6c99f7f6d400b82fa5a4e80291bd3803e19/librt-0.8.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cc3656283d11540ab0ea01978378e73e10002145117055e03722417aeab30994", size = 263001, upload-time = "2026-02-17T16:12:43.627Z" }, ] +[[package]] +name = "litellm" +version = "1.82.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "click", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "fastuuid", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "httpx", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "importlib-metadata", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jinja2", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "jsonschema", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "openai", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "pydantic", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "python-dotenv", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tiktoken", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tokenizers", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, +] + [[package]] name = "llvmlite" version = "0.46.0" @@ -1176,7 +1244,7 @@ wheels = [ [[package]] name = "openai" -version = "1.109.1" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -1188,9 +1256,9 @@ dependencies = [ { name = "tqdm", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "typing-extensions", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/a1/a303104dc55fc546a3f6914c842d3da471c64eec92043aef8f652eb6c524/openai-1.109.1.tar.gz", hash = "sha256:d173ed8dbca665892a6db099b4a2dfac624f94d20a93f46eb0b56aae940ed869", size = 564133, upload-time = "2025-09-24T13:00:53.075Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084, upload-time = "2026-03-25T22:08:59.96Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/2a/7dd3d207ec669cacc1f186fd856a0f61dbc255d24f6fdc1a6715d6051b0f/openai-1.109.1-py3-none-any.whl", hash = "sha256:6bcaf57086cf59159b8e27447e4e7dd019db5d29a438072fbd49c290c7e65315", size = 948627, upload-time = "2025-09-24T13:00:50.754Z" }, + { url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656, upload-time = "2026-03-25T22:08:58.2Z" }, ] [[package]] @@ -1695,6 +1763,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/92/1446574745d74df0c92e6aa4a7b0b3130706a4142b2d1a5869f2eaa423c6/pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65", size = 829923, upload-time = "2025-09-25T21:32:54.537Z" }, ] +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "rpds-py", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + [[package]] name = "regex" version = "2026.1.15" @@ -1773,6 +1854,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/45/615f5babd880b4bd7d405cc0dc348234c5ffb6ed1ea33e152ede08b2072d/rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69", size = 309963, upload-time = "2026-02-01T16:20:46.078Z" }, ] +[[package]] +name = "rpds-py" +version = "0.30.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2", size = 374887, upload-time = "2025-11-30T20:22:41.812Z" }, + { url = "https://files.pythonhosted.org/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8", size = 358904, upload-time = "2025-11-30T20:22:43.479Z" }, + { url = "https://files.pythonhosted.org/packages/b7/de/f7192e12b21b9e9a68a6d0f249b4af3fdcdff8418be0767a627564afa1f1/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9027da1ce107104c50c81383cae773ef5c24d296dd11c99e2629dbd7967a20c6", size = 394025, upload-time = "2025-11-30T20:22:50.196Z" }, + { url = "https://files.pythonhosted.org/packages/5f/60/525a50f45b01d70005403ae0e25f43c0384369ad24ffe46e8d9068b50086/rpds_py-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:945dccface01af02675628334f7cf49c2af4c1c904748efc5cf7bbdf0b579f95", size = 563020, upload-time = "2025-11-30T20:22:58.2Z" }, + { url = "https://files.pythonhosted.org/packages/ff/1b/b10de890a0def2a319a2626334a7f0ae388215eb60914dbac8a3bae54435/rpds_py-0.30.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:eb0b93f2e5c2189ee831ee43f156ed34e2a89a78a66b98cadad955972548be5a", size = 364443, upload-time = "2025-11-30T20:23:04.878Z" }, + { url = "https://files.pythonhosted.org/packages/0d/bf/27e39f5971dc4f305a4fb9c672ca06f290f7c4e261c568f3dea16a410d47/rpds_py-0.30.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:922e10f31f303c7c920da8981051ff6d8c1a56207dbdf330d9047f6d30b70e5e", size = 353375, upload-time = "2025-11-30T20:23:06.342Z" }, + { url = "https://files.pythonhosted.org/packages/60/ca/780cf3b1a32b18c0f05c441958d3758f02544f1d613abf9488cd78876378/rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a1234d8febafdfd33a42d97da7a43f5dcb120c1060e352a3fbc0c6d36e2083", size = 383843, upload-time = "2025-11-30T20:23:14.638Z" }, + { url = "https://files.pythonhosted.org/packages/6d/61/21b8c41f68e60c8cc3b2e25644f0e3681926020f11d06ab0b78e3c6bbff1/rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c5f36a861bc4b7da6516dbdf302c55313afa09b81931e8280361a4f6c9a2d27", size = 555806, upload-time = "2025-11-30T20:23:22.488Z" }, + { url = "https://files.pythonhosted.org/packages/86/81/dad16382ebbd3d0e0328776d8fd7ca94220e4fa0798d1dc5e7da48cb3201/rpds_py-0.30.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:68f19c879420aa08f61203801423f6cd5ac5f0ac4ac82a2368a9fcd6a9a075e0", size = 362099, upload-time = "2025-11-30T20:23:27.316Z" }, + { url = "https://files.pythonhosted.org/packages/2b/60/19f7884db5d5603edf3c6bce35408f45ad3e97e10007df0e17dd57af18f8/rpds_py-0.30.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ec7c4490c672c1a0389d319b3a9cfcd098dcdc4783991553c332a15acf7249be", size = 353192, upload-time = "2025-11-30T20:23:29.151Z" }, + { url = "https://files.pythonhosted.org/packages/ce/81/9a91c0111ce1758c92516a3e44776920b579d9a7c09b2b06b642d4de3f0f/rpds_py-0.30.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47e77dc9822d3ad616c3d5759ea5631a75e5809d5a28707744ef79d7a1bcfcad", size = 382112, upload-time = "2025-11-30T20:23:36.842Z" }, + { url = "https://files.pythonhosted.org/packages/21/20/7ff5f3c8b00c8a95f75985128c26ba44503fb35b8e0259d812766ea966c7/rpds_py-0.30.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:46e83c697b1f1c72b50e5ee5adb4353eef7406fb3f2043d64c33f20ad1c2fc53", size = 553371, upload-time = "2025-11-30T20:23:46.004Z" }, + { url = "https://files.pythonhosted.org/packages/9e/68/154fe0194d83b973cdedcdcc88947a2752411165930182ae41d983dcefa6/rpds_py-0.30.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:7e6ecfcb62edfd632e56983964e6884851786443739dbfe3582947e87274f7cb", size = 364868, upload-time = "2025-11-30T20:23:52.494Z" }, + { url = "https://files.pythonhosted.org/packages/83/69/8bbc8b07ec854d92a8b75668c24d2abcb1719ebf890f5604c61c9369a16f/rpds_py-0.30.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a1d0bc22a7cdc173fedebb73ef81e07faef93692b8c1ad3733b67e31e1b6e1b8", size = 353747, upload-time = "2025-11-30T20:23:54.036Z" }, + { url = "https://files.pythonhosted.org/packages/c2/c7/736e00ebf39ed81d75544c0da6ef7b0998f8201b369acf842f9a90dc8fce/rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:626a7433c34566535b6e56a1b39a7b17ba961e97ce3b80ec62e6f1312c025551", size = 383765, upload-time = "2025-11-30T20:24:01.759Z" }, + { url = "https://files.pythonhosted.org/packages/85/70/92482ccffb96f5441aab93e26c4d66489eb599efdcf96fad90c14bbfb976/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:dbd936cde57abfee19ab3213cf9c26be06d60750e60a8e4dd85d1ab12c8b1f40", size = 556030, upload-time = "2025-11-30T20:24:10.956Z" }, +] + [[package]] name = "ruff" version = "0.15.1" @@ -1954,6 +2059,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, ] +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "requests", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://files.pythonhosted.org/packages/72/05/3abc1db5d2c9aadc4d2c76fa5640134e475e58d9fbb82b5c535dc0de9b01/tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646", size = 1050188, upload-time = "2025-10-06T20:22:19.563Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7b/50c2f060412202d6c95f32b20755c7a6273543b125c0985d6fa9465105af/tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88", size = 993978, upload-time = "2025-10-06T20:22:20.702Z" }, + { url = "https://files.pythonhosted.org/packages/f5/de/9341a6d7a8f1b448573bbf3425fa57669ac58258a667eb48a25dfe916d70/tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830", size = 1151216, upload-time = "2025-10-06T20:22:23.085Z" }, + { url = "https://files.pythonhosted.org/packages/b3/1e/b651ec3059474dab649b8d5b69f5c65cd8fcd8918568c1935bd4136c9392/tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b", size = 1254567, upload-time = "2025-10-06T20:22:25.671Z" }, + { url = "https://files.pythonhosted.org/packages/ac/a4/72eed53e8976a099539cdd5eb36f241987212c29629d0a52c305173e0a68/tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365", size = 1050473, upload-time = "2025-10-06T20:22:27.775Z" }, + { url = "https://files.pythonhosted.org/packages/e6/d7/0110b8f54c008466b19672c615f2168896b83706a6611ba6e47313dbc6e9/tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e", size = 993855, upload-time = "2025-10-06T20:22:28.799Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2b/fc46c90fe5028bd094cd6ee25a7db321cb91d45dc87531e2bdbb26b4867a/tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0", size = 1150736, upload-time = "2025-10-06T20:22:30.996Z" }, + { url = "https://files.pythonhosted.org/packages/ab/0d/c1ad6f4016a3968c048545f5d9b8ffebf577774b2ede3e2e352553b685fe/tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0", size = 1253706, upload-time = "2025-10-06T20:22:33.385Z" }, +] + [[package]] name = "tokenizers" version = "0.22.2" @@ -2001,6 +2134,10 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/f4/39/590742415c3030551944edc2ddc273ea1fdfe8ffb2780992e824f1ebee98/torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328", size = 915632474, upload-time = "2026-03-11T14:15:13.666Z" }, + { url = "https://files.pythonhosted.org/packages/b6/8e/34949484f764dde5b222b7fe3fede43e4a6f0da9d7f8c370bb617d629ee2/torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591", size = 915523882, upload-time = "2026-03-11T14:14:46.311Z" }, { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, From 19f7415e5e99ae4ac61fe0300ad9a508c7e71528 Mon Sep 17 00:00:00 2001 From: aGallea Date: Mon, 30 Mar 2026 17:36:02 +0300 Subject: [PATCH 2/6] refactor: remove GPT cluster naming code and add name field to SubClusterInfo --- embedding_cluster/ai_naming.py | 111 +++++++++++++ embedding_cluster/scatter_plot.py | 55 +------ embedding_cluster/server/models.py | 4 +- embedding_cluster/server/routes/plot.py | 3 - embedding_cluster/settings.py | 8 - tests/test_ai_naming.py | 198 ++++++++++++++++++++++++ tests/test_scatter_plot.py | 88 +---------- tests/test_server_plot.py | 3 - 8 files changed, 319 insertions(+), 151 deletions(-) create mode 100644 embedding_cluster/ai_naming.py create mode 100644 tests/test_ai_naming.py diff --git a/embedding_cluster/ai_naming.py b/embedding_cluster/ai_naming.py new file mode 100644 index 0000000..ed62fd2 --- /dev/null +++ b/embedding_cluster/ai_naming.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import logging + +import litellm + +logger = logging.getLogger(__name__) + +# Alias for testability (easy to mock) +litellm_completion = litellm.completion + +SYSTEM_PROMPT_TOP_LEVEL = ( + "Your role is to find a very short (max 5 words), concise name " + "for a group of items, one name to rule them all. " + "The user will provide a list of item names. Do your best." +) + +SYSTEM_PROMPT_SUB_CLUSTER = ( + "Your role is to find a very short (max 5 words), concise name " + "for a sub-group of items within a larger group called " + '"{parent_name}". ' + "The name should distinguish this sub-group from its siblings " + "while relating to the parent theme. The user will provide a " + "list of item names. Do your best." +) + + +def _call_llm( + messages: list[dict[str, str]], + api_key: str, + model: str, + base_url: str | None = None, + temperature: float = 0.5, +) -> str: + """Call LiteLLM and return the response content.""" + kwargs: dict[str, object] = { + "model": model, + "messages": messages, + "api_key": api_key, + "temperature": temperature, + } + if base_url: + kwargs["api_base"] = base_url + + response = litellm_completion(**kwargs) + content: str = response.choices[0].message.content or "" + return (content[:30] + "..") if len(content) > 30 else content + + +def get_cluster_name( + item_names: list[str], + api_key: str, + model: str, + base_url: str | None = None, + temperature: float = 0.5, +) -> str: + """Generate a short name for a cluster of items.""" + user_content = "\n".join(f"name: {name}" for name in item_names) + messages = [ + {"role": "system", "content": SYSTEM_PROMPT_TOP_LEVEL}, + {"role": "user", "content": user_content}, + ] + return _call_llm(messages, api_key, model, base_url, temperature) + + +def get_sub_cluster_name( + item_names: list[str], + api_key: str, + model: str, + base_url: str | None = None, + temperature: float = 0.5, + parent_cluster_name: str | None = None, +) -> str: + """Generate a short name for a sub-cluster.""" + if parent_cluster_name: + system_content = SYSTEM_PROMPT_SUB_CLUSTER.format( + parent_name=parent_cluster_name, + ) + else: + system_content = SYSTEM_PROMPT_TOP_LEVEL + + user_content = "\n".join(f"name: {name}" for name in item_names) + messages = [ + {"role": "system", "content": system_content}, + {"role": "user", "content": user_content}, + ] + return _call_llm(messages, api_key, model, base_url, temperature) + + +def test_connection( + api_key: str, + model: str, + base_url: str | None = None, +) -> tuple[bool, str | None]: + """Test LLM connection. Returns (success, error).""" + try: + kwargs: dict[str, object] = { + "model": model, + "messages": [{"role": "user", "content": "Say hello"}], + "api_key": api_key, + "max_tokens": 5, + } + if base_url: + kwargs["api_base"] = base_url + litellm_completion(**kwargs) + return True, None + except Exception as exc: + error_msg = str(exc) + if api_key and api_key in error_msg: + error_msg = error_msg.replace(api_key, "***") + return False, error_msg diff --git a/embedding_cluster/scatter_plot.py b/embedding_cluster/scatter_plot.py index 532d1f1..13ecedc 100644 --- a/embedding_cluster/scatter_plot.py +++ b/embedding_cluster/scatter_plot.py @@ -1,14 +1,12 @@ from __future__ import annotations import logging -import random from typing import TYPE_CHECKING, Any import chromadb import numpy as np import plotly.graph_objects as go from dash import Dash, Input, Output, callback, dcc, html, no_update -from openai import OpenAI from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.manifold import TSNE @@ -71,31 +69,6 @@ def reduce_dimensions( return result -def gpt_get_cluster_name(info: str, settings: Settings) -> str: - openai_client = OpenAI() - messages: list[dict[str, str]] = [ - { - "role": "system", - "content": ( - "Your role is to find a very short (max 5 words), concise " - "name for a group of items, one name to rule them all. " - "the user will provide a list of item names. do your best" - ), - }, - { - "role": "user", - "content": info, - }, - ] - completion = openai_client.chat.completions.create( - model=settings.gpt_default_model, - temperature=settings.gpt_default_temperature, - messages=messages, # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - ) - content = completion.choices[0].message.content or "" - return (content[:30] + "..") if len(content) > 30 else content - - def load_chromadb_collection(settings: Settings) -> Any: chromadb_client: ClientAPI = chromadb.PersistentClient(path="./chromadb") collection = chromadb_client.get_or_create_collection( @@ -201,9 +174,9 @@ def generate_cluster_props( num_clusters: int, pred_arr: Any, collection_content_text_display: list[str], - settings: Settings, num_products_for_cluster_name: int = 10, ) -> tuple[list[list[int]], list[str]]: + _ = (collection_content_text_display, num_products_for_cluster_name) clusters_indices: list[list[int]] = [] cluster_names: list[str] = [] group_index = 1 @@ -211,28 +184,8 @@ def generate_cluster_props( curr_cluster_indices = [i for i, x in enumerate(pred_arr) if x == cluster_i] clusters_indices.append(curr_cluster_indices) logger.info("Generating cluster %d names ...", cluster_i) - if settings.gpt_generate_cluster_name is True: - random_product_indexes = random.sample( - range(0, len(curr_cluster_indices)), - min( - num_products_for_cluster_name, - len(curr_cluster_indices), - ), - ) - curr_descriptions = "" - for product_index in random_product_indexes: - idx = curr_cluster_indices[product_index] - item = ( - collection_content_text_display[idx] - if idx < len(collection_content_text_display) - else f"Item {idx}" - ) - curr_descriptions += f"name: {item} \n" - cluster_name = gpt_get_cluster_name(curr_descriptions, settings) - cluster_names.append(cluster_name) - else: - cluster_names.append(f"Group {group_index}") - group_index += 1 + cluster_names.append(f"Group {group_index}") + group_index += 1 return clusters_indices, cluster_names @@ -288,7 +241,7 @@ def compute_plot_data(settings: Settings) -> dict[str, Any]: ) clusters_indices, cluster_names = generate_cluster_props( - num_clusters, pred_arr, collection_content_text_display, settings + num_clusters, pred_arr, collection_content_text_display ) # Build structured point data diff --git a/embedding_cluster/server/models.py b/embedding_cluster/server/models.py index 5f42c1f..2a036df 100644 --- a/embedding_cluster/server/models.py +++ b/embedding_cluster/server/models.py @@ -70,9 +70,6 @@ class PlotRequest(BaseModel): num_clusters: int = 10 text_display_fields: list[str] | None = None image_field: str | None = None - gpt_generate_cluster_name: bool = False - gpt_default_model: str = "gpt-3.5-turbo" - gpt_default_temperature: float = 0.51 reduction_algorithm: Literal["tsne", "umap", "pca"] = "tsne" tsne_perplexity: float = 30.0 tsne_learning_rate: str = "auto" @@ -199,6 +196,7 @@ class SubClusterInfo(BaseModel): index: int count: int color: str + name: str | None = None class SubClusterResponse(BaseModel): diff --git a/embedding_cluster/server/routes/plot.py b/embedding_cluster/server/routes/plot.py index 35da576..5529f96 100644 --- a/embedding_cluster/server/routes/plot.py +++ b/embedding_cluster/server/routes/plot.py @@ -43,9 +43,6 @@ async def _run_compute(task_state: TaskState, request: PlotRequest) -> None: num_clusters=request.num_clusters, text_display_fields=request.text_display_fields, image_field=request.image_field, - gpt_generate_cluster_name=request.gpt_generate_cluster_name, - gpt_default_model=request.gpt_default_model, - gpt_default_temperature=request.gpt_default_temperature, reduction_algorithm=request.reduction_algorithm, tsne_perplexity=request.tsne_perplexity, tsne_learning_rate=request.tsne_learning_rate, diff --git a/embedding_cluster/settings.py b/embedding_cluster/settings.py index a7677ec..f424f31 100644 --- a/embedding_cluster/settings.py +++ b/embedding_cluster/settings.py @@ -75,11 +75,3 @@ class Settings(BaseSettings): umap_n_neighbors: int = Field(default=15, description="UMAP number of neighbors") umap_min_dist: float = Field(default=0.1, description="UMAP minimum distance") umap_metric: str = Field(default="cosine", description="UMAP distance metric") - - gpt_generate_cluster_name: bool = Field( - default=False, description="Generate cluster names using GPT" - ) - gpt_default_model: str = Field(default="gpt-3.5-turbo", description="GPT model name") - gpt_default_temperature: float = Field( - default=0.51, description="GPT model temperature" - ) diff --git a/tests/test_ai_naming.py b/tests/test_ai_naming.py new file mode 100644 index 0000000..1bc308a --- /dev/null +++ b/tests/test_ai_naming.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + + +class TestGetClusterName: + def test_returns_short_name(self) -> None: + from embedding_cluster.ai_naming import get_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Athletic Footwear" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ): + result = get_cluster_name( + item_names=["Running Shoes", "Basketball Sneakers"], + api_key="test-key", + model="gpt-4o-mini", + ) + + assert result == "Athletic Footwear" + + def test_truncates_long_name(self) -> None: + from embedding_cluster.ai_naming import get_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "A" * 50 + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ): + result = get_cluster_name( + item_names=["item1"], + api_key="test-key", + model="gpt-4o-mini", + ) + + assert len(result) == 32 # 30 chars + ".." + + def test_none_content_returns_empty(self) -> None: + from embedding_cluster.ai_naming import get_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = None + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ): + result = get_cluster_name( + item_names=["item1"], + api_key="test-key", + model="gpt-4o-mini", + ) + + assert result == "" + + def test_passes_base_url(self) -> None: + from embedding_cluster.ai_naming import get_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Name" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ) as mock_completion: + get_cluster_name( + item_names=["item1"], + api_key="test-key", + model="gpt-4o-mini", + base_url="http://localhost:11434", + ) + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["api_base"] == "http://localhost:11434" + + def test_passes_temperature(self) -> None: + from embedding_cluster.ai_naming import get_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Name" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ) as mock_completion: + get_cluster_name( + item_names=["item1"], + api_key="test-key", + model="gpt-4o-mini", + temperature=0.7, + ) + + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["temperature"] == 0.7 + + +class TestGetSubClusterName: + def test_includes_parent_context(self) -> None: + from embedding_cluster.ai_naming import get_sub_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Running Shoes" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ) as mock_completion: + result = get_sub_cluster_name( + item_names=["Nike Air Max", "Adidas Ultraboost"], + api_key="test-key", + model="gpt-4o-mini", + parent_cluster_name="Athletic Footwear", + ) + + assert result == "Running Shoes" + call_kwargs = mock_completion.call_args[1] + system_msg = call_kwargs["messages"][0]["content"] + assert "Athletic Footwear" in system_msg + + def test_without_parent_name_uses_default_prompt(self) -> None: + from embedding_cluster.ai_naming import get_sub_cluster_name + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Sub Name" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ) as mock_completion: + result = get_sub_cluster_name( + item_names=["item1"], + api_key="test-key", + model="gpt-4o-mini", + ) + + assert result == "Sub Name" + call_kwargs = mock_completion.call_args[1] + system_msg = call_kwargs["messages"][0]["content"] + assert "sub-group" not in system_msg + + +class TestTestConnection: + def test_success(self) -> None: + from embedding_cluster.ai_naming import test_connection + + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Hello" + mock_response.choices = [mock_choice] + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=mock_response, + ): + success, error = test_connection( + api_key="test-key", + model="gpt-4o-mini", + ) + + assert success is True + assert error is None + + def test_failure_redacts_key(self) -> None: + from embedding_cluster.ai_naming import test_connection + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + side_effect=Exception("Invalid API key: sk-1234567890abcdef"), + ): + success, error = test_connection( + api_key="sk-1234567890abcdef", + model="gpt-4o-mini", + ) + + assert success is False + assert error is not None + assert "sk-1234567890abcdef" not in error diff --git a/tests/test_scatter_plot.py b/tests/test_scatter_plot.py index 6e8128e..2e1227b 100644 --- a/tests/test_scatter_plot.py +++ b/tests/test_scatter_plot.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import cast from unittest.mock import MagicMock, patch import numpy as np @@ -50,10 +51,9 @@ def test_custom_separator(self) -> None: class TestGenerateClusterProps: - def test_without_gpt(self, monkeypatch: pytest.MonkeyPatch) -> None: + def test_without_gpt(self) -> None: from embedding_cluster.scatter_plot import generate_cluster_props - settings = Settings() pred_arr = [0, 0, 1, 1, 2] text_display = ["a", "b", "c", "d", "e"] @@ -61,7 +61,6 @@ def test_without_gpt(self, monkeypatch: pytest.MonkeyPatch) -> None: num_clusters=3, pred_arr=pred_arr, collection_content_text_display=text_display, - settings=settings, ) assert len(clusters_indices) == 3 @@ -71,83 +70,6 @@ def test_without_gpt(self, monkeypatch: pytest.MonkeyPatch) -> None: assert clusters_indices[2] == [4] assert cluster_names == ["Group 1", "Group 2", "Group 3"] - def test_with_gpt(self, monkeypatch: pytest.MonkeyPatch) -> None: - from embedding_cluster.scatter_plot import generate_cluster_props - - monkeypatch.setenv("GPT_GENERATE_CLUSTER_NAME", "true") - settings = Settings() - pred_arr = [0, 0] - text_display = ["item1", "item2"] - - with patch( - "embedding_cluster.scatter_plot.gpt_get_cluster_name", - return_value="Cool Group", - ): - _clusters_indices, cluster_names = generate_cluster_props( - num_clusters=1, - pred_arr=pred_arr, - collection_content_text_display=text_display, - settings=settings, - ) - - assert cluster_names == ["Cool Group"] - - -class TestGptGetClusterName: - def test_gpt_get_cluster_name(self) -> None: - from embedding_cluster.scatter_plot import gpt_get_cluster_name - - settings = Settings() - - with patch("embedding_cluster.scatter_plot.OpenAI") as mock_openai_cls: - mock_client = MagicMock() - mock_openai_cls.return_value = mock_client - mock_completion = MagicMock() - mock_choice = MagicMock() - mock_choice.message.content = "Fashion Items" - mock_completion.choices = [mock_choice] - mock_client.chat.completions.create.return_value = mock_completion - - result = gpt_get_cluster_name("item1\nitem2", settings) - - assert result == "Fashion Items" - - def test_gpt_truncates_long_name(self) -> None: - from embedding_cluster.scatter_plot import gpt_get_cluster_name - - settings = Settings() - - with patch("embedding_cluster.scatter_plot.OpenAI") as mock_openai_cls: - mock_client = MagicMock() - mock_openai_cls.return_value = mock_client - mock_completion = MagicMock() - mock_choice = MagicMock() - mock_choice.message.content = "A" * 50 - mock_completion.choices = [mock_choice] - mock_client.chat.completions.create.return_value = mock_completion - - result = gpt_get_cluster_name("info", settings) - - assert len(result) == 32 # 30 chars + ".." - - def test_gpt_none_content(self) -> None: - from embedding_cluster.scatter_plot import gpt_get_cluster_name - - settings = Settings() - - with patch("embedding_cluster.scatter_plot.OpenAI") as mock_openai_cls: - mock_client = MagicMock() - mock_openai_cls.return_value = mock_client - mock_completion = MagicMock() - mock_choice = MagicMock() - mock_choice.message.content = None - mock_completion.choices = [mock_choice] - mock_client.chat.completions.create.return_value = mock_completion - - result = gpt_get_cluster_name("info", settings) - - assert result == "" - class TestLoadChromadbCollection: def test_load(self) -> None: @@ -690,8 +612,8 @@ def test_points_aligned_with_cluster_labels(self) -> None: endpoint to return items from the wrong cluster. """ result = self._run_compute(self._make_settings(), n_points=6) - points = result["points"] - labels = result["cluster_labels"] + points = cast("list[dict[str, object]]", result["points"]) + labels = cast("list[int]", result["cluster_labels"]) assert len(points) == len(labels) for i, (point, label) in enumerate(zip(points, labels, strict=True)): assert point["cluster"] == label, ( @@ -706,7 +628,7 @@ def test_points_preserve_original_id_order(self) -> None: index-based lookup in cluster-detail and sub-cluster endpoints. """ result = self._run_compute(self._make_settings(), n_points=6) - points = result["points"] + points = cast("list[dict[str, object]]", result["points"]) expected_ids = [str(i) for i in range(6)] actual_ids = [p["id"] for p in points] assert actual_ids == expected_ids diff --git a/tests/test_server_plot.py b/tests/test_server_plot.py index 4aa1d76..7f0bb25 100644 --- a/tests/test_server_plot.py +++ b/tests/test_server_plot.py @@ -241,9 +241,6 @@ async def test_compute_with_all_fields(app: FastAPI, mock_compute: None) -> None "num_clusters": 5, "text_display_fields": ["name", "description"], "image_field": "imageUrl", - "gpt_generate_cluster_name": True, - "gpt_default_model": "gpt-4", - "gpt_default_temperature": 0.7, }, ) From 283166730d1e4001442e8382c7a24e2fa1f43f03 Mon Sep 17 00:00:00 2001 From: aGallea Date: Mon, 30 Mar 2026 17:56:37 +0300 Subject: [PATCH 3/6] feat(ai): add AI cluster naming endpoints and tests --- embedding_cluster/server/app.py | 2 + embedding_cluster/server/models.py | 35 +++ embedding_cluster/server/routes/ai.py | 160 ++++++++++++ tests/test_server_ai.py | 351 ++++++++++++++++++++++++++ tests/test_settings.py | 11 - 5 files changed, 548 insertions(+), 11 deletions(-) create mode 100644 embedding_cluster/server/routes/ai.py create mode 100644 tests/test_server_ai.py diff --git a/embedding_cluster/server/app.py b/embedding_cluster/server/app.py index 7a352da..3784879 100644 --- a/embedding_cluster/server/app.py +++ b/embedding_cluster/server/app.py @@ -8,6 +8,7 @@ from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from embedding_cluster.server.routes.ai import router as ai_router from embedding_cluster.server.routes.annotations import ( router as annotations_router, ) @@ -43,6 +44,7 @@ def create_app() -> FastAPI: async def health_check() -> dict[str, str]: return {"status": "ok"} + app.include_router(ai_router) app.include_router(collections_router) app.include_router(csv_router) app.include_router(index_router) diff --git a/embedding_cluster/server/models.py b/embedding_cluster/server/models.py index 2a036df..56dcd22 100644 --- a/embedding_cluster/server/models.py +++ b/embedding_cluster/server/models.py @@ -246,3 +246,38 @@ class ClusterAnnotation(BaseModel): class AnnotationsResponse(BaseModel): job_id: str clusters: dict[str, ClusterAnnotation] + + +class AiNamingRequest(BaseModel): + job_id: str + cluster_indices: list[int] + api_key: str + model: str + base_url: str | None = None + temperature: float = 0.5 + + +class AiNamingResponse(BaseModel): + names: dict[str, str] + + +class AiSubClusterNamingRequest(BaseModel): + job_id: str + point_ids: list[str] + sub_cluster_labels: list[int] + api_key: str + model: str + base_url: str | None = None + temperature: float = 0.5 + parent_cluster_name: str | None = None + + +class AiTestConnectionRequest(BaseModel): + api_key: str + model: str + base_url: str | None = None + + +class AiTestConnectionResponse(BaseModel): + success: bool + error: str | None = None diff --git a/embedding_cluster/server/routes/ai.py b/embedding_cluster/server/routes/ai.py new file mode 100644 index 0000000..bd029fa --- /dev/null +++ b/embedding_cluster/server/routes/ai.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import logging +import random +from typing import Any, cast + +from fastapi import APIRouter, HTTPException + +from embedding_cluster.ai_naming import ( + get_cluster_name, + get_sub_cluster_name, +) +from embedding_cluster.ai_naming import ( + test_connection as ai_test_connection, +) +from embedding_cluster.server.models import ( + AiNamingRequest, + AiNamingResponse, + AiSubClusterNamingRequest, + AiTestConnectionRequest, + AiTestConnectionResponse, +) +from embedding_cluster.server.tasks import TaskStatus, task_registry + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/ai", tags=["ai"]) + +MAX_SAMPLE_ITEMS = 10 + + +def _get_completed_job(job_id: str) -> dict[str, Any]: + task = task_registry.get(job_id) + if task is None: + raise HTTPException(status_code=404, detail="Job not found") + if task.status != TaskStatus.COMPLETED: + raise HTTPException(status_code=409, detail="Job not completed") + return cast("dict[str, Any]", task.result) + + +def _get_item_names_for_cluster( + result: dict[str, Any], + cluster_index: int, +) -> list[str]: + points = cast("list[dict[str, Any]]", result["points"]) + cluster_labels = cast("list[int]", result["cluster_labels"]) + + cluster_point_indices = [ + i for i, label in enumerate(cluster_labels) if label == cluster_index + ] + + sample_indices = random.sample( + cluster_point_indices, + min(MAX_SAMPLE_ITEMS, len(cluster_point_indices)), + ) + + names: list[str] = [] + for idx in sample_indices: + point = points[idx] + metadata = cast("dict[str, Any]", point.get("metadata", {})) + name_parts = [str(v) for v in metadata.values()] + names.append( + ", ".join(name_parts) if name_parts else f"Item {idx}", + ) + + return names + + +def _get_item_names_for_sub_cluster( + result: dict[str, Any], + point_ids: list[str], + sub_cluster_labels: list[int], + sub_cluster_index: int, +) -> list[str]: + points = cast("list[dict[str, Any]]", result["points"]) + point_id_to_point = {cast("str", p["id"]): p for p in points} + + sub_indices = [ + i for i, label in enumerate(sub_cluster_labels) if label == sub_cluster_index + ] + + sample_indices = random.sample( + sub_indices, + min(MAX_SAMPLE_ITEMS, len(sub_indices)), + ) + + names: list[str] = [] + for idx in sample_indices: + pid = point_ids[idx] + point = point_id_to_point.get(pid) + if point: + metadata = cast("dict[str, Any]", point.get("metadata", {})) + name_parts = [str(v) for v in metadata.values()] + names.append( + ", ".join(name_parts) if name_parts else f"Item {pid}", + ) + else: + names.append(f"Item {pid}") + + return names + + +@router.post("/name-clusters", response_model=AiNamingResponse) +async def name_clusters(request: AiNamingRequest) -> AiNamingResponse: + result = _get_completed_job(request.job_id) + + names: dict[str, str] = {} + for cluster_index in request.cluster_indices: + item_names = _get_item_names_for_cluster(result, cluster_index) + name = get_cluster_name( + item_names=item_names, + api_key=request.api_key, + model=request.model, + base_url=request.base_url, + temperature=request.temperature, + ) + names[str(cluster_index)] = name + + return AiNamingResponse(names=names) + + +@router.post("/name-sub-clusters", response_model=AiNamingResponse) +async def name_sub_clusters( + request: AiSubClusterNamingRequest, +) -> AiNamingResponse: + result = _get_completed_job(request.job_id) + + unique_labels = sorted(set(request.sub_cluster_labels)) + names: dict[str, str] = {} + + for label in unique_labels: + item_names = _get_item_names_for_sub_cluster( + result, + request.point_ids, + request.sub_cluster_labels, + label, + ) + name = get_sub_cluster_name( + item_names=item_names, + api_key=request.api_key, + model=request.model, + base_url=request.base_url, + temperature=request.temperature, + parent_cluster_name=request.parent_cluster_name, + ) + names[str(label)] = name + + return AiNamingResponse(names=names) + + +@router.post("/test-connection", response_model=AiTestConnectionResponse) +async def test_connection( + request: AiTestConnectionRequest, +) -> AiTestConnectionResponse: + success, error = ai_test_connection( + api_key=request.api_key, + model=request.model, + base_url=request.base_url, + ) + return AiTestConnectionResponse(success=success, error=error) diff --git a/tests/test_server_ai.py b/tests/test_server_ai.py new file mode 100644 index 0000000..1c8ca1f --- /dev/null +++ b/tests/test_server_ai.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import status +from httpx import ASGITransport, AsyncClient + +if TYPE_CHECKING: + from fastapi import FastAPI + +from embedding_cluster.server.app import create_app +from embedding_cluster.server.tasks import TaskStatus, task_registry + + +@pytest.fixture +def app() -> FastAPI: + return create_app() + + +@pytest.fixture +def completed_job() -> str: + """Create a completed job in the task registry and return its ID.""" + task = task_registry.create() + task.status = TaskStatus.COMPLETED + task.result = { + "points": [ + { + "id": "p1", + "x": 1.0, + "y": 2.0, + "z": 3.0, + "cluster": 0, + "metadata": {"name": "Running Shoes"}, + }, + { + "id": "p2", + "x": 4.0, + "y": 5.0, + "z": 6.0, + "cluster": 0, + "metadata": {"name": "Basketball Sneakers"}, + }, + { + "id": "p3", + "x": 7.0, + "y": 8.0, + "z": 9.0, + "cluster": 1, + "metadata": {"name": "Summer Dress"}, + }, + ], + "cluster_labels": [0, 0, 1], + "clusters": [ + {"index": 0, "name": "Group 1", "color": "#ff0000", "count": 2}, + {"index": 1, "name": "Group 2", "color": "#00ff00", "count": 1}, + ], + "total_points": 3, + } + return task.job_id + + +def _mock_llm_response(content: str = "Athletic Footwear") -> MagicMock: + mock_response = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = content + mock_response.choices = [mock_choice] + return mock_response + + +class TestNameClusters: + @pytest.mark.asyncio + async def test_names_clusters_successfully( + self, app: FastAPI, completed_job: str + ) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Athletic Footwear"), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-clusters", + json={ + "job_id": completed_job, + "cluster_indices": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = cast("dict[str, object]", response.json()) + names = cast("dict[str, str]", data["names"]) + assert "0" in names + assert names["0"] == "Athletic Footwear" + + @pytest.mark.asyncio + async def test_names_multiple_clusters( + self, app: FastAPI, completed_job: str + ) -> None: + call_count = 0 + + def side_effect(**kwargs: object) -> MagicMock: + nonlocal call_count + call_count += 1 + if call_count == 1: + return _mock_llm_response("Athletic Footwear") + return _mock_llm_response("Fashion Dresses") + + with patch( + "embedding_cluster.ai_naming.litellm_completion", + side_effect=side_effect, + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-clusters", + json={ + "job_id": completed_job, + "cluster_indices": [0, 1], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = cast("dict[str, object]", response.json()) + names = cast("dict[str, str]", data["names"]) + assert "0" in names + assert "1" in names + + @pytest.mark.asyncio + async def test_job_not_found(self, app: FastAPI) -> None: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-clusters", + json={ + "job_id": "nonexistent-id", + "cluster_indices": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + @pytest.mark.asyncio + async def test_job_not_completed(self, app: FastAPI) -> None: + task = task_registry.create() + task.status = TaskStatus.RUNNING + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-clusters", + json={ + "job_id": task.job_id, + "cluster_indices": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_409_CONFLICT + + @pytest.mark.asyncio + async def test_passes_optional_params(self, app: FastAPI, completed_job: str) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Name"), + ) as mock_completion: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + await client.post( + "/api/ai/name-clusters", + json={ + "job_id": completed_job, + "cluster_indices": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + "base_url": "http://localhost:11434", + "temperature": 0.8, + }, + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["api_base"] == "http://localhost:11434" + assert call_kwargs["temperature"] == 0.8 + + +class TestNameSubClusters: + @pytest.mark.asyncio + async def test_names_sub_clusters_successfully( + self, app: FastAPI, completed_job: str + ) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Running Shoes"), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-sub-clusters", + json={ + "job_id": completed_job, + "point_ids": ["p1", "p2"], + "sub_cluster_labels": [0, 1], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = cast("dict[str, object]", response.json()) + names = cast("dict[str, str]", data["names"]) + assert "0" in names + assert "1" in names + + @pytest.mark.asyncio + async def test_with_parent_cluster_name( + self, app: FastAPI, completed_job: str + ) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Running Shoes"), + ) as mock_completion: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + await client.post( + "/api/ai/name-sub-clusters", + json={ + "job_id": completed_job, + "point_ids": ["p1"], + "sub_cluster_labels": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + "parent_cluster_name": "Athletic Footwear", + }, + ) + + call_kwargs = mock_completion.call_args[1] + system_msg = call_kwargs["messages"][0]["content"] + assert "Athletic Footwear" in system_msg + + @pytest.mark.asyncio + async def test_job_not_found(self, app: FastAPI) -> None: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/name-sub-clusters", + json={ + "job_id": "nonexistent-id", + "point_ids": ["p1"], + "sub_cluster_labels": [0], + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestTestConnection: + @pytest.mark.asyncio + async def test_successful_connection(self, app: FastAPI) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Hello"), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/test-connection", + json={ + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = cast("dict[str, object]", response.json()) + assert data["success"] is True + assert data["error"] is None + + @pytest.mark.asyncio + async def test_failed_connection(self, app: FastAPI) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + side_effect=Exception("Connection refused"), + ): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/test-connection", + json={ + "api_key": "test-key", + "model": "gpt-4o-mini", + }, + ) + + assert response.status_code == status.HTTP_200_OK + data = cast("dict[str, object]", response.json()) + assert data["success"] is False + assert data["error"] is not None + assert "Connection refused" in cast("str", data["error"]) + + @pytest.mark.asyncio + async def test_with_base_url(self, app: FastAPI) -> None: + with patch( + "embedding_cluster.ai_naming.litellm_completion", + return_value=_mock_llm_response("Hello"), + ) as mock_completion: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + await client.post( + "/api/ai/test-connection", + json={ + "api_key": "test-key", + "model": "gpt-4o-mini", + "base_url": "http://localhost:11434", + }, + ) + + call_kwargs = mock_completion.call_args[1] + assert call_kwargs["api_base"] == "http://localhost:11434" + + @pytest.mark.asyncio + async def test_missing_required_fields(self, app: FastAPI) -> None: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/api/ai/test-connection", + json={}, + ) + + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/tests/test_settings.py b/tests/test_settings.py index 8d7dfed..6befd5f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -36,12 +36,6 @@ def test_default_none_fields(self) -> None: assert s.image_field is None assert s.id_field is None - def test_default_gpt_settings(self) -> None: - s = Settings() - assert s.gpt_generate_cluster_name is False - assert s.gpt_default_model == "gpt-3.5-turbo" - assert s.gpt_default_temperature == pytest.approx(0.51) - def test_default_model_names(self) -> None: s = Settings() assert s.image_model_name == "openai/clip-vit-base-patch32" @@ -86,11 +80,6 @@ def test_text_display_fields_from_env(self, monkeypatch: pytest.MonkeyPatch) -> s = Settings() assert s.text_display_fields == ["productDisplayName"] - def test_boolean_field_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("GPT_GENERATE_CLUSTER_NAME", "true") - s = Settings() - assert s.gpt_generate_cluster_name is True - def test_start_end_lines_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("INDEX_START_LINE", "5") monkeypatch.setenv("INDEX_END_LINE", "100") From b4987c03755eb56bd584f5a4ed2b748510b01b1e Mon Sep 17 00:00:00 2001 From: aGallea Date: Mon, 30 Mar 2026 18:08:53 +0300 Subject: [PATCH 4/6] feat(frontend): add Settings page and remove GPT cluster naming UI --- frontend/src/App.tsx | 5 + frontend/src/api/ai.ts | 52 +++++ frontend/src/components/plot/PlotControls.tsx | 52 ----- frontend/src/pages/SettingsPage.tsx | 200 ++++++++++++++++++ frontend/src/types/index.ts | 48 ++++- 5 files changed, 302 insertions(+), 55 deletions(-) create mode 100644 frontend/src/api/ai.ts create mode 100644 frontend/src/pages/SettingsPage.tsx diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c97f470..d6c95a4 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -4,6 +4,7 @@ import { useEffect, useRef } from 'react' import HomePage from './pages/HomePage' import IndexPage from './pages/IndexPage' import PlotPage from './pages/PlotPage' +import SettingsPage from './pages/SettingsPage' import { usePlotStore } from './stores/plotStore' const queryClient = new QueryClient() @@ -35,6 +36,9 @@ function NavBar() { Plot + + Settings + @@ -86,6 +90,7 @@ export default function App() { } /> } /> } /> + } /> diff --git a/frontend/src/api/ai.ts b/frontend/src/api/ai.ts new file mode 100644 index 0000000..d0bab61 --- /dev/null +++ b/frontend/src/api/ai.ts @@ -0,0 +1,52 @@ +import type { AiNamingRequest, AiNamingResponse, AiTestConnectionRequest, AiTestConnectionResponse } from "../types"; +import { apiPost } from "./client"; + +const AI_SETTINGS_KEY = "ai-cluster-naming-settings"; + +export interface StoredAiSettings { + provider: string; + model: string; + apiKey: string; + baseUrl: string; + temperature: number; +} + +export const DEFAULT_AI_SETTINGS: StoredAiSettings = { + provider: "openai", + model: "gpt-4o-mini", + apiKey: "", + baseUrl: "", + temperature: 0.5, +}; + +export function loadAiSettings(): StoredAiSettings { + try { + const raw = localStorage.getItem(AI_SETTINGS_KEY); + if (!raw) return { ...DEFAULT_AI_SETTINGS }; + return { ...DEFAULT_AI_SETTINGS, ...JSON.parse(raw) } as StoredAiSettings; + } catch { + return { ...DEFAULT_AI_SETTINGS }; + } +} + +export function saveAiSettings(settings: StoredAiSettings): void { + localStorage.setItem(AI_SETTINGS_KEY, JSON.stringify(settings)); +} + +export async function testAiConnection( + request: AiTestConnectionRequest, +): Promise { + return apiPost("/ai/test-connection", request); +} + +export async function nameAiClusters( + request: AiNamingRequest, +): Promise { + return apiPost("/ai/name-clusters", request); +} + +export async function nameAiSubClusters( + request: AiNamingRequest, +): Promise { + return apiPost("/ai/name-sub-clusters", request); +} diff --git a/frontend/src/components/plot/PlotControls.tsx b/frontend/src/components/plot/PlotControls.tsx index c02d164..92a0e69 100644 --- a/frontend/src/components/plot/PlotControls.tsx +++ b/frontend/src/components/plot/PlotControls.tsx @@ -36,10 +36,6 @@ export default function PlotControls({ onCompute, isComputing }: PlotControlsPro const [numClusters, setNumClusters] = useState(10) const [textDisplayFields, setTextDisplayFields] = useState([]) const [imageField, setImageField] = useState('') - const [gptEnabled, setGptEnabled] = useState(false) - const [gptModel, setGptModel] = useState('gpt-3.5-turbo') - const [gptTemperature, setGptTemperature] = useState(0.51) - const { renderMode, setRenderMode, pointSize, setPointSize, reductionAlgorithm, setReductionAlgorithm, @@ -100,9 +96,6 @@ export default function PlotControls({ onCompute, isComputing }: PlotControlsPro num_clusters: numClusters, text_display_fields: textDisplayFields, image_field: imageField || undefined, - gpt_generate_cluster_name: gptEnabled, - gpt_default_model: gptEnabled ? gptModel : undefined, - gpt_default_temperature: gptEnabled ? gptTemperature : undefined, reduction_algorithm: reductionAlgorithm, ...(reductionAlgorithm === 'tsne' && { tsne_perplexity: tsnePerplexity, @@ -320,51 +313,6 @@ export default function PlotControls({ onCompute, isComputing }: PlotControlsPro - {/* GPT Settings */} - -
-
- setGptEnabled(e.target.checked)} - className="mr-2" - /> - -
- - {gptEnabled && ( -
-
- - setGptModel(e.target.value)} - className="w-full border border-gray-300 rounded px-2 py-1 text-sm" - /> -
-
- - setGptTemperature(Number(e.target.value))} - className="w-full border border-gray-300 rounded px-2 py-1 text-sm" - /> -
-
- )} -
-
- {/* Rendering (Render Mode + Point Size) */}
diff --git a/frontend/src/pages/SettingsPage.tsx b/frontend/src/pages/SettingsPage.tsx new file mode 100644 index 0000000..7449fd9 --- /dev/null +++ b/frontend/src/pages/SettingsPage.tsx @@ -0,0 +1,200 @@ +import { useState, useEffect } from 'react'; +import { + loadAiSettings, + saveAiSettings, + testAiConnection, + StoredAiSettings, + DEFAULT_AI_SETTINGS, +} from '../api/ai'; + +export default function SettingsPage() { + const [settings, setSettings] = useState(DEFAULT_AI_SETTINGS); + const [isSaved, setIsSaved] = useState(false); + const [testStatus, setTestStatus] = useState<'idle' | 'testing' | 'success' | 'error'>('idle'); + const [testMessage, setTestMessage] = useState(''); + const [showApiKey, setShowApiKey] = useState(false); + + useEffect(() => { + setSettings(loadAiSettings()); + }, []); + + const handleChange = (field: keyof StoredAiSettings, value: string | number) => { + setSettings((prev) => ({ ...prev, [field]: value })); + setIsSaved(false); + setTestStatus('idle'); + }; + + const handleSave = () => { + saveAiSettings(settings); + setIsSaved(true); + setTimeout(() => setIsSaved(false), 3000); + }; + + const handleTestConnection = async () => { + setTestStatus('testing'); + setTestMessage(''); + try { + const result = await testAiConnection({ + api_key: settings.apiKey, + model: settings.model, + base_url: settings.baseUrl || undefined, + }); + + if (result.success) { + setTestStatus('success'); + setTestMessage('Connection successful!'); + } else { + setTestStatus('error'); + setTestMessage(result.error || 'Connection failed.'); + } + } catch (err: unknown) { + setTestStatus('error'); + const msg = err instanceof Error ? err.message : 'Unknown error occurred.'; + setTestMessage(`Failed to test connection: ${msg}`); + } + }; + + return ( +
+
+

AI Settings

+

+ Configure AI provider for cluster naming +

+
+ +
+
+
+ + handleChange('provider', e.target.value)} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" + placeholder="e.g. openai" + /> +
+ +
+ + handleChange('model', e.target.value)} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" + placeholder="e.g. gpt-4o-mini" + /> +
+ +
+ +
+ handleChange('apiKey', e.target.value)} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border pr-10" + placeholder="sk-..." + /> + +
+
+ +
+ + handleChange('baseUrl', e.target.value)} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" + placeholder="e.g. https://api.openai.com/v1" + /> +
+ +
+ + handleChange('temperature', parseFloat(e.target.value))} + className="w-full border-gray-300 rounded-md shadow-sm focus:border-blue-500 focus:ring-blue-500 sm:text-sm p-2 border" + /> +
+ +
+
+ + + {testStatus === 'success' && ( + + + + + {testMessage} + + )} + + {testStatus === 'error' && ( + + {testMessage} + + )} +
+ +
+ {isSaved && ( + + Settings saved! + + )} + +
+
+
+
+
+ ); +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 75eec4c..a05c3fe 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -65,9 +65,6 @@ export interface PlotRequest { num_clusters?: number text_display_fields?: string[] image_field?: string - gpt_generate_cluster_name?: boolean - gpt_default_model?: string - gpt_default_temperature?: number reduction_algorithm?: ReductionAlgorithm tsne_perplexity?: number tsne_learning_rate?: string @@ -184,6 +181,7 @@ export interface SubClusterInfo { index: number; count: number; color: string; + name?: string; } export interface SubClusterResponse { @@ -235,3 +233,47 @@ export interface AnnotationsResponse { job_id: string; clusters: Record; } + +// AI Naming +export interface AiSettings { + provider: string; + model: string; + apiKey: string; + baseUrl: string; + temperature: number; +} + +export interface AiNamingRequest { + job_id: string; + cluster_indices: number[]; + api_key: string; + model: string; + base_url?: string; + temperature?: number; +} + +export interface AiNamingResponse { + names: Record; +} + +export interface AiSubClusterNamingRequest { + job_id: string; + point_ids: string[]; + sub_cluster_labels: number[]; + api_key: string; + model: string; + base_url?: string; + temperature?: number; + parent_cluster_name?: string; +} + +export interface AiTestConnectionRequest { + api_key: string; + model: string; + base_url?: string; +} + +export interface AiTestConnectionResponse { + success: boolean; + error: string | null; +} From 8d4a609c1ed787089be85fb901402722c4fe7daf Mon Sep 17 00:00:00 2001 From: aGallea Date: Mon, 30 Mar 2026 18:23:39 +0300 Subject: [PATCH 5/6] feat(frontend): add AI cluster naming integration to legend and drawer --- frontend/src/api/ai.ts | 10 ++- .../components/plot/ClusterDetailDrawer.tsx | 54 ++++++++++++- .../src/components/plot/ClusterLegend.tsx | 81 +++++++++++++++++-- frontend/src/stores/plotStore.ts | 26 ++++++ 4 files changed, 158 insertions(+), 13 deletions(-) diff --git a/frontend/src/api/ai.ts b/frontend/src/api/ai.ts index d0bab61..59f57a9 100644 --- a/frontend/src/api/ai.ts +++ b/frontend/src/api/ai.ts @@ -1,4 +1,10 @@ -import type { AiNamingRequest, AiNamingResponse, AiTestConnectionRequest, AiTestConnectionResponse } from "../types"; +import type { + AiNamingRequest, + AiNamingResponse, + AiSubClusterNamingRequest, + AiTestConnectionRequest, + AiTestConnectionResponse, +} from "../types"; import { apiPost } from "./client"; const AI_SETTINGS_KEY = "ai-cluster-naming-settings"; @@ -46,7 +52,7 @@ export async function nameAiClusters( } export async function nameAiSubClusters( - request: AiNamingRequest, + request: AiSubClusterNamingRequest, ): Promise { return apiPost("/ai/name-sub-clusters", request); } diff --git a/frontend/src/components/plot/ClusterDetailDrawer.tsx b/frontend/src/components/plot/ClusterDetailDrawer.tsx index ab7c956..7db58b3 100644 --- a/frontend/src/components/plot/ClusterDetailDrawer.tsx +++ b/frontend/src/components/plot/ClusterDetailDrawer.tsx @@ -8,7 +8,8 @@ import { subCluster, subClusterByPointIds, } from '../../api/plot' -import type { ClusterDetailResponse, SuggestKResponse } from '../../types' +import { loadAiSettings, nameAiSubClusters } from '../../api/ai' +import type { ClusterDetailResponse, SubClusterResponse, SuggestKResponse } from '../../types' import SelectedPointsDistancePanel from './SelectedPointsDistancePanel' interface ClusterDetailDrawerProps { @@ -35,6 +36,9 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail const drillIntoSubCluster = usePlotStore((s) => s.drillIntoSubCluster) const setIsLoadingDrill = usePlotStore((s) => s.setIsLoadingDrill) const isLoadingDrill = usePlotStore((s) => s.isLoadingDrill) + const isNamingSubClusters = usePlotStore((s) => s.isNamingSubClusters) + const setIsNamingSubClusters = usePlotStore((s) => s.setIsNamingSubClusters) + const updateSubClusterNames = usePlotStore((s) => s.updateSubClusterNames) const [page, setPage] = useState(1) const [isEditingName, setIsEditingName] = useState(false) @@ -100,6 +104,35 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail setExpandedSubClusterIndex(null) }, [currentLevel?.label]) + const autoNameSubClusters = useCallback(async ( + data: SubClusterResponse, + parentName: string | undefined, + ) => { + const settings = loadAiSettings() + if (!settings.apiKey || !jobId) return + + setIsNamingSubClusters(true) + try { + const pointIds = data.points.map((p) => p.id) + const labels = data.points.map((p) => p.sub_cluster) + const response = await nameAiSubClusters({ + job_id: jobId, + point_ids: pointIds, + sub_cluster_labels: labels, + api_key: settings.apiKey, + model: settings.provider ? `${settings.provider}/${settings.model}` : settings.model, + base_url: settings.baseUrl || undefined, + temperature: settings.temperature, + parent_cluster_name: parentName, + }) + updateSubClusterNames(response.names) + } catch { + // AI naming is best-effort; failures are silent + } finally { + setIsNamingSubClusters(false) + } + }, [jobId, setIsNamingSubClusters, updateSubClusterNames]) + const handlePageChange = useCallback((newPage: number) => { if (clusterIndex == null) return setPage(newPage) @@ -224,12 +257,16 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail }) setSelectedSubClusterIndex(null) drillIntoSubCluster(selectedSubCluster, data) + const parentLabel = currentLevel.label + autoNameSubClusters(data, parentLabel) } else { const data = await subCluster(jobId, clusterIndex, { num_sub_clusters: subClusterK, }) setSelectedSubClusterIndex(null) drillIntoCluster(clusterIndex, data) + const parentName = annotation?.name ?? cluster?.name + autoNameSubClusters(data, parentName) } } catch { setIsLoadingDrill(false) @@ -245,6 +282,9 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail setIsLoadingDrill, drillIntoCluster, drillIntoSubCluster, + autoNameSubClusters, + annotation, + cluster, ]) const handleToggleSubClusterExpanded = useCallback((subClusterIndex: number) => { @@ -390,7 +430,15 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail className="px-4 py-2 border-b border-gray-200 flex-1 min-h-0 overflow-y-auto" data-testid="drawer-subcluster-list" > -
Current sub-clusters
+
+ Current sub-clusters + {isNamingSubClusters && ( +
+
+ Naming... +
+ )} +
{subClusters.map((sc) => { const scColor = CLUSTER_COLORS[sc.index % CLUSTER_COLORS.length] @@ -421,7 +469,7 @@ export default function ClusterDetailDrawer({ jobId, imageField }: ClusterDetail style={{ backgroundColor: scColor }} /> - Sub {sc.index} + {sc.name || `Sub ${sc.index}`} {sc.count} pts diff --git a/frontend/src/components/plot/ClusterLegend.tsx b/frontend/src/components/plot/ClusterLegend.tsx index 9f82a50..f836cc5 100644 --- a/frontend/src/components/plot/ClusterLegend.tsx +++ b/frontend/src/components/plot/ClusterLegend.tsx @@ -1,5 +1,7 @@ -import { useCallback } from 'react' +import { useCallback, useState } from 'react' import { usePlotStore, CLUSTER_COLORS } from '../../stores/plotStore' +import { loadAiSettings, nameAiClusters } from '../../api/ai' +import { updateAnnotation, getAnnotations } from '../../api/plot' export default function ClusterLegend() { const plotData = usePlotStore((state) => state.plotData) @@ -14,8 +16,14 @@ export default function ClusterLegend() { const setSelectedCluster = usePlotStore((state) => state.setSelectedCluster) const selectedCluster = usePlotStore((state) => state.selectedCluster) const annotations = usePlotStore((state) => state.annotations) + const setAnnotations = usePlotStore((state) => state.setAnnotations) const drillPath = usePlotStore((state) => state.drillPath) const isLoadingDrill = usePlotStore((state) => state.isLoadingDrill) + const isNamingClusters = usePlotStore((state) => state.isNamingClusters) + const setIsNamingClusters = usePlotStore((state) => state.setIsNamingClusters) + const plotJobId = usePlotStore((state) => state.plotJobId) + + const [namingError, setNamingError] = useState(null) const handleClickCluster = useCallback( (clusterIndex: number, isSelected: boolean) => { @@ -24,6 +32,42 @@ export default function ClusterLegend() { [setSelectedCluster], ) + const handleNameWithAi = useCallback(async () => { + if (!plotData || !plotJobId || isNamingClusters) return + + const settings = loadAiSettings() + if (!settings.apiKey) { + setNamingError('Configure AI settings first (Settings page)') + return + } + + setIsNamingClusters(true) + setNamingError(null) + + try { + const clusterIndices = plotData.clusters.map((c) => c.index) + const response = await nameAiClusters({ + job_id: plotJobId, + cluster_indices: clusterIndices, + api_key: settings.apiKey, + model: settings.provider ? `${settings.provider}/${settings.model}` : settings.model, + base_url: settings.baseUrl || undefined, + temperature: settings.temperature, + }) + + for (const [indexStr, name] of Object.entries(response.names)) { + await updateAnnotation(plotJobId, Number(indexStr), { name }) + } + + const updated = await getAnnotations(plotJobId) + setAnnotations(updated) + } catch (err) { + setNamingError(err instanceof Error ? err.message : 'AI naming failed') + } finally { + setIsNamingClusters(false) + } + }, [plotData, plotJobId, isNamingClusters, setIsNamingClusters, setAnnotations]) + if (!plotData) return null const handleShowAll = () => resetVisibleClusters(plotData.clusters.length) @@ -56,12 +100,33 @@ export default function ClusterLegend() { Show All ) : ( - + <> + + + {namingError && ( + {namingError} + )} + )} {isLoadingDrill && (
@@ -118,7 +183,7 @@ export default function ClusterLegend() { />
- Sub {sc.index} + {sc.name || `Sub ${sc.index}`}
{sc.count} points diff --git a/frontend/src/stores/plotStore.ts b/frontend/src/stores/plotStore.ts index c0f144e..2b19d74 100644 --- a/frontend/src/stores/plotStore.ts +++ b/frontend/src/stores/plotStore.ts @@ -27,6 +27,8 @@ interface PlotState { drillPath: DrillLevel[] subClusterColorMap: Map | null isLoadingDrill: boolean + isNamingClusters: boolean + isNamingSubClusters: boolean imageField: string | null plotJobId: string | null plotCollectionName: string | null @@ -61,6 +63,9 @@ interface PlotState { navigateBack: () => void resetDrill: () => void setIsLoadingDrill: (loading: boolean) => void + setIsNamingClusters: (loading: boolean) => void + setIsNamingSubClusters: (loading: boolean) => void + updateSubClusterNames: (names: Record) => void isolateCluster: (index: number) => void toggleSubCluster: (index: number) => void isolateSubCluster: (index: number) => void @@ -121,6 +126,8 @@ export const usePlotStore = create((set) => ({ drillPath: [], subClusterColorMap: null, isLoadingDrill: false, + isNamingClusters: false, + isNamingSubClusters: false, imageField: null, plotJobId: null, plotCollectionName: null, @@ -247,6 +254,25 @@ export const usePlotStore = create((set) => ({ setIsLoadingDrill: (loading) => set({ isLoadingDrill: loading }), + setIsNamingClusters: (loading) => set({ isNamingClusters: loading }), + setIsNamingSubClusters: (loading) => set({ isNamingSubClusters: loading }), + + updateSubClusterNames: (names) => + set((state) => { + if (state.drillPath.length === 0) return {} + const newPath = [...state.drillPath] + const currentLevel = { ...newPath[newPath.length - 1] } + currentLevel.subClusterData = { + ...currentLevel.subClusterData, + sub_clusters: currentLevel.subClusterData.sub_clusters.map((sc) => ({ + ...sc, + name: names[String(sc.index)] ?? sc.name, + })), + } + newPath[newPath.length - 1] = currentLevel + return { drillPath: newPath } + }), + isolateCluster: (index) => set(() => ({ visibleClusters: new Set([index]) })), From 97543c7c644381517ec7c2c55f1a85a398941672 Mon Sep 17 00:00:00 2001 From: aGallea Date: Tue, 31 Mar 2026 10:52:55 +0300 Subject: [PATCH 6/6] fix(frontend): use type-only import for StoredAiSettings interface --- frontend/src/pages/SettingsPage.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/pages/SettingsPage.tsx b/frontend/src/pages/SettingsPage.tsx index 7449fd9..1de00cf 100644 --- a/frontend/src/pages/SettingsPage.tsx +++ b/frontend/src/pages/SettingsPage.tsx @@ -1,9 +1,9 @@ import { useState, useEffect } from 'react'; +import type { StoredAiSettings } from '../api/ai'; import { loadAiSettings, saveAiSettings, testAiConnection, - StoredAiSettings, DEFAULT_AI_SETTINGS, } from '../api/ai';