From 19b6a337fcec6ee4cfae66f7f786ba12570909b7 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 14:26:55 +0000 Subject: [PATCH 01/15] Install optuna --- poetry.lock | 232 ++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 231 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6e8e7a5e..ec9d6845 100644 --- a/poetry.lock +++ b/poetry.lock @@ -182,6 +182,26 @@ files = [ {file = "alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e"}, ] +[[package]] +name = "alembic" +version = "1.18.1" +description = "A database migration tool for SQLAlchemy." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "alembic-1.18.1-py3-none-any.whl", hash = "sha256:f1c3b0920b87134e851c25f1f7f236d8a332c34b75416802d06971df5d1b7810"}, + {file = "alembic-1.18.1.tar.gz", hash = "sha256:83ac6b81359596816fb3b893099841a0862f2117b2963258e965d70dc62fb866"}, +] + +[package.dependencies] +Mako = "*" +SQLAlchemy = ">=1.4.0" +typing-extensions = ">=4.12" + +[package.extras] +tz = ["tzdata"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -710,7 +730,7 @@ version = "6.10.1" description = "Add colours to the output of Python's logging module." optional = false python-versions = ">=3.6" -groups = ["development"] +groups = ["main", "development"] files = [ {file = "colorlog-6.10.1-py3-none-any.whl", hash = "sha256:2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c"}, {file = "colorlog-6.10.1.tar.gz", hash = "sha256:eb4ae5cb65fe7fec7773c2306061a8e63e02efc2c72eba9d27b0fa23c94f1321"}, @@ -1533,6 +1553,69 @@ gitdb = ">=4.0.1,<5" doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] +[[package]] +name = "greenlet" +version = "3.3.0" +description = "Lightweight in-process concurrent programming" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\"" +files = [ + {file = "greenlet-3.3.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6f8496d434d5cb2dce025773ba5597f71f5410ae499d5dd9533e0653258cdb3d"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b96dc7eef78fd404e022e165ec55327f935b9b52ff355b067eb4a0267fc1cffb"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:73631cd5cccbcfe63e3f9492aaa664d278fda0ce5c3d43aeda8e77317e38efbd"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b299a0cb979f5d7197442dccc3aee67fce53500cd88951b7e6c35575701c980b"}, + {file = "greenlet-3.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dee147740789a4632cace364816046e43310b59ff8fb79833ab043aefa72fd5"}, + {file = "greenlet-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39b28e339fc3c348427560494e28d8a6f3561c8d2bcf7d706e1c624ed8d822b9"}, + {file = "greenlet-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b3c374782c2935cc63b2a27ba8708471de4ad1abaa862ffdb1ef45a643ddbb7d"}, + {file = "greenlet-3.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:b49e7ed51876b459bd645d83db257f0180e345d3f768a35a85437a24d5a49082"}, + {file = "greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45"}, + {file = "greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948"}, + {file = "greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794"}, + {file = "greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5"}, + {file = "greenlet-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:7652ee180d16d447a683c04e4c5f6441bae7ba7b17ffd9f6b3aff4605e9e6f71"}, + {file = "greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7"}, + {file = "greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b"}, + {file = "greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53"}, + {file = "greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614"}, + {file = "greenlet-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:a7a34b13d43a6b78abf828a6d0e87d3385680eaf830cd60d20d52f249faabf39"}, + {file = "greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492"}, + {file = "greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527"}, + {file = "greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39"}, + {file = "greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8"}, + {file = "greenlet-3.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:9ee1942ea19550094033c35d25d20726e4f1c40d59545815e1128ac58d416d38"}, + {file = "greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45"}, + {file = "greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955"}, + {file = "greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55"}, + {file = "greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc"}, + {file = "greenlet-3.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:73f51dd0e0bdb596fb0417e475fa3c5e32d4c83638296e560086b8d7da7c4170"}, + {file = "greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221"}, + {file = "greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b"}, + {file = "greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd"}, + {file = "greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9"}, + {file = "greenlet-3.3.0.tar.gz", hash = "sha256:a82bb225a4e9e4d653dd2fb7b8b2d36e4fb25bc0165422a11e48b88e9e6f78fb"}, +] + +[package.extras] +docs = ["Sphinx", "furo"] +test = ["objgraph", "psutil", "setuptools"] + [[package]] name = "h11" version = "0.16.0" @@ -2048,6 +2131,26 @@ cli = ["jsonargparse[signatures] (>=4.38.0)", "tomlkit"] docs = ["requests (>=2.0.0)"] typing = ["mypy (>=1.0.0)", "types-setuptools"] +[[package]] +name = "mako" +version = "1.3.10" +description = "A super-fast templating language that borrows the best ideas from the existing templating languages." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59"}, + {file = "mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28"}, +] + +[package.dependencies] +MarkupSafe = ">=0.9.2" + +[package.extras] +babel = ["Babel"] +lingua = ["lingua"] +testing = ["pytest"] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -2968,6 +3071,33 @@ files = [ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"}, ] +[[package]] +name = "optuna" +version = "4.7.0" +description = "A hyperparameter optimization framework" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "optuna-4.7.0-py3-none-any.whl", hash = "sha256:e41ec84018cecc10eabf28143573b1f0bde0ba56dba8151631a590ecbebc1186"}, + {file = "optuna-4.7.0.tar.gz", hash = "sha256:d91817e2079825557bd2e97de2e8c9ae260bfc99b32712502aef8a5095b2d2c0"}, +] + +[package.dependencies] +alembic = ">=1.5.0" +colorlog = "*" +numpy = "*" +packaging = ">=20.0" +PyYAML = "*" +sqlalchemy = ">=1.4.2" +tqdm = "*" + +[package.extras] +checking = ["mypy", "mypy_boto3_s3", "ruff", "scipy-stubs ; python_version >= \"3.10\"", "types-PyYAML", "types-redis", "types-setuptools", "types-tqdm", "typing_extensions (>=3.10.0.0)"] +document = ["ase", "cmaes (>=0.12.0)", "fvcore", "kaleido (<0.4)", "lightgbm", "matplotlib (!=3.6.0)", "pandas", "pillow", "plotly (>=4.9.0)", "scikit-learn", "sphinx", "sphinx-copybutton", "sphinx-gallery", "sphinx-notfound-page", "sphinx_rtd_theme (>=1.2.0)", "torch", "torchvision"] +optional = ["boto3", "cmaes (>=0.12.0)", "google-cloud-storage", "greenlet", "grpcio", "matplotlib (!=3.6.0)", "pandas", "plotly (>=4.9.0)", "protobuf (>=5.28.1)", "redis", "scikit-learn (>=0.24.2)", "scipy", "torch"] +test = ["fakeredis[lua]", "greenlet", "grpcio", "kaleido (<0.4)", "moto", "protobuf (>=5.28.1)", "pytest", "pytest-xdist", "scipy (>=1.9.2)", "torch"] + [[package]] name = "packaging" version = "26.0" @@ -5145,6 +5275,104 @@ lint = ["mypy", "ruff (==0.5.5)", "types-docutils"] standalone = ["Sphinx (>=5)"] test = ["pytest"] +[[package]] +name = "sqlalchemy" +version = "2.0.46" +description = "Database Abstraction Library" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "sqlalchemy-2.0.46-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:895296687ad06dc9b11a024cf68e8d9d3943aa0b4964278d2553b86f1b267735"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab65cb2885a9f80f979b85aa4e9c9165a31381ca322cbde7c638fe6eefd1ec39"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52fe29b3817bd191cc20bad564237c808967972c97fa683c04b28ec8979ae36f"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:09168817d6c19954d3b7655da6ba87fcb3a62bb575fb396a81a8b6a9fadfe8b5"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:be6c0466b4c25b44c5d82b0426b5501de3c424d7a3220e86cd32f319ba56798e"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-win32.whl", hash = "sha256:1bc3f601f0a818d27bfe139f6766487d9c88502062a2cd3a7ee6c342e81d5047"}, + {file = "sqlalchemy-2.0.46-cp310-cp310-win_amd64.whl", hash = "sha256:e0c05aff5c6b1bb5fb46a87e0f9d2f733f83ef6cbbbcd5c642b6c01678268061"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:261c4b1f101b4a411154f1da2b76497d73abbfc42740029205d4d01fa1052684"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:181903fe8c1b9082995325f1b2e84ac078b1189e2819380c2303a5f90e114a62"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:590be24e20e2424a4c3c1b0835e9405fa3d0af5823a1a9fc02e5dff56471515f"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7568fe771f974abadce52669ef3a03150ff03186d8eb82613bc8adc435a03f01"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf7e1e78af38047e08836d33502c7a278915698b7c2145d045f780201679999"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-win32.whl", hash = "sha256:9d80ea2ac519c364a7286e8d765d6cd08648f5b21ca855a8017d9871f075542d"}, + {file = "sqlalchemy-2.0.46-cp311-cp311-win_amd64.whl", hash = "sha256:585af6afe518732d9ccd3aea33af2edaae4a7aa881af5d8f6f4fe3a368699597"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a9a72b0da8387f15d5810f1facca8f879de9b85af8c645138cba61ea147968c"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2347c3f0efc4de367ba00218e0ae5c4ba2306e47216ef80d6e31761ac97cb0b9"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9094c8b3197db12aa6f05c51c05daaad0a92b8c9af5388569847b03b1007fb1b"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:37fee2164cf21417478b6a906adc1a91d69ae9aba8f9533e67ce882f4bb1de53"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b1e14b2f6965a685c7128bd315e27387205429c2e339eeec55cb75ca4ab0ea2e"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-win32.whl", hash = "sha256:412f26bb4ba942d52016edc8d12fb15d91d3cd46b0047ba46e424213ad407bcb"}, + {file = "sqlalchemy-2.0.46-cp312-cp312-win_amd64.whl", hash = "sha256:ea3cd46b6713a10216323cda3333514944e510aa691c945334713fca6b5279ff"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:93a12da97cca70cea10d4b4fc602589c4511f96c1f8f6c11817620c021d21d00"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af865c18752d416798dae13f83f38927c52f085c52e2f32b8ab0fef46fdd02c2"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8d679b5f318423eacb61f933a9a0f75535bfca7056daeadbf6bd5bcee6183aee"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64901e08c33462acc9ec3bad27fc7a5c2b6491665f2aa57564e57a4f5d7c52ad"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8ac45e8f4eaac0f9f8043ea0e224158855c6a4329fd4ee37c45c61e3beb518e"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-win32.whl", hash = "sha256:8d3b44b3d0ab2f1319d71d9863d76eeb46766f8cf9e921ac293511804d39813f"}, + {file = "sqlalchemy-2.0.46-cp313-cp313-win_amd64.whl", hash = "sha256:77f8071d8fbcbb2dd11b7fd40dedd04e8ebe2eb80497916efedba844298065ef"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1e8cc6cc01da346dc92d9509a63033b9b1bda4fed7a7a7807ed385c7dccdc10"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:96c7cca1a4babaaf3bfff3e4e606e38578856917e52f0384635a95b226c87764"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b2a9f9aee38039cf4755891a1e50e1effcc42ea6ba053743f452c372c3152b1b"}, + {file = "sqlalchemy-2.0.46-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:db23b1bf8cfe1f7fda19018e7207b20cdb5168f83c437ff7e95d19e39289c447"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:56bdd261bfd0895452006d5316cbf35739c53b9bb71a170a331fa0ea560b2ada"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:33e462154edb9493f6c3ad2125931e273bbd0be8ae53f3ecd1c161ea9a1dd366"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9bcdce05f056622a632f1d44bb47dbdb677f58cad393612280406ce37530eb6d"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e84b09a9b0f19accedcbeff5c2caf36e0dd537341a33aad8d680336152dc34e"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4f52f7291a92381e9b4de9050b0a65ce5d6a763333406861e33906b8aa4906bf"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-win32.whl", hash = "sha256:70ed2830b169a9960193f4d4322d22be5c0925357d82cbf485b3369893350908"}, + {file = "sqlalchemy-2.0.46-cp314-cp314-win_amd64.whl", hash = "sha256:3c32e993bc57be6d177f7d5d31edb93f30726d798ad86ff9066d75d9bf2e0b6b"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4dafb537740eef640c4d6a7c254611dca2df87eaf6d14d6a5fca9d1f4c3fc0fa"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42a1643dc5427b69aca967dae540a90b0fbf57eaf248f13a90ea5930e0966863"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ff33c6e6ad006bbc0f34f5faf941cfc62c45841c64c0a058ac38c799f15b5ede"}, + {file = "sqlalchemy-2.0.46-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:82ec52100ec1e6ec671563bbd02d7c7c8d0b9e71a0723c72f22ecf52d1755330"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6ac245604295b521de49b465bab845e3afe6916bcb2147e5929c8041b4ec0545"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e6199143d51e3e1168bedd98cc698397404a8f7508831b81b6a29b18b051069"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:716be5bcabf327b6d5d265dbdc6213a01199be587224eb991ad0d37e83d728fd"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:6f827fd687fa1ba7f51699e1132129eac8db8003695513fcf13fc587e1bd47a5"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c805fa6e5d461329fa02f53f88c914d189ea771b6821083937e79550bf31fc19"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-win32.whl", hash = "sha256:3aac08f7546179889c62b53b18ebf1148b10244b3405569c93984b0388d016a7"}, + {file = "sqlalchemy-2.0.46-cp38-cp38-win_amd64.whl", hash = "sha256:0cc3117db526cad3e61074100bd2867b533e2c7dc1569e95c14089735d6fb4fe"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:90bde6c6b1827565a95fde597da001212ab436f1b2e0c2dcc7246e14db26e2a3"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94b1e5f3a5f1ff4f42d5daab047428cd45a3380e51e191360a35cef71c9a7a2a"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93bb0aae40b52c57fd74ef9c6933c08c040ba98daf23ad33c3f9893494b8d3ce"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4e2cc868b7b5208aec6c960950b7bb821f82c2fe66446c92ee0a571765e91a5"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:965c62be8256d10c11f8907e7a8d3e18127a4c527a5919d85fa87fd9ecc2cfdc"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-win32.whl", hash = "sha256:9397b381dcee8a2d6b99447ae85ea2530dcac82ca494d1db877087a13e38926d"}, + {file = "sqlalchemy-2.0.46-cp39-cp39-win_amd64.whl", hash = "sha256:4396c948d8217e83e2c202fbdcc0389cf8c93d2c1c5e60fa5c5a955eae0e64be"}, + {file = "sqlalchemy-2.0.46-py3-none-any.whl", hash = "sha256:f9c11766e7e7c0a2767dda5acb006a118640c9fc0a4104214b96269bfb78399e"}, + {file = "sqlalchemy-2.0.46.tar.gz", hash = "sha256:cf36851ee7219c170bb0793dbc3da3e80c582e04a5437bc601bfe8c85c9216d7"}, +] + +[package.dependencies] +greenlet = {version = ">=1", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +typing-extensions = ">=4.6.0" + +[package.extras] +aiomysql = ["aiomysql (>=0.2.0)", "greenlet (>=1)"] +aioodbc = ["aioodbc", "greenlet (>=1)"] +aiosqlite = ["aiosqlite", "greenlet (>=1)", "typing_extensions (!=3.10.0.1)"] +asyncio = ["greenlet (>=1)"] +asyncmy = ["asyncmy (>=0.2.3,!=0.2.4,!=0.2.6)", "greenlet (>=1)"] +mariadb-connector = ["mariadb (>=1.0.1,!=1.1.2,!=1.1.5,!=1.1.10)"] +mssql = ["pyodbc"] +mssql-pymssql = ["pymssql"] +mssql-pyodbc = ["pyodbc"] +mypy = ["mypy (>=0.910)"] +mysql = ["mysqlclient (>=1.4.0)"] +mysql-connector = ["mysql-connector-python"] +oracle = ["cx_oracle (>=8)"] +oracle-oracledb = ["oracledb (>=1.0.1)"] +postgresql = ["psycopg2 (>=2.7)"] +postgresql-asyncpg = ["asyncpg", "greenlet (>=1)"] +postgresql-pg8000 = ["pg8000 (>=1.29.1)"] +postgresql-psycopg = ["psycopg (>=3.0.7)"] +postgresql-psycopg2binary = ["psycopg2-binary"] +postgresql-psycopg2cffi = ["psycopg2cffi"] +postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] +pymysql = ["pymysql"] +sqlcipher = ["sqlcipher3_binary"] + [[package]] name = "starlette" version = "0.52.1" @@ -6402,4 +6630,4 @@ multiprocessing = ["pydantic", "ray"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "b48d33e2c6e66c3fa8aa5b42db423e8577a44e19090b766f65051f4b9587dde4" +content-hash = "205db54ef1c6b5ee0b5d9821e6bc9f10d020e674c71fe62a03a477c226ab0b47" diff --git a/pyproject.toml b/pyproject.toml index d6056275..c8a0b56a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ poetry = "^2.0.1" starlette = ">=0.49.1" pydantic = { version = ">=2.5", optional = true } wandb = "^0.24.0" +optuna = "^4.7.0" [tool.poetry.requires-plugins] poetry-plugin-export = ">=1.8" From 292c99f33408ebb4f27cb6cbb2a87535af8e8b25 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 14:46:05 +0000 Subject: [PATCH 02/15] First implementation draft --- drevalpy/experiment.py | 352 +++++++++++++++++++++++++++++------ drevalpy/models/drp_model.py | 29 +++ 2 files changed, 321 insertions(+), 60 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 635c1ea3..f31a2082 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -49,6 +49,7 @@ def drug_response_experiment( path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", hyperparameter_tuning=True, + n_trials: int = 20, final_model_on_full_data: bool = False, wandb_project: str | None = None, ) -> None: @@ -99,7 +100,9 @@ def drug_response_experiment( :param overwrite: whether to overwrite existing results :param path_data: path to the data directory, usually data/ :param model_checkpoint_dir: directory to save model checkpoints. If "TEMPORARY", a temporary directory is created. - :param hyperparameter_tuning: whether to run in debug mode - if False, only select first hyperparameter set + :param hyperparameter_tuning: whether to perform hyperparameter tuning. If False, uses the first hyperparameter + configuration from the search space. + :param n_trials: number of Bayesian optimization trials for hyperparameter tuning. Default is 20. :param final_model_on_full_data: if True, a final/production model is saved in the results directory. If hyperparameter_tuning is true, the final model is produced according to the hyperparameter tuning procedure which was evaluated in the nested cross validation. @@ -222,6 +225,7 @@ def drug_response_experiment( "metric": hpam_optimization_metric, "path_data": path_data, "model_checkpoint_dir": model_checkpoint_dir, + "n_trials": n_trials, } # During hyperparameter tuning, create separate wandb runs per trial if enabled @@ -383,6 +387,7 @@ def drug_response_experiment( test_mode=test_mode, val_ratio=0.1, hyperparameter_tuning=hyperparameter_tuning, + n_trials=n_trials, ) consolidate_single_drug_model_predictions( @@ -1148,53 +1153,168 @@ def train_and_evaluate( return results +def _deep_equal(a: Any, b: Any) -> bool: + """ + Compare two values for equality, handling nested structures. + + :param a: first value + :param b: second value + :returns: True if values are equal (including nested structures) + """ + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + return all(_deep_equal(ai, bi) for ai, bi in zip(a, b, strict=True)) + elif isinstance(a, dict) and isinstance(b, dict): + if set(a.keys()) != set(b.keys()): + return False + return all(_deep_equal(a[k], b[k]) for k in a.keys()) + else: + return a == b + + +def _sample_hyperparameters_from_search_space(trial, search_space: dict[str, Any]) -> dict[str, Any]: + """ + Sample hyperparameters from a search space definition using Optuna. + + :param trial: Optuna trial object + :param search_space: dictionary mapping parameter names to their search space definitions + :returns: dictionary of sampled hyperparameters + """ + sampled = {} + for param_name, param_def in search_space.items(): + if isinstance(param_def, dict) and "type" in param_def: + # Structured search space definition for continuous ranges + param_type = param_def["type"] + low = param_def["low"] + high = param_def["high"] + log_scale = param_def.get("log", False) + + if param_type == "int": + sampled[param_name] = trial.suggest_int(param_name, low, high, log=log_scale) + elif param_type == "float": + if log_scale: + sampled[param_name] = trial.suggest_float(param_name, low, high, log=True) + else: + sampled[param_name] = trial.suggest_float(param_name, low, high) + else: + raise ValueError(f"Unknown parameter type: {param_type}") + elif isinstance(param_def, list): + # Categorical choices + if len(param_def) == 1: + # Single value, no tuning needed + sampled[param_name] = param_def[0] + else: + sampled[param_name] = trial.suggest_categorical(param_name, param_def) + else: + # Single fixed value (not a list or dict) + sampled[param_name] = param_def + + return sampled + + def hpam_tune( model: DRPModel, train_dataset: DrugResponseDataset, validation_dataset: DrugResponseDataset, - hpam_set: list[dict], + hpam_set: list[dict] | dict[str, Any], early_stopping_dataset: DrugResponseDataset | None = None, response_transformation: TransformerMixin | None = None, metric: str = "RMSE", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + n_trials: int = 20, *, split_index: int | None = None, wandb_project: str | None = None, wandb_base_config: dict[str, Any] | None = None, ) -> dict: """ - Tune the hyperparameters for the given model in an iterative manner. + Tune hyperparameters using Bayesian optimization with Optuna. + + This function uses Optuna's TPE (Tree-structured Parzen Estimator) sampler + for efficient hyperparameter search. Trials are run sequentially. :param model: model to use :param train_dataset: training dataset :param validation_dataset: validation dataset - :param hpam_set: hyperparameters to tune + :param hpam_set: either a search space dictionary (for Bayesian optimization) or + a list of hyperparameter configurations (legacy grid search format) :param early_stopping_dataset: early stopping dataset :param response_transformation: normalizer to use for the response data :param metric: metric to evaluate which model is the best :param path_data: path to the data directory, e.g., data/ :param model_checkpoint_dir: directory to save model checkpoints + :param n_trials: number of Bayesian optimization trials to run :param split_index: optional CV split index, used for naming wandb runs :param wandb_project: optional wandb project name; if provided, enables per-trial wandb runs :param wandb_base_config: optional base config dict to include in each wandb run :returns: best hyperparameters :raises AssertionError: if hpam_set is empty """ - if len(hpam_set) == 0: - raise AssertionError("hpam_set must contain at least one hyperparameter configuration") - if len(hpam_set) == 1: - return hpam_set[0] + import optuna + from optuna.samplers import TPESampler + + # Handle legacy list format (grid search) - convert to search space + if isinstance(hpam_set, list): + if len(hpam_set) == 0: + raise AssertionError("hpam_set must contain at least one hyperparameter configuration") + if len(hpam_set) == 1: + return hpam_set[0] + + # Convert list of dicts to search space by extracting unique values per parameter + # Handle nested structures (like lists of lists) by using a list-based approach + search_space: dict[str, Any] = {} + all_keys = set() + for config in hpam_set: + all_keys.update(config.keys()) + + for key in all_keys: + # Collect all values for this key, preserving order and handling unhashable types + values = [] + seen = [] + for config in hpam_set: + if key in config: + value = config.get(key) + # For unhashable types (lists, dicts), use deep comparison + if isinstance(value, (list, dict)): + # Check if we've seen an equivalent value + if not any(_deep_equal(value, v) for v in seen): + values.append(value) + seen.append(value) + else: + # For hashable types, use set for deduplication + if value not in values: + values.append(value) + if len(values) == 1: + search_space[key] = values[0] + else: + search_space[key] = values + else: + search_space = hpam_set + + # Check if there's anything to tune + tunable_params = [ + k + for k, v in search_space.items() + if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + ] + if not tunable_params: + # No tuning needed, return fixed values + return {k: (v[0] if isinstance(v, list) else v) for k, v in search_space.items()} # Mark that we're in hyperparameter tuning phase - # This prevents updating wandb.config during tuning - we'll only log final best hyperparameters model._in_hyperparameter_tuning = True - best_hyperparameters = None mode = get_mode(metric) - best_score = float("inf") if mode == "min" else float("-inf") - for trial_idx, hyperparameter in enumerate(hpam_set): - print(f"Training model with hyperparameters: {hyperparameter}") + direction = "minimize" if mode == "min" else "maximize" + + def objective(trial): + # Sample hyperparameters + hyperparameter = _sample_hyperparameters_from_search_space(trial, search_space) + trial_idx = trial.number + + print(f"Trial {trial_idx}: Training model with hyperparameters: {hyperparameter}") # Create a separate wandb run for each hyperparameter trial if enabled if wandb_project is not None: @@ -1222,40 +1342,52 @@ def hpam_tune( finish_previous=True, ) - # During hyperparameter tuning, don't update wandb config via log_hyperparameters - # Trial hyperparameters are stored in wandb.config for each run - score = train_and_evaluate( - model=model, - hpams=hyperparameter, - path_data=path_data, - train_dataset=train_dataset, - validation_dataset=validation_dataset, - early_stopping_dataset=early_stopping_dataset, - metric=metric, - response_transformation=response_transformation, - model_checkpoint_dir=model_checkpoint_dir, - )[metric] + try: + score = train_and_evaluate( + model=model, + hpams=hyperparameter, + path_data=path_data, + train_dataset=train_dataset, + validation_dataset=validation_dataset, + early_stopping_dataset=early_stopping_dataset, + metric=metric, + response_transformation=response_transformation, + model_checkpoint_dir=model_checkpoint_dir, + )[metric] + + if np.isnan(score): + # Return a bad score for NaN results + score = float("inf") if mode == "min" else float("-inf") + else: + print(f"Trial {trial_idx}: {metric} = {np.round(score, 4)}") - # Note: train_and_evaluate() already logs val_* metrics once via - # DRPModel.compute_and_log_final_metrics(..., prefix="val_"). - # Avoid logging val_{metric} again here (it would create duplicate points). - if np.isnan(score): + except Exception as e: + print(f"Trial {trial_idx} failed: {e}") + score = float("inf") if mode == "min" else float("-inf") + + finally: if model.is_wandb_enabled(): model.finish_wandb() - continue - if (mode == "min" and score < best_score) or (mode == "max" and score > best_score): - print(f"current best {metric} score: {np.round(score, 3)}") - best_score = score - best_hyperparameters = hyperparameter + return score + + # Create and run the Optuna study + study = optuna.create_study(direction=direction, sampler=TPESampler(seed=42)) + study.optimize(objective, n_trials=n_trials, show_progress_bar=True) + + # Get best hyperparameters + best_hyperparameters = study.best_params - # Close this trial's run after all logging is done - if model.is_wandb_enabled(): - model.finish_wandb() + # Fill in fixed parameters that weren't tuned + for key, value in search_space.items(): + if key not in best_hyperparameters: + if isinstance(value, list) and len(value) == 1: + best_hyperparameters[key] = value[0] + elif not isinstance(value, (list, dict)): + best_hyperparameters[key] = value - if best_hyperparameters is None: - warnings.warn("all hpams lead to NaN respone. using last hpam combination.", stacklevel=2) - best_hyperparameters = hyperparameter + print(f"\nBest trial: {study.best_trial.number}") + print(f"Best {metric}: {np.round(study.best_value, 4)}") return best_hyperparameters @@ -1265,50 +1397,132 @@ def hpam_tune_raytune( train_dataset: DrugResponseDataset, validation_dataset: DrugResponseDataset, early_stopping_dataset: DrugResponseDataset | None, - hpam_set: list[dict], + hpam_set: list[dict] | dict[str, Any], response_transformation: TransformerMixin | None = None, metric: str = "RMSE", ray_path: str = "raytune", path_data: str = "data", model_checkpoint_dir: str = "TEMPORARY", + n_trials: int = 20, ) -> dict: """ - Tune the hyperparameters for the given model using Ray Tune. Ray[tune] must be installed. + Tune hyperparameters using Bayesian optimization with Ray Tune and Optuna. + + This function uses Ray Tune with OptunaSearch for parallel Bayesian optimization. + Ray[tune] and optuna must be installed. :param model: model to use :param train_dataset: training dataset :param validation_dataset: validation dataset :param early_stopping_dataset: early stopping dataset - :param hpam_set: hyperparameters to tune + :param hpam_set: either a search space dictionary (for Bayesian optimization) or + a list of hyperparameter configurations (legacy grid search format) :param response_transformation: normalizer for response data :param metric: evaluation metric :param ray_path: path to the raytune directory :param path_data: path to data directory, e.g., data/ :param model_checkpoint_dir: directory for model checkpoints + :param n_trials: number of Bayesian optimization trials to run :returns: best hyperparameters :raises ValueError: if best_result is None """ - print("Starting hyperparameter tuning with Ray Tune ...") - print(f"Hyperparameter combinations to evaluate: {len(hpam_set)}") - print() - - if len(hpam_set) == 1: - return hpam_set[0] - import ray from ray import tune + from ray.tune.search.optuna import OptunaSearch + + print("Starting hyperparameter tuning with Ray Tune (Bayesian optimization) ...") + + # Handle legacy list format (grid search) - convert to search space + if isinstance(hpam_set, list): + if len(hpam_set) == 0: + raise AssertionError("hpam_set must contain at least one hyperparameter configuration") + if len(hpam_set) == 1: + return hpam_set[0] + + # Convert list of dicts to search space + search_space: dict[str, Any] = {} + all_keys = set() + for config in hpam_set: + all_keys.update(config.keys()) + + for key in all_keys: + # Collect all values for this key, preserving order and handling unhashable types + values = [] + seen = [] + for config in hpam_set: + if key in config: + value = config.get(key) + # For unhashable types (lists, dicts), use deep comparison + if isinstance(value, (list, dict)): + # Check if we've seen an equivalent value + if not any(_deep_equal(value, v) for v in seen): + values.append(value) + seen.append(value) + else: + # For hashable types, use set for deduplication + if value not in values: + values.append(value) + if len(values) == 1: + search_space[key] = values[0] + else: + search_space[key] = values + else: + search_space = hpam_set + + # Check if there's anything to tune + tunable_params = [ + k + for k, v in search_space.items() + if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + ] + if not tunable_params: + return {k: (v[0] if isinstance(v, list) else v) for k, v in search_space.items()} + + # Convert search space to Ray Tune format + ray_search_space = {} + fixed_params = {} + for param_name, param_def in search_space.items(): + if isinstance(param_def, dict) and "type" in param_def: + param_type = param_def["type"] + low = param_def["low"] + high = param_def["high"] + log_scale = param_def.get("log", False) + + if param_type == "int": + if log_scale: + ray_search_space[param_name] = tune.lograndint(low, high) + else: + ray_search_space[param_name] = tune.randint(low, high + 1) + elif param_type == "float": + if log_scale: + ray_search_space[param_name] = tune.loguniform(low, high) + else: + ray_search_space[param_name] = tune.uniform(low, high) + elif isinstance(param_def, list): + if len(param_def) == 1: + fixed_params[param_name] = param_def[0] + else: + ray_search_space[param_name] = tune.choice(param_def) + else: + fixed_params[param_name] = param_def + + print(f"Tunable parameters: {list(ray_search_space.keys())}") + print(f"Fixed parameters: {list(fixed_params.keys())}") + print(f"Number of trials: {n_trials}") + print() path_data = os.path.abspath(path_data) if not ray.is_initialized(): ray.init(_temp_dir=os.path.join(os.path.expanduser("~"), "raytmp")) resources_per_trial = {"gpu": 1} if torch.cuda.is_available() else {"cpu": 1} - def trainable(hpams): + def trainable(config): try: - inner = hpams["hpams"] + # Merge sampled params with fixed params + hyperparameter = {**fixed_params, **config} result = train_and_evaluate( model=model, - hpams=inner, + hpams=hyperparameter, path_data=path_data, train_dataset=train_dataset, validation_dataset=validation_dataset, @@ -1317,35 +1531,50 @@ def trainable(hpams): response_transformation=response_transformation, model_checkpoint_dir=model_checkpoint_dir, ) - tune.report(metrics={metric: result[metric]}) + return {metric: result[metric]} except Exception as e: import traceback print("Trial failed:", e) traceback.print_exc() + # Return bad score on failure + mode = get_mode(metric) + return {metric: float("inf") if mode == "min" else float("-inf")} trainable = tune.with_resources(trainable, resources_per_trial) - param_space = {"hpams": tune.grid_search(hpam_set)} + + mode = get_mode(metric) + optuna_search = OptunaSearch(metric=metric, mode=mode, seed=42) tuner = tune.Tuner( trainable, - param_space=param_space, + param_space=ray_search_space, run_config=tune.RunConfig( storage_path=ray_path, name="hpam_tuning", ), tune_config=tune.TuneConfig( metric=metric, - mode=get_mode(metric), + mode=mode, + search_alg=optuna_search, + num_samples=n_trials, + max_concurrent_trials=1, # Run one at a time for Bayesian optimization ), ) results = tuner.fit() - best_result = results.get_best_result(metric=metric, mode=get_mode(metric)) + best_result = results.get_best_result(metric=metric, mode=mode) ray.shutdown() + if best_result.config is None: raise ValueError("Ray failed; no best result.") - return best_result.config["hpams"] + + # Merge best config with fixed params + best_hyperparameters = {**fixed_params, **best_result.config} + + print(f"\nBest {metric}: {np.round(best_result.metrics[metric], 4)}") + + return best_hyperparameters @pipeline_function @@ -1470,6 +1699,7 @@ def train_final_model( test_mode: str = "LCO", val_ratio: float = 0.1, hyperparameter_tuning: bool = True, + n_trials: int = 20, ) -> None: """ Final Production Model Training. @@ -1493,6 +1723,7 @@ def train_final_model( :param test_mode: split logic for validation (LCO, LDO, LTO, LPO) :param val_ratio: validation size ratio :param hyperparameter_tuning: whether to perform hyperparameter tuning + :param n_trials: number of Bayesian optimization trials for hyperparameter tuning """ print("Training final model with application-specific validation strategy ...") @@ -1524,6 +1755,7 @@ def train_final_model( metric=metric, path_data=path_data, model_checkpoint_dir=model_checkpoint_dir, + n_trials=n_trials, ) else: best_hpams = hpam_set[0] diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 7599be2e..2f580741 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -289,6 +289,35 @@ def get_hyperparameter_set(cls) -> list[dict[str, Any]]: grid = list(ParameterGrid(hpams)) return grid + @classmethod + @pipeline_function + def get_hyperparameter_search_space(cls) -> dict[str, Any]: + """ + Load the raw hyperparameter search space from a YAML file. + + This method returns the search space definition without expanding it into + all combinations. Useful for Bayesian optimization where we sample from + the space rather than enumerating all combinations. + + :returns: dictionary mapping parameter names to their search space definitions + :raises ValueError: if the hyperparameters are not in the correct format + :raises KeyError: if the model is not found in the hyperparameters file + """ + hyperparameter_file = os.path.join(os.path.dirname(inspect.getfile(cls)), "hyperparameters.yaml") + + with open(hyperparameter_file, encoding="utf-8") as f: + try: + hpams = yaml.safe_load(f)[cls.get_model_name()] + except yaml.YAMLError as exc: + raise ValueError(f"Error in hyperparameters.yaml: {exc}") from exc + except KeyError as key_exc: + raise KeyError(f"Model {cls.get_model_name()} not found in hyperparameters.yaml") from key_exc + + if hpams is None: + return {} + + return hpams + @property @abstractmethod def cell_line_views(self) -> list[str]: From 84fcf1a7292165f33767f6de32143b5cd648c572 Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 15:39:35 +0000 Subject: [PATCH 03/15] Fix docstrings --- drevalpy/experiment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index f31a2082..5f40c9d2 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1180,6 +1180,7 @@ def _sample_hyperparameters_from_search_space(trial, search_space: dict[str, Any :param trial: Optuna trial object :param search_space: dictionary mapping parameter names to their search space definitions :returns: dictionary of sampled hyperparameters + :raises ValueError: if an unknown parameter type is encountered in the search space """ sampled = {} for param_name, param_def in search_space.items(): @@ -1424,6 +1425,7 @@ def hpam_tune_raytune( :param model_checkpoint_dir: directory for model checkpoints :param n_trials: number of Bayesian optimization trials to run :returns: best hyperparameters + :raises AssertionError: if hpam_set is empty :raises ValueError: if best_result is None """ import ray From befd82ddec770cff94d9b79c78dc02f2e2574b7d Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 15:41:40 +0000 Subject: [PATCH 04/15] Pre-commit --- drevalpy/experiment.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 5f40c9d2..cebd6a70 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1296,9 +1296,7 @@ def hpam_tune( # Check if there's anything to tune tunable_params = [ - k - for k, v in search_space.items() - if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + k for k, v in search_space.items() if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) ] if not tunable_params: # No tuning needed, return fixed values @@ -1473,9 +1471,7 @@ def hpam_tune_raytune( # Check if there's anything to tune tunable_params = [ - k - for k, v in search_space.items() - if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) + k for k, v in search_space.items() if isinstance(v, (list, dict)) and (not isinstance(v, list) or len(v) > 1) ] if not tunable_params: return {k: (v[0] if isinstance(v, list) else v) for k, v in search_space.items()} From d0d47d45f685948173c8a70258b80f98fba194cc Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 15:49:48 +0000 Subject: [PATCH 05/15] MyPy --- drevalpy/experiment.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index cebd6a70..cdf3890c 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1266,17 +1266,17 @@ def hpam_tune( # Convert list of dicts to search space by extracting unique values per parameter # Handle nested structures (like lists of lists) by using a list-based approach search_space: dict[str, Any] = {} - all_keys = set() + all_keys: set[str] = set() for config in hpam_set: all_keys.update(config.keys()) for key in all_keys: # Collect all values for this key, preserving order and handling unhashable types - values = [] - seen = [] + values: list[Any] = [] + seen: list[Any] = [] for config in hpam_set: if key in config: - value = config.get(key) + value = config[key] # Use direct access since we know key exists # For unhashable types (lists, dicts), use deep comparison if isinstance(value, (list, dict)): # Check if we've seen an equivalent value @@ -1441,17 +1441,17 @@ def hpam_tune_raytune( # Convert list of dicts to search space search_space: dict[str, Any] = {} - all_keys = set() + all_keys: set[str] = set() for config in hpam_set: all_keys.update(config.keys()) for key in all_keys: # Collect all values for this key, preserving order and handling unhashable types - values = [] - seen = [] + values: list[Any] = [] + seen: list[Any] = [] for config in hpam_set: if key in config: - value = config.get(key) + value = config[key] # Use direct access since we know key exists # For unhashable types (lists, dicts), use deep comparison if isinstance(value, (list, dict)): # Check if we've seen an equivalent value From a23778d3a982991c7641c5ee701af479bc924e9a Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 15:58:39 +0000 Subject: [PATCH 06/15] Make continuous hyperparameter spaces work --- drevalpy/experiment.py | 12 +++++++++ drevalpy/models/drp_model.py | 47 +++++++++++++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index cdf3890c..480bc71a 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -1186,6 +1186,12 @@ def _sample_hyperparameters_from_search_space(trial, search_space: dict[str, Any for param_name, param_def in search_space.items(): if isinstance(param_def, dict) and "type" in param_def: # Structured search space definition for continuous ranges + if "default" not in param_def: + raise ValueError( + f"Hyperparameter '{param_name}' has continuous range definition " + f"but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) param_type = param_def["type"] low = param_def["low"] high = param_def["high"] @@ -1481,6 +1487,12 @@ def hpam_tune_raytune( fixed_params = {} for param_name, param_def in search_space.items(): if isinstance(param_def, dict) and "type" in param_def: + if "default" not in param_def: + raise ValueError( + f"Hyperparameter '{param_name}' has continuous range definition " + f"but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) param_type = param_def["type"] low = param_def["low"] high = param_def["high"] diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index 2f580741..cde01b67 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -282,11 +282,50 @@ def get_hyperparameter_set(cls) -> list[dict[str, Any]]: if hpams is None: return [{}] - # each param should be a list + # Convert continuous ranges to default values for grid expansion + # This handles the case when hyperparameter_tuning=False + processed_hpams: dict[str, Any] = {} for hp in hpams: - if not isinstance(hpams[hp], list): - hpams[hp] = [hpams[hp]] - grid = list(ParameterGrid(hpams)) + value = hpams[hp] + # If it's a continuous range definition, require and use the default value + if isinstance(value, dict) and "type" in value: + if "default" not in value: + raise ValueError( + f"Hyperparameter '{hp}' has continuous range definition but missing required 'default' field. " + f"Please add a 'default' value to use when hyperparameter_tuning=False." + ) + # Validate default is within range + low = value["low"] + high = value["high"] + default = value["default"] + param_type = value["type"] + + if param_type == "int": + if not isinstance(default, int): + raise ValueError( + f"Hyperparameter '{hp}': default must be an integer, got {type(default).__name__}" + ) + if default < low or default > high: + raise ValueError( + f"Hyperparameter '{hp}': default value {default} is outside range [{low}, {high}]" + ) + elif param_type == "float": + if not isinstance(default, (int, float)): + raise ValueError( + f"Hyperparameter '{hp}': default must be a float, got {type(default).__name__}" + ) + if default < low or default > high: + raise ValueError( + f"Hyperparameter '{hp}': default value {default} is outside range [{low}, {high}]" + ) + + processed_hpams[hp] = [default] + elif isinstance(value, list): + processed_hpams[hp] = value + else: + # Single value + processed_hpams[hp] = [value] + grid = list(ParameterGrid(processed_hpams)) return grid @classmethod From 655f073610c10e7dc09bd4e32686a64b5919c56e Mon Sep 17 00:00:00 2001 From: nictru Date: Thu, 22 Jan 2026 16:03:29 +0000 Subject: [PATCH 07/15] Fix continuous hyperparameter sampling --- drevalpy/experiment.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 480bc71a..e2d95f2b 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -176,8 +176,12 @@ def drug_response_experiment( ) parent_dir = os.path.dirname(predictions_path) - model_hpam_set = model_class.get_hyperparameter_set() - if not hyperparameter_tuning: + if hyperparameter_tuning: + # Use raw search space for Bayesian optimization + model_hpam_set = model_class.get_hyperparameter_search_space() + else: + # Use expanded grid and take first (default) configuration + model_hpam_set = model_class.get_hyperparameter_set() model_hpam_set = [model_hpam_set[0]] if response_data.cv_splits is None: @@ -1753,8 +1757,9 @@ def train_final_model( else: early_stopping_dataset = None - hpam_set = model.get_hyperparameter_set() if hyperparameter_tuning: + # Use raw search space for Bayesian optimization + hpam_set = model.get_hyperparameter_search_space() best_hpams = hpam_tune( model=model, train_dataset=train_dataset, @@ -1768,6 +1773,8 @@ def train_final_model( n_trials=n_trials, ) else: + # Use expanded grid and take first (default) configuration + hpam_set = model.get_hyperparameter_set() best_hpams = hpam_set[0] print(f"Best hyperparameters for final model: {best_hpams}") From bd5d0d2fdfd4a385576f989927a4492536dd616f Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 10:35:01 +0000 Subject: [PATCH 08/15] Add PCA featurizer --- .../featurizer/create_transcriptome_pca.py | 114 ++++++++++++++++++ tests/test_featurizers.py | 60 +++++++++ 2 files changed, 174 insertions(+) create mode 100644 drevalpy/datasets/featurizer/create_transcriptome_pca.py diff --git a/drevalpy/datasets/featurizer/create_transcriptome_pca.py b/drevalpy/datasets/featurizer/create_transcriptome_pca.py new file mode 100644 index 00000000..fd642577 --- /dev/null +++ b/drevalpy/datasets/featurizer/create_transcriptome_pca.py @@ -0,0 +1,114 @@ +"""Preprocesses transcriptome (gene expression) data using PCA dimensionality reduction.""" + +import argparse +from pathlib import Path + +import joblib +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER + + +def main(): + """Process transcriptome data and save PCA-transformed features. + + :raises FileNotFoundError: If the gene expression file is not found. + """ + parser = argparse.ArgumentParser(description="Preprocess transcriptome (gene expression) data using PCA.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument( + "--n_components", + type=int, + default=100, + help="Number of principal components to keep (default: 100)", + ) + parser.add_argument( + "--data_path", + type=str, + default="data", + help="Path to the data folder (default: data)", + ) + parser.add_argument( + "--feature_type", + type=str, + default="gene_expression", + help="Type of transcriptome feature to use (default: gene_expression)", + ) + args = parser.parse_args() + + dataset_name = args.dataset_name + n_components = args.n_components + data_dir = Path(args.data_path).resolve() + feature_type = args.feature_type + + # Input file: gene expression CSV + input_file = data_dir / dataset_name / f"{feature_type}.csv" + # Output files: PCA features CSV and fitted PCA/scaler objects + output_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}.csv" + pca_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}_pca.pkl" + scaler_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}_scaler.pkl" + + if not input_file.exists(): + raise FileNotFoundError(f"Error: {input_file} not found.") + + print(f"Loading transcriptome data from {input_file}...") + # Load gene expression data + # Format: rows are cell lines (indexed by cell_line_name), columns are genes + ge_df = pd.read_csv(input_file, index_col=CELL_LINE_IDENTIFIER) + ge_df.index = ge_df.index.astype(str) + + # Drop cellosaurus_id if present + if "cellosaurus_id" in ge_df.columns: + ge_df = ge_df.drop(columns=["cellosaurus_id"]) + + print(f"Loaded {len(ge_df)} cell lines with {len(ge_df.columns)} genes") + print(f"Performing PCA with {n_components} components...") + + # Extract cell line IDs and gene expression matrix + cell_line_ids = ge_df.index.values + gene_expression_matrix = ge_df.values.astype(np.float32) + + # Handle missing values: fill with 0 or mean (using 0 as default) + if np.isnan(gene_expression_matrix).any(): + print("Warning: Found NaN values. Filling with 0.") + gene_expression_matrix = np.nan_to_num(gene_expression_matrix, nan=0.0) + + # Standardize the data before PCA + scaler = StandardScaler() + gene_expression_scaled = scaler.fit_transform(gene_expression_matrix) + + # Perform PCA + pca = PCA(n_components=n_components) + pca_features = pca.fit_transform(gene_expression_scaled) + + print(f"PCA explained variance ratio: {pca.explained_variance_ratio_.sum():.4f}") + print(f"PCA explained variance (first 10 components): {pca.explained_variance_ratio_[:10]}") + + # Create output DataFrame + pca_df = pd.DataFrame( + pca_features, + index=cell_line_ids, + columns=[f"PC{i + 1}" for i in range(n_components)], + ) + pca_df.index.name = CELL_LINE_IDENTIFIER + pca_df = pca_df.reset_index() + + # Save PCA-transformed features + pca_df.to_csv(output_file, index=False) + print(f"PCA features saved to {output_file}") + + # Save fitted PCA and scaler for potential future use (e.g., transforming new data) + joblib.dump(pca, pca_file) + print(f"Fitted PCA model saved to {pca_file}") + + joblib.dump(scaler, scaler_file) + print(f"Fitted scaler saved to {scaler_file}") + + print("Finished processing transcriptome PCA featurization.") + + +if __name__ == "__main__": + main() diff --git a/tests/test_featurizers.py b/tests/test_featurizers.py index a2db9b69..08d78368 100644 --- a/tests/test_featurizers.py +++ b/tests/test_featurizers.py @@ -3,6 +3,7 @@ import sys from unittest.mock import patch +import numpy as np import pandas as pd import torch @@ -123,6 +124,7 @@ def test_bpe_smiles_featurizer(tmp_path): except ImportError: print("subword-nmt package not installed; skipping BPE SMILES featurizer test.") return + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -155,3 +157,61 @@ def test_bpe_smiles_featurizer(tmp_path): assert len(feature_cols) == 128 # Values should be numeric (character ordinals, may be stored as float in CSV) assert pd.api.types.is_numeric_dtype(df_out[feature_cols[0]]) + + +def test_transcriptome_pca_featurizer(tmp_path): + """ + Test transcriptome PCA featurizer end-to-end. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + import drevalpy.datasets.featurizer.create_transcriptome_pca as pca_feat + except ImportError: + print("sklearn package not installed; skipping transcriptome PCA featurizer test.") + return + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + # Format: rows are cell lines, columns are genes + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + # Generate some random gene expression data + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # Run the featurizer + with patch.object( + sys, + "argv", + ["prog", dataset, "--data_path", str(tmp_path), "--n_components", "10"], + ): + pca_feat.main() + + # Check output files + output_file = data_dir / "cell_line_gene_expression_pca_10.csv" + pca_file = data_dir / "cell_line_gene_expression_pca_10_pca.pkl" + scaler_file = data_dir / "cell_line_gene_expression_pca_10_scaler.pkl" + + assert output_file.exists() + assert pca_file.exists() + assert scaler_file.exists() + + # Verify output CSV structure + df_out = pd.read_csv(output_file) + assert "cell_line_name" in df_out.columns + assert len(df_out.columns) == 11 # cell_line_name + 10 PC columns + assert len(df_out) == n_cell_lines + assert all(f"PC{i + 1}" in df_out.columns for i in range(10)) From d98552fd4b4266c0dc6b12a81c143af6b13fd021 Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 12:15:00 +0000 Subject: [PATCH 09/15] Add PCA neural network --- .../SimpleNeuralNetwork/hyperparameters.yaml | 19 +++ .../simple_neural_network.py | 48 +++++- drevalpy/models/__init__.py | 4 +- tests/models/test_global_models.py | 143 ++++++++++++++++++ 4 files changed, 210 insertions(+), 4 deletions(-) diff --git a/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml b/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml index 721991f2..47440d56 100644 --- a/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml +++ b/drevalpy/models/SimpleNeuralNetwork/hyperparameters.yaml @@ -62,3 +62,22 @@ ChemBERTaNeuralNetwork: - 16 max_epochs: - 100 + +PCANeuralNetwork: + dropout_prob: + - 0.3 + units_per_layer: + - - 32 + - 16 + - 8 + - 4 + - - 128 + - 64 + - 32 + - - 64 + - 64 + - 32 + n_components: + - 100 + max_epochs: + - 100 diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 077e69ff..837e02a9 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -1,4 +1,4 @@ -"""Contains the SimpleNeuralNetwork and the ChemBERTaNeuralNetwork model.""" +"""Contains the SimpleNeuralNetwork, ChemBERTaNeuralNetwork, and PCANeuralNetwork models.""" import json import os @@ -12,6 +12,7 @@ from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER from ..drp_model import DRPModel from ..utils import load_and_select_gene_features, load_drug_fingerprint_features, scale_gene_expression @@ -90,7 +91,7 @@ def train( gene_expression_scaler=self.gene_expression_scaler, ) - dim_gex = next(iter(cell_line_input.features.values()))["gene_expression"].shape[0] + dim_gex = next(iter(cell_line_input.features.values()))[self.cell_line_views[0]].shape[0] dim_fingerprint = next(iter(drug_input.features.values()))[self.drug_views[0]].shape[0] self.hyperparameters["input_dim_gex"] = dim_gex self.hyperparameters["input_dim_fp"] = dim_fingerprint @@ -159,7 +160,7 @@ def predict( ) x = self.get_concatenated_features( - cell_line_view="gene_expression", + cell_line_view=self.cell_line_views[0], drug_view=self.drug_views[0], cell_line_ids_output=cell_line_ids, drug_ids_output=drug_ids, @@ -292,3 +293,44 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase features[drug_id] = {"chemberta_embeddings": embedding} return FeatureDataset(features) + + +class PCANeuralNetwork(SimpleNeuralNetwork): + """Neural Network model using PCA-transformed gene expression and fingerprints.""" + + cell_line_views = ["gene_expression_pca"] + + @classmethod + def get_model_name(cls) -> str: + """ + Returns the model name. + + :returns: PCANeuralNetwork + """ + return "PCANeuralNetwork" + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """ + Loads the PCA-transformed gene expression features. + + :param data_path: Path to the data, e.g., data/ + :param dataset_name: name of the dataset, e.g., GDSC1 + :returns: FeatureDataset containing the PCA features + :raises FileNotFoundError: if the PCA features file is not found + """ + n_components = self.hyperparameters.get("n_components", 100) + pca_file = os.path.join(data_path, dataset_name, f"cell_line_gene_expression_pca_{n_components}.csv") + if not os.path.exists(pca_file): + raise FileNotFoundError( + f"PCA features file not found: {pca_file}. " + f"Please create it first with create_transcriptome_pca.py using --n_components {n_components}." + ) + + pca_df = pd.read_csv(pca_file, dtype={CELL_LINE_IDENTIFIER: str}) + features = {} + for _, row in pca_df.iterrows(): + cell_line_id = row[CELL_LINE_IDENTIFIER] + embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) + features[cell_line_id] = {"gene_expression_pca": embedding} + + return FeatureDataset(features) diff --git a/drevalpy/models/__init__.py b/drevalpy/models/__init__.py index 5ecf2e4f..f8b934b8 100644 --- a/drevalpy/models/__init__.py +++ b/drevalpy/models/__init__.py @@ -30,6 +30,7 @@ "DrugGNN", "ChemBERTaNeuralNetwork", "PharmaFormerModel", + "PCANeuralNetwork", ] from .baselines.multi_omics_random_forest import MultiOmicsRandomForest @@ -57,7 +58,7 @@ from .MOLIR.molir import MOLIR from .PharmaFormer.pharmaformer import PharmaFormerModel from .SimpleNeuralNetwork.multiomics_neural_network import MultiOmicsNeuralNetwork -from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, SimpleNeuralNetwork +from .SimpleNeuralNetwork.simple_neural_network import ChemBERTaNeuralNetwork, PCANeuralNetwork, SimpleNeuralNetwork from .SRMF.srmf import SRMF from .SuperFELTR.superfeltr import SuperFELTR @@ -93,6 +94,7 @@ "DrugGNN": DrugGNN, "ChemBERTaNeuralNetwork": ChemBERTaNeuralNetwork, "PharmaFormer": PharmaFormerModel, + "PCANeuralNetwork": PCANeuralNetwork, } # MODEL_FACTORY is used in the pipeline! diff --git a/tests/models/test_global_models.py b/tests/models/test_global_models.py index 69e20b2b..535b98e6 100644 --- a/tests/models/test_global_models.py +++ b/tests/models/test_global_models.py @@ -5,9 +5,13 @@ from typing import cast import numpy as np +import pandas as pd import pytest +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER from drevalpy.evaluation import evaluate from drevalpy.experiment import cross_study_prediction from drevalpy.models import MODEL_FACTORY @@ -156,3 +160,142 @@ def test_global_models( split_index=0, single_drug_id=None, ) + + +def test_pca_neural_network( + sample_dataset: DrugResponseDataset, + cross_study_dataset: DrugResponseDataset, +) -> None: + """ + Test PCANeuralNetwork model. + + This test creates PCA features from gene expression data, then trains and evaluates + the PCANeuralNetwork model. + + :param sample_dataset: from conftest.py + :param cross_study_dataset: from conftest.py + :raises ValueError: if drug input is None + """ + test_mode = "LTO" + model_name = "PCANeuralNetwork" + n_components = 10 # Use small number for testing + + drug_response = sample_dataset + drug_response.split_dataset(n_cv_splits=2, mode=test_mode, validation_ratio=0.4) + assert drug_response.cv_splits is not None + split = drug_response.cv_splits[0] + train_dataset = split["train"] + val_es_dataset = split["validation_es"] + es_dataset = split["early_stopping"] + + path_data = os.path.join("..", "data") + + # Create PCA features from gene expression data + ge_file = os.path.join(path_data, "TOYv1", "gene_expression.csv") + ge_df = pd.read_csv(ge_file, index_col=CELL_LINE_IDENTIFIER) + ge_df.index = ge_df.index.astype(str) + if "cellosaurus_id" in ge_df.columns: + ge_df = ge_df.drop(columns=["cellosaurus_id"]) + + # Perform PCA + scaler = StandardScaler() + ge_scaled = scaler.fit_transform(ge_df.values) + pca = PCA(n_components=n_components) + pca_features = pca.fit_transform(ge_scaled) + + # Save PCA features + pca_df = pd.DataFrame( + pca_features, + index=ge_df.index, + columns=[f"PC{i + 1}" for i in range(n_components)], + ) + pca_df.index.name = CELL_LINE_IDENTIFIER + pca_df = pca_df.reset_index() + + pca_output_file = os.path.join(path_data, "TOYv1", f"cell_line_gene_expression_pca_{n_components}.csv") + pca_df.to_csv(pca_output_file, index=False) + + try: + # Load model and features - need to build_model first to set n_components in hyperparameters + model_class = cast(type[DRPModel], MODEL_FACTORY[model_name]) + model = model_class() + hpams = model.get_hyperparameter_set() + hpam_combi = hpams[0] + hpam_combi["units_per_layer"] = [2, 2] + hpam_combi["max_epochs"] = 1 + hpam_combi["n_components"] = n_components + model.build_model(hyperparameters=hpam_combi) + + cell_line_input = model.load_cell_line_features(data_path=path_data, dataset_name="TOYv1") + drug_input = model.load_drug_features(data_path=path_data, dataset_name="TOYv1") + if drug_input is None: + raise ValueError("Drug input is None") + + cell_lines_to_keep = cell_line_input.identifiers + drugs_to_keep = drug_input.identifiers + + train_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + val_es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + es_dataset.reduce_to(cell_line_ids=cell_lines_to_keep, drug_ids=drugs_to_keep) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.train( + output=train_dataset, + cell_line_input=cell_line_input, + drug_input=drug_input, + output_earlystopping=es_dataset, + model_checkpoint_dir=tmpdirname, + ) + + prediction_dataset = val_es_dataset + prediction_dataset._predictions = model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + + # Save and load test + with tempfile.TemporaryDirectory() as model_dir: + model.save(model_dir) + loaded_model = model_class.load(model_dir) + assert isinstance(loaded_model, DRPModel) + + preds_before = model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + preds_after = loaded_model.predict( + drug_ids=prediction_dataset.drug_ids, + cell_line_ids=prediction_dataset.cell_line_ids, + drug_input=drug_input, + cell_line_input=cell_line_input, + ) + + assert preds_before.shape == preds_after.shape + assert isinstance(preds_after, np.ndarray) + + metrics = evaluate(prediction_dataset, metric=["Pearson"]) + print(f"Model: {model_name}, Pearson: {metrics['Pearson']}") + assert metrics["Pearson"] >= -1.0 + + with tempfile.TemporaryDirectory() as temp_dir: + print(f"Running cross-study prediction for {model_name}") + cross_study_prediction( + dataset=cross_study_dataset, + model=model, + test_mode=test_mode, + train_dataset=train_dataset, + path_data=path_data, + early_stopping_dataset=None, + response_transformation=None, + path_out=temp_dir, + split_index=0, + single_drug_id=None, + ) + finally: + # Clean up the generated PCA file + if os.path.exists(pca_output_file): + os.remove(pca_output_file) From 23a84265092c5ea1d615903b43dbb37f2de54386 Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 15:18:07 +0000 Subject: [PATCH 10/15] Improve featurizer structure --- drevalpy/datasets/featurizer/__init__.py | 52 + .../datasets/featurizer/cell_line/__init__.py | 9 + .../datasets/featurizer/cell_line/base.py | 259 +++++ drevalpy/datasets/featurizer/cell_line/pca.py | 200 ++++ .../create_chemberta_drug_embeddings.py | 85 -- .../datasets/featurizer/create_drug_graphs.py | 145 --- .../featurizer/create_molgnet_embeddings.py | 917 ------------------ .../featurizer/create_transcriptome_pca.py | 114 --- drevalpy/datasets/featurizer/drug/__init__.py | 13 + drevalpy/datasets/featurizer/drug/base.py | 193 ++++ .../datasets/featurizer/drug/chemberta.py | 101 ++ .../datasets/featurizer/drug/drug_graph.py | 220 +++++ drevalpy/datasets/featurizer/drug/molgnet.py | 817 ++++++++++++++++ drevalpy/models/DrugGNN/drug_gnn.py | 2 +- .../simple_neural_network.py | 45 +- tests/test_featurizers.py | 315 +++++- 16 files changed, 2150 insertions(+), 1337 deletions(-) create mode 100644 drevalpy/datasets/featurizer/__init__.py create mode 100644 drevalpy/datasets/featurizer/cell_line/__init__.py create mode 100644 drevalpy/datasets/featurizer/cell_line/base.py create mode 100644 drevalpy/datasets/featurizer/cell_line/pca.py delete mode 100644 drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py delete mode 100644 drevalpy/datasets/featurizer/create_drug_graphs.py delete mode 100644 drevalpy/datasets/featurizer/create_molgnet_embeddings.py delete mode 100644 drevalpy/datasets/featurizer/create_transcriptome_pca.py create mode 100644 drevalpy/datasets/featurizer/drug/__init__.py create mode 100644 drevalpy/datasets/featurizer/drug/base.py create mode 100644 drevalpy/datasets/featurizer/drug/chemberta.py create mode 100644 drevalpy/datasets/featurizer/drug/drug_graph.py create mode 100644 drevalpy/datasets/featurizer/drug/molgnet.py diff --git a/drevalpy/datasets/featurizer/__init__.py b/drevalpy/datasets/featurizer/__init__.py new file mode 100644 index 00000000..3a178ca9 --- /dev/null +++ b/drevalpy/datasets/featurizer/__init__.py @@ -0,0 +1,52 @@ +"""Featurizers for converting drug and cell line data to embeddings. + +This module provides abstract base classes and concrete implementations for +featurizing drugs and cell lines for drug response prediction models. + +Drug Featurizers: + - DrugFeaturizer: Abstract base class for drug featurizers + - ChemBERTaFeaturizer: ChemBERTa transformer embeddings from SMILES + - DrugGraphFeaturizer: Molecular graph representations + - MolGNetFeaturizer: MolGNet graph neural network embeddings + +Cell Line Featurizers: + - CellLineFeaturizer: Abstract base class for cell line featurizers + - PCAFeaturizer: PCA dimensionality reduction for omics data + +Example usage:: + + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer, PCAFeaturizer + + # Drug features + drug_featurizer = ChemBERTaFeaturizer(device="cuda") + drug_features = drug_featurizer.load_or_generate("data", "GDSC1") + + # Cell line features + cell_featurizer = PCAFeaturizer(n_components=100) + cell_features = cell_featurizer.load_or_generate("data", "GDSC1") +""" + +# Cell line featurizers +from .cell_line import ( + CellLineFeaturizer, + PCAFeaturizer, +) + +# Drug featurizers +from .drug import ( + ChemBERTaFeaturizer, + DrugFeaturizer, + DrugGraphFeaturizer, + MolGNetFeaturizer, +) + +__all__ = [ + # Drug featurizers + "DrugFeaturizer", + "ChemBERTaFeaturizer", + "DrugGraphFeaturizer", + "MolGNetFeaturizer", + # Cell line featurizers + "CellLineFeaturizer", + "PCAFeaturizer", +] diff --git a/drevalpy/datasets/featurizer/cell_line/__init__.py b/drevalpy/datasets/featurizer/cell_line/__init__.py new file mode 100644 index 00000000..47a5f613 --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/__init__.py @@ -0,0 +1,9 @@ +"""Cell line featurizers for converting omics data to embeddings.""" + +from .base import CellLineFeaturizer +from .pca import PCAFeaturizer + +__all__ = [ + "CellLineFeaturizer", + "PCAFeaturizer", +] diff --git a/drevalpy/datasets/featurizer/cell_line/base.py b/drevalpy/datasets/featurizer/cell_line/base.py new file mode 100644 index 00000000..a441ed32 --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/base.py @@ -0,0 +1,259 @@ +"""Abstract base class for cell line featurizers.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER + + +class CellLineFeaturizer(ABC): + """Abstract base class for cell line featurizers. + + Cell line featurizers convert omics data (e.g., gene expression, methylation) + into numerical embeddings that can be used as input features for machine learning models. + + Supports both single-omics and multi-omics featurization through the `omics_types` + parameter. + + Subclasses must implement: + - featurize(): Convert omics data for a single cell line to its embedding + - get_feature_name(): Return the name of the feature view + - get_output_filename(): Return the filename for cached embeddings + + The base class provides: + - load_or_generate(): Load cached embeddings or generate and cache them + - generate_embeddings(): Generate embeddings for all cell lines in a dataset + - load_embeddings(): Load pre-generated embeddings from disk + """ + + # Supported omics types and their corresponding file names + OMICS_FILE_MAPPING = { + "gene_expression": "gene_expression.csv", + "methylation": "methylation.csv", + "mutations": "mutations.csv", + "copy_number_variation": "copy_number_variation.csv", + } + + def __init__(self, omics_types: list[str] | str = "gene_expression"): + """Initialize the featurizer. + + :param omics_types: Single omics type or list of omics types to use. + Supported types: 'gene_expression', 'methylation', + 'mutations', 'copy_number_variation' + :raises ValueError: If an unsupported omics type is provided + """ + if isinstance(omics_types, str): + omics_types = [omics_types] + + for omics_type in omics_types: + if omics_type not in self.OMICS_FILE_MAPPING: + raise ValueError( + f"Unsupported omics type: {omics_type}. " f"Supported types: {list(self.OMICS_FILE_MAPPING.keys())}" + ) + + self.omics_types = omics_types + + @abstractmethod + def featurize(self, omics_data: dict[str, np.ndarray]) -> np.ndarray | Any: + """Convert omics data to a feature representation. + + :param omics_data: Dictionary mapping omics type to data array for a single cell line + :returns: Feature representation (numpy array or other format) + """ + + @classmethod + @abstractmethod + def get_feature_name(cls) -> str: + """Return the name of the feature view. + + This name is used as the key in the FeatureDataset. + + :returns: Feature view name (e.g., 'gene_expression_pca') + """ + + @abstractmethod + def get_output_filename(self) -> str: + """Return the filename for cached embeddings. + + Note: This is an instance method (not classmethod) because the filename + may depend on featurizer parameters (e.g., n_components for PCA). + + :returns: Filename (e.g., 'cell_line_gene_expression_pca_100.csv') + """ + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached embeddings or generate and cache them if not available. + + This is the main entry point for using a featurizer. It checks if + pre-generated embeddings exist and loads them, otherwise generates + new embeddings and saves them for future use. + + :param data_path: Path to the data directory (e.g., 'data/') + :param dataset_name: Name of the dataset (e.g., 'GDSC1') + :returns: FeatureDataset containing the cell line embeddings + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists(): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Embeddings not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + def _load_omics_data(self, data_path: str, dataset_name: str) -> dict[str, pd.DataFrame]: + """Load omics data files for the specified omics types. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: Dictionary mapping omics type to DataFrame + :raises FileNotFoundError: If any required omics file is not found + """ + data_dir = Path(data_path) / dataset_name + omics_data = {} + + for omics_type in self.omics_types: + filename = self.OMICS_FILE_MAPPING[omics_type] + filepath = data_dir / filename + + if not filepath.exists(): + raise FileNotFoundError( + f"Omics data file not found: {filepath}. " f"Please ensure the {omics_type} data is available." + ) + + df = pd.read_csv(filepath, dtype={CELL_LINE_IDENTIFIER: str}) + df = df.set_index(CELL_LINE_IDENTIFIER) + omics_data[omics_type] = df + + return omics_data + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate embeddings for all cell lines in a dataset and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the generated embeddings + """ + data_dir = Path(data_path).resolve() + output_file = data_dir / dataset_name / self.get_output_filename() + + # Load omics data + omics_data = self._load_omics_data(data_path, dataset_name) + + # Get common cell line IDs across all omics types + cell_line_ids = None + for _omics_type, df in omics_data.items(): + if cell_line_ids is None: + cell_line_ids = set(df.index) + else: + cell_line_ids = cell_line_ids.intersection(set(df.index)) + + cell_line_ids = sorted(list(cell_line_ids)) + print(f"Processing {len(cell_line_ids)} cell lines for dataset {dataset_name}...") + + # Generate embeddings + embeddings_list = [] + valid_cell_line_ids = [] + + for cell_line_id in cell_line_ids: + try: + # Prepare omics data for this cell line + cell_omics = {} + for omics_type, df in omics_data.items(): + cell_omics[omics_type] = df.loc[cell_line_id].to_numpy(dtype=np.float32) + + embedding = self.featurize(cell_omics) + embeddings_list.append(embedding) + valid_cell_line_ids.append(cell_line_id) + except Exception as e: + print(f"Failed to process cell line {cell_line_id}: {e}") + continue + + # Save embeddings + self._save_embeddings(embeddings_list, valid_cell_line_ids, output_file, omics_data) + + print(f"Embeddings saved to {output_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(embeddings_list, valid_cell_line_ids) + + def _save_embeddings( + self, + embeddings: list, + cell_line_ids: list[str], + output_path: Path, + omics_data: dict[str, pd.DataFrame] | None = None, + ) -> None: + """Save embeddings to disk. + + Default implementation saves as CSV. Subclasses can override for other formats. + + :param embeddings: List of embedding arrays + :param cell_line_ids: List of cell line identifiers + :param output_path: Path to save the embeddings + :param omics_data: Optional omics data (may be used by subclasses for saving fitted models) + """ + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, CELL_LINE_IDENTIFIER, cell_line_ids) + embeddings_df.to_csv(output_path, index=False) + + def _create_feature_dataset(self, embeddings: list, cell_line_ids: list[str]) -> FeatureDataset: + """Create a FeatureDataset from embeddings. + + :param embeddings: List of embedding arrays + :param cell_line_ids: List of cell line identifiers + :returns: FeatureDataset containing the embeddings + """ + feature_name = self.get_feature_name() + features = {} + for cell_line_id, embedding in zip(cell_line_ids, embeddings, strict=True): + if isinstance(embedding, np.ndarray): + features[cell_line_id] = {feature_name: embedding.astype(np.float32)} + else: + features[cell_line_id] = {feature_name: embedding} + return FeatureDataset(features) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + embeddings_df = pd.read_csv(embeddings_file, dtype={CELL_LINE_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + cell_line_id = row[CELL_LINE_IDENTIFIER] + embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) + features[cell_line_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +def main(): + """Entry point for running featurizer from command line. + + This function should be overridden by subclasses that support CLI usage. + + :raises NotImplementedError: Always, as subclasses should implement their own main() + """ + raise NotImplementedError("Subclasses should implement their own main() function") + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py new file mode 100644 index 00000000..25b095ad --- /dev/null +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -0,0 +1,200 @@ +"""PCA featurizer for cell line gene expression data.""" + +import argparse +import pickle # noqa: S403 +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER + +from .base import CellLineFeaturizer + + +class PCAFeaturizer(CellLineFeaturizer): + """Featurizer that applies PCA to gene expression data. + + This featurizer standardizes gene expression data and applies PCA + to reduce dimensionality. It is designed specifically for transcriptomics + (gene expression) data. + + Example usage:: + + featurizer = PCAFeaturizer(n_components=100) + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, n_components: int = 100): + """Initialize the PCA featurizer. + + :param n_components: Number of principal components to keep + """ + super().__init__(omics_types="gene_expression") + self.n_components = n_components + self._scaler: StandardScaler | None = None + self._pca: PCA | None = None + self._fitted = False + + def featurize(self, omics_data: dict[str, np.ndarray]) -> np.ndarray: + """Apply PCA transformation to gene expression data. + + :param omics_data: Dictionary with 'gene_expression' key containing the data + :returns: PCA-transformed features + :raises RuntimeError: If the PCA model is not fitted + :raises ValueError: If gene_expression data is not provided + """ + if not self._fitted: + raise RuntimeError("PCA model is not fitted. Call generate_embeddings() or fit() first.") + + if "gene_expression" not in omics_data: + raise ValueError("gene_expression data is required for PCA featurizer") + + data = omics_data["gene_expression"].reshape(1, -1) + scaled = self._scaler.transform(data) + return self._pca.transform(scaled).flatten() + + def fit(self, gene_expression_df: pd.DataFrame) -> None: + """Fit the scaler and PCA model on gene expression data. + + :param gene_expression_df: DataFrame with cell lines as rows and genes as columns + """ + data = gene_expression_df.values + + self._scaler = StandardScaler() + scaled_data = self._scaler.fit_transform(data) + + n_components = min(self.n_components, min(scaled_data.shape)) + self._pca = PCA(n_components=n_components) + self._pca.fit(scaled_data) + + self._fitted = True + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'gene_expression_pca' + """ + return "gene_expression_pca" + + def get_output_filename(self) -> str: + """Return the output filename for cached embeddings. + + :returns: Filename like 'cell_line_gene_expression_pca_100.csv' + """ + return f"cell_line_gene_expression_pca_{self.n_components}.csv" + + def _get_model_filename(self) -> str: + """Return the filename for the fitted model. + + :returns: Filename like 'cell_line_gene_expression_pca_100_models.pkl' + """ + return f"cell_line_gene_expression_pca_{self.n_components}_models.pkl" + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate PCA embeddings for all cell lines and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the PCA embeddings + :raises FileNotFoundError: If the gene expression file is not found + """ + data_dir = Path(data_path).resolve() + output_file = data_dir / dataset_name / self.get_output_filename() + model_file = data_dir / dataset_name / self._get_model_filename() + + # Load gene expression data + ge_file = data_dir / dataset_name / "gene_expression.csv" + if not ge_file.exists(): + raise FileNotFoundError(f"Gene expression file not found: {ge_file}") + + ge_df = pd.read_csv(ge_file, dtype={CELL_LINE_IDENTIFIER: str}) + ge_df = ge_df.set_index(CELL_LINE_IDENTIFIER) + + cell_line_ids = list(ge_df.index) + print(f"Processing {len(cell_line_ids)} cell lines for dataset {dataset_name}...") + + # Fit the model + self.fit(ge_df) + + # Transform all cell lines + scaled_data = self._scaler.transform(ge_df.values) + embeddings = self._pca.transform(scaled_data) + + # Save embeddings + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, CELL_LINE_IDENTIFIER, cell_line_ids) + embeddings_df.to_csv(output_file, index=False) + + # Save fitted models + with open(model_file, "wb") as f: + pickle.dump({"scaler": self._scaler, "pca": self._pca}, f) + + print(f"Embeddings saved to {output_file}") + print(f"Fitted models saved to {model_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(list(embeddings), cell_line_ids) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated PCA embeddings from disk. + + Also loads the fitted scaler and PCA model for future transformations. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings or model file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + model_file = Path(data_path) / dataset_name / self._get_model_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + # Load fitted models if available + if model_file.exists(): + with open(model_file, "rb") as f: + models = pickle.load(f) # noqa: S301 + self._scaler = models["scaler"] + self._pca = models["pca"] + self._fitted = True + else: + raise FileNotFoundError( + f"Fitted model file not found: {model_file}. " f"Use generate_embeddings() to fit and save model." + ) + + # Load embeddings + embeddings_df = pd.read_csv(embeddings_file, dtype={CELL_LINE_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + cell_line_id = row[CELL_LINE_IDENTIFIER] + embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) + features[cell_line_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +def main(): + """Generate PCA embeddings for cell line gene expression from command line.""" + parser = argparse.ArgumentParser(description="Generate PCA embeddings for cell line gene expression.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + parser.add_argument("--n_components", type=int, default=100, help="Number of PCA components") + args = parser.parse_args() + + featurizer = PCAFeaturizer(n_components=args.n_components) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py b/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py deleted file mode 100644 index 3d19ff18..00000000 --- a/drevalpy/datasets/featurizer/create_chemberta_drug_embeddings.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Preprocesses drug SMILES strings into ChemBERTa embeddings.""" - -import argparse -from pathlib import Path - -import pandas as pd -import torch -from tqdm import tqdm - -try: - from transformers import AutoModel, AutoTokenizer -except ImportError: - raise ImportError( - "Please install transformers package for ChemBERTa embedding featurizer: pip install transformers" - ) -# Load ChemBERTa -tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") -model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") -model.eval() - - -def _smiles_to_chemberta(smiles: str, device="cpu"): - inputs = tokenizer(smiles, return_tensors="pt", truncation=True) - inputs = {k: v.to(device) for k, v in inputs.items()} - - with torch.no_grad(): - outputs = model(**inputs) - hidden_states = outputs.last_hidden_state - - embedding = hidden_states.mean(dim=1).squeeze(0) - return embedding.cpu().numpy() - - -def main(): - """Process drug SMILES and save ChemBERTa embeddings. - - :raises Exception: If a drug fails to process. - """ - parser = argparse.ArgumentParser(description="Preprocess drug SMILES to ChemBERTa embeddings.") - parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") - parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") - parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") - args = parser.parse_args() - - dataset_name = args.dataset_name - device = args.device - data_dir = Path(args.data_path).resolve() - - smiles_file = data_dir / dataset_name / "drug_smiles.csv" - output_file = data_dir / dataset_name / "drug_chemberta_embeddings.csv" - - if not smiles_file.exists(): - print(f"Error: {smiles_file} not found.") - return - - smiles_df = pd.read_csv(smiles_file, dtype={"canonical_smiles": str, "pubchem_id": str}) - embeddings_list = [] - drug_ids = [] - - print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") - - for row in tqdm(smiles_df.itertuples(index=False), total=len(smiles_df)): - drug_id = row.pubchem_id - smiles = row.canonical_smiles - - try: - embedding = _smiles_to_chemberta(smiles, device=device) - embeddings_list.append(embedding) - drug_ids.append(drug_id) - except Exception as e: - print() - print(smiles) - print() - print(f"Failed to process {drug_id}") - raise e - - embeddings_array = pd.DataFrame(embeddings_list) - embeddings_array.insert(0, "pubchem_id", drug_ids) - embeddings_array.to_csv(output_file, index=False) - - print(f"Finished processing. Embeddings saved to {output_file}") - - -if __name__ == "__main__": - main() diff --git a/drevalpy/datasets/featurizer/create_drug_graphs.py b/drevalpy/datasets/featurizer/create_drug_graphs.py deleted file mode 100644 index c79e09ca..00000000 --- a/drevalpy/datasets/featurizer/create_drug_graphs.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -Preprocesses drug SMILES strings into graph representations. - -This script takes a dataset name as input, reads the corresponding -drug_smiles.csv file, and converts each SMILES string into a -torch_geometric.data.Data object. The resulting graph objects are saved -to {data_path}/{dataset_name}/drug_graphs/{drug_name}.pt. -""" - -import argparse -import os -from pathlib import Path - -import pandas as pd -import torch -from torch_geometric.data import Data -from tqdm import tqdm - -try: - from rdkit import Chem -except ImportError: - raise ImportError("Please install rdkit package for drug graphs featurizer: pip install rdkit") - -# Atom feature configuration -ATOM_FEATURES = { - "atomic_num": list(range(1, 119)), - "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], - "num_hs": [0, 1, 2, 3, 4, 5, 6, 7, 8], - "hybridization": [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - ], -} - -# Bond feature configuration -BOND_FEATURES = { - "bond_type": [ - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, - ] -} - - -def _one_hot_encode(value, choices): - """Create a one-hot encoding for a value in a list of choices. - - :param value: The value to be one-hot encoded. - :param choices: A list of possible choices for the value. - :return: A list representing the one-hot encoding. - """ - encoding = [0] * (len(choices) + 1) - index = choices.index(value) if value in choices else -1 - encoding[index] = 1 - return encoding - - -def _smiles_to_graph(smiles: str): - """ - Converts a SMILES string to a torch_geometric.data.Data object. - - :param smiles: The SMILES string for the drug. - :return: A Data object representing the molecular graph, or None if conversion fails. - """ - mol = Chem.MolFromSmiles(smiles) - if mol is None: - return None - - # Atom features - atom_features_list = [] - for atom in mol.GetAtoms(): - features = [] - features.extend(_one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURES["atomic_num"])) - features.extend(_one_hot_encode(atom.GetDegree(), ATOM_FEATURES["degree"])) - features.extend(_one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURES["formal_charge"])) - features.extend(_one_hot_encode(atom.GetTotalNumHs(), ATOM_FEATURES["num_hs"])) - features.extend(_one_hot_encode(atom.GetHybridization(), ATOM_FEATURES["hybridization"])) - features.append(atom.GetIsAromatic()) - features.append(atom.IsInRing()) - atom_features_list.append(features) - x = torch.tensor(atom_features_list, dtype=torch.float) - - # Edge index and edge features - edge_indices = [] - edge_features_list = [] - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - - # Edge features - features = [] - features.extend(_one_hot_encode(bond.GetBondType(), BOND_FEATURES["bond_type"])) - features.append(bond.GetIsConjugated()) - features.append(bond.IsInRing()) - - edge_indices.extend([[i, j], [j, i]]) - edge_features_list.extend([features, features]) # Same features for both directions - - edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() - edge_attr = torch.tensor(edge_features_list, dtype=torch.float) - - return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - - -def main(): - """Main function to run the preprocessing.""" - parser = argparse.ArgumentParser(description="Preprocess drug SMILES to graphs.") - parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") - parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") - args = parser.parse_args() - - dataset_name = args.dataset_name - data_dir = Path(args.data_path).resolve() - smiles_file = data_dir / dataset_name / "drug_smiles.csv" - output_dir = data_dir / dataset_name / "drug_graphs" - - if not smiles_file.exists(): - print(f"Error: {smiles_file} not found.") - return - - os.makedirs(output_dir, exist_ok=True) - - smiles_df = pd.read_csv(smiles_file) - - print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") - - for _, row in tqdm(smiles_df.iterrows(), total=smiles_df.shape[0]): - drug_id = row["pubchem_id"] - smiles = row["canonical_smiles"] - - graph = _smiles_to_graph(smiles) - - if graph: - torch.save(graph, output_dir / f"{drug_id}.pt") - - print(f"Finished processing. Graphs saved to {output_dir}") - - -if __name__ == "__main__": - main() diff --git a/drevalpy/datasets/featurizer/create_molgnet_embeddings.py b/drevalpy/datasets/featurizer/create_molgnet_embeddings.py deleted file mode 100644 index 2d337e80..00000000 --- a/drevalpy/datasets/featurizer/create_molgnet_embeddings.py +++ /dev/null @@ -1,917 +0,0 @@ -#!/usr/bin/env python3 -"""MolGNet feature extraction utilities (needed for DIPK and adapted from the DIPK github). - -Creates MolGNet embeddings for molecules given their SMILES strings. This module needs torch_scatter. - python create_molgnet_embeddings.py dataset_name --checkpoint meta/MolGNet.pt --data_path data -""" - -import argparse -import math -import os -import pickle # noqa: S403 -from pathlib import Path -from typing import Any, Optional - -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as torch_nn_f -from torch import nn -from torch.nn import Parameter -from torch_geometric.data import Data -from torch_geometric.utils import add_self_loops, softmax -from tqdm import tqdm - -try: - from rdkit import Chem - from rdkit.Chem.rdchem import Mol as RDMol -except ImportError: - raise ImportError("Please install rdkit package for MolGNet featurizer: pip install rdkit") - -# building graphs -allowable_features: dict[str, list[Any]] = { - "atomic_num": list(range(1, 122)), - "formal_charge": ["unk", -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], - "chirality": [ - "unk", - Chem.rdchem.ChiralType.CHI_UNSPECIFIED, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, - Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, - Chem.rdchem.ChiralType.CHI_OTHER, - ], - "hybridization": [ - "unk", - Chem.rdchem.HybridizationType.S, - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2, - Chem.rdchem.HybridizationType.UNSPECIFIED, - ], - "numH": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8], - "implicit_valence": ["unk", 0, 1, 2, 3, 4, 5, 6], - "degree": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - "isaromatic": [False, True], - "bond_type": [ - "unk", - Chem.rdchem.BondType.SINGLE, - Chem.rdchem.BondType.DOUBLE, - Chem.rdchem.BondType.TRIPLE, - Chem.rdchem.BondType.AROMATIC, - ], - "bond_dirs": [ - Chem.rdchem.BondDir.NONE, - Chem.rdchem.BondDir.ENDUPRIGHT, - Chem.rdchem.BondDir.ENDDOWNRIGHT, - ], - "bond_isconjugated": [False, True], - "bond_inring": [False, True], - "bond_stereo": [ - "STEREONONE", - "STEREOANY", - "STEREOZ", - "STEREOE", - "STEREOCIS", - "STEREOTRANS", - ], -} - -atom_dic = [ - len(allowable_features["atomic_num"]), - len(allowable_features["formal_charge"]), - len(allowable_features["chirality"]), - len(allowable_features["hybridization"]), - len(allowable_features["numH"]), - len(allowable_features["implicit_valence"]), - len(allowable_features["degree"]), - len(allowable_features["isaromatic"]), -] -bond_dic = [ - len(allowable_features["bond_type"]), - len(allowable_features["bond_dirs"]), - len(allowable_features["bond_isconjugated"]), - len(allowable_features["bond_inring"]), - len(allowable_features["bond_stereo"]), -] -atom_cumsum = np.cumsum(atom_dic) -bond_cumsum = np.cumsum(bond_dic) - - -def mol_to_graph_data_obj_complex(mol: RDMol) -> Data: - """Convert an RDKit Mol into a torch_geometric ``Data`` object. - - The function encodes a fixed set of atom and bond categorical - features and returns a ``Data`` instance with ``x``, ``edge_index`` - and ``edge_attr`` fields. It mirrors the feature layout expected by - the MolGNet implementation used in this repository. - - :param mol: RDKit ``Mol`` instance. Must not be ``None``. - :return: A ``torch_geometric.data.Data`` object with node and edge fields. - :raises ValueError: If ``mol`` is ``None``. - """ - if mol is None: - raise ValueError("mol must not be None") - atom_features_list: list = [] - # Shortcuts for feature lists - fc_list = allowable_features["formal_charge"] - ch_list = allowable_features["chirality"] - hyb_list = allowable_features["hybridization"] - numh_list = allowable_features["numH"] - imp_list = allowable_features["implicit_valence"] - deg_list = allowable_features["degree"] - isa_list = allowable_features["isaromatic"] - bt_list = allowable_features["bond_type"] - bd_list = allowable_features["bond_dirs"] - bic_list = allowable_features["bond_isconjugated"] - bir_list = allowable_features["bond_inring"] - bs_list = allowable_features["bond_stereo"] - for atom in mol.GetAtoms(): - a_idx = allowable_features["atomic_num"].index(atom.GetAtomicNum()) - fc_idx = fc_list.index(atom.GetFormalCharge()) + atom_cumsum[0] - ch_idx = ch_list.index(atom.GetChiralTag()) + atom_cumsum[1] - hyb_idx = hyb_list.index(atom.GetHybridization()) + atom_cumsum[2] - numh_idx = numh_list.index(atom.GetTotalNumHs()) + atom_cumsum[3] - imp_idx = imp_list.index(atom.GetImplicitValence()) + atom_cumsum[4] - deg_idx = deg_list.index(atom.GetDegree()) + atom_cumsum[5] - isa_idx = isa_list.index(atom.GetIsAromatic()) + atom_cumsum[6] - - atom_feature = [ - a_idx, - fc_idx, - ch_idx, - hyb_idx, - numh_idx, - imp_idx, - deg_idx, - isa_idx, - ] - atom_features_list.append(atom_feature) - x = torch.tensor(np.array(atom_features_list), dtype=torch.long) - - # bonds - num_bond_features = 5 - if len(mol.GetBonds()) > 0: - edges_list = [] - edge_features_list = [] - for bond in mol.GetBonds(): - i = bond.GetBeginAtomIdx() - j = bond.GetEndAtomIdx() - bt = bt_list.index(bond.GetBondType()) - bd = bd_list.index(bond.GetBondDir()) + bond_cumsum[0] - bic = bic_list.index(bond.GetIsConjugated()) + bond_cumsum[1] - bir = bir_list.index(bond.IsInRing()) + bond_cumsum[2] - bs = bs_list.index(str(bond.GetStereo())) + bond_cumsum[3] - - edge_feature = [bt, bd, bic, bir, bs] - edges_list.append((i, j)) - edge_features_list.append(edge_feature) - edges_list.append((j, i)) - edge_features_list.append(edge_feature) - edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) - edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) - else: - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - return data - - -class SelfLoop: - """Callable that appends self-loops and matching edge attributes. - - This helper mutates the provided ``Data`` object by adding self-loop - entries to ``edge_index`` and a corresponding edge attribute row for - every node. - """ - - def __call__(self, data: Data) -> Data: - """Modify ``data`` in-place by adding self-loop indices and corresponding edge attributes. - - :param data: ``torch_geometric.data.Data`` to modify. - :return: The modified ``Data`` object (same instance). - """ - num_nodes = data.num_nodes - data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=num_nodes) - self_loop_attr = torch.LongTensor([0, 5, 8, 10, 12]).repeat(num_nodes, 1) - data.edge_attr = torch.cat((data.edge_attr, self_loop_attr), dim=0) - return data - - -class AddSegId: - """Attach zero-valued segment id tensors to nodes and edges. - - The created ``node_seg`` and ``edge_seg`` tensors are added to the - provided ``Data`` instance and used by the MolGNet embedding layers. - """ - - def __init__(self) -> None: - """Create an AddSegId callable (no parameters).""" - pass - - def __call__(self, data: Data) -> Data: - """Attach zero-filled ``node_seg`` and ``edge_seg`` tensors to ``data``. - - :param data: ``torch_geometric.data.Data`` to modify. - :return: The modified ``Data`` object (same instance). - """ - num_nodes = data.num_nodes - num_edges = data.num_edges - node_seg = [0 for _ in range(num_nodes)] - edge_seg = [0 for _ in range(num_edges)] - data.edge_seg = torch.LongTensor(edge_seg) - data.node_seg = torch.LongTensor(node_seg) - return data - - -# MolGNet model - - -class BertLayerNorm(nn.Module): - """Layer normalization compatible with BERT-style implementations. - - :param hidden_size: Dimension of the last axis to normalize. - :param eps: Small epsilon for numerical stability. - """ - - def __init__(self, hidden_size, eps=1e-12): - """Create a BertLayerNorm module. - - :param hidden_size: Dimension of the last axis to normalize. - :param eps: Small epsilon for numerical stability. - """ - super().__init__() - self.shape = torch.Size((hidden_size,)) - self.eps = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply layer normalization to the last dimension of ``x``. - - :param x: Input tensor. - :return: Normalized tensor with same shape as ``x``. - """ - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight * x + self.bias - return x - - -def gelu(x: torch.Tensor) -> torch.Tensor: - """Gaussian Error Linear Unit activation (approximation). - - :param x: Input tensor. - :return: Activated tensor. - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) - - -def bias_gelu(bias: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Apply GELU to ``bias + y``. - - :param bias: Bias tensor to add. - :param y: Linear output tensor. - :return: GELU applied to ``bias + y``. - """ - x = bias + y - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) - - -class LinearActivation(nn.Module): - """Linear layer with optional bias-aware GELU activation. - - :param in_features: Input feature dimension. - :param out_features: Output feature dimension. - :param bias: Whether to use a bias parameter and the biased GELU. - """ - - def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: - """ - Create a LinearActivation module. - - :param in_features: Input feature dimension. - :param out_features: Output feature dimension. - :param bias: Whether to use a bias parameter and the biased GELU. - """ - super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.biased_act_fn = bias_gelu - else: - self.act_fn = gelu - self.weight = Parameter(torch.Tensor(out_features, in_features)) - if bias: - self.bias = Parameter(torch.Tensor(out_features)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self) -> None: - """Initialize the layer parameters.""" - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - nn.init.uniform_(self.bias, -bound, bound) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """Apply the linear transformation and activation. - - :param input: Input tensor of shape [N, in_features]. - :return: Transformed tensor of shape [N, out_features]. - """ - if self.bias is not None: - linear_out = torch_nn_f.linear(input, self.weight, None) - return self.biased_act_fn(self.bias, linear_out) - else: - return self.act_fn(torch_nn_f.linear(input, self.weight, self.bias)) - - -class Intermediate(nn.Module): - """Intermediate feed-forward block used inside GT layers. - - :param hidden: Hidden dimension size. - """ - - def __init__(self, hidden: int) -> None: - """Create the intermediate dense activation block. - - :param hidden: Hidden dimension size. - """ - super().__init__() - self.dense_act = LinearActivation(hidden, 4 * hidden) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Apply the dense activation to the hidden states. - - :param hidden_states: Input tensor of shape [N, hidden]. - :return: Transformed tensor of shape [N, 4*hidden]. - """ - hidden_states = self.dense_act(hidden_states) - return hidden_states - - -class AttentionOut(nn.Module): - """Post-attention output block: projection, dropout and residual norm. - - :param hidden: Hidden dimension used for the linear projection. - :param dropout: Dropout probability. - """ - - def __init__(self, hidden: int, dropout: float) -> None: - """Create an AttentionOut block. - - :param hidden: Hidden dimension used for projection. - :param dropout: Dropout probability. - """ - super().__init__() - self.dense = nn.Linear(hidden, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.dropout = nn.Dropout(dropout) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - """Project attention outputs and apply layer norm with residual. - - :param hidden_states: Attention output tensor. - :param input_tensor: Residual tensor to add before normalization. - :return: Normalized tensor with the same shape as ``input_tensor``. - """ - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class GTOut(nn.Module): - """Output projection used in GT blocks. - - :param hidden: Hidden dimension. - :param dropout: Dropout probability. - """ - - def __init__(self, hidden: int, dropout: float) -> None: - """Create a GTOut projection block. - - :param hidden: Hidden dimension. - :param dropout: Dropout probability. - """ - super().__init__() - self.dense = nn.Linear(hidden * 4, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.dropout = nn.Dropout(dropout) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - """Project intermediate states back to hidden dimension and normalize. - - :param hidden_states: Intermediate tensor of shape [N, 4*hidden]. - :param input_tensor: Residual tensor to add. - :return: Tensor of shape [N, hidden]. - """ - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class MessagePassing(nn.Module): - """Minimal MessagePassing base class used by the MolGNet layers. - - This class provides a lightweight implementation of propagate/ - message/aggregate/update used in graph convolutions. - - :param aggr: Aggregation method (e.g., 'add', 'mean'). - :param flow: Message flow direction. - :param node_dim: Node dimension index (unused in this minimal impl). - """ - - def __init__(self, aggr: str = "add", flow: str = "source_to_target", node_dim: int = 0) -> None: - """Create a MessagePassing helper. - - :param aggr: Aggregation method (e.g., 'add' or 'mean'). - :param flow: Message flow direction. - :param node_dim: Node dimension index. - """ - super().__init__() - self.aggr = aggr - self.flow = flow - self.node_dim = node_dim - - def propagate(self, edge_index: torch.Tensor, size: Optional[tuple[int, int]] = None, **kwargs) -> torch.Tensor: - """Run full message-passing: message -> aggregate -> update. - - :param edge_index: Edge indices tensor of shape [2, E]. - :param size: Optional pair describing (num_nodes_source, num_nodes_target). - :param kwargs: Additional data (e.g., node features) needed for message computation. - :raises ValueError: If required inputs (e.g., 'x') are missing or indexing fails. - :return: Updated node tensor after aggregation. - """ - i = 1 if self.flow == "source_to_target" else 0 - j = 0 if i == 1 else 1 - x = kwargs.get("x") - if x is None: - raise ValueError("propagate requires node features passed as keyword 'x'") - try: - x_i = x[edge_index[i]] - x_j = x[edge_index[j]] - except Exception as exc: # defensive - raise ValueError("failed to index node features with edge_index") from exc - msg = self.message( - edge_index_i=edge_index[i], - edge_index_j=edge_index[j], - x_i=x_i, - x_j=x_j, - **kwargs, - ) - # determine number of destination nodes for aggregation - if hasattr(x, "size"): - dim_size = x.size(0) - else: - dim_size = len(x) - out = self.aggregate(msg, index=edge_index[i], dim_size=dim_size) - out = self.update(out) - return out - - def message(self, *args: Any, **kwargs: Any) -> torch.Tensor: - """Default message function returning neighbor features. - - Subclasses may provide richer signatures; this generic form allows - subclass overrides while keeping the base class typed. - - :param args: Positional arguments forwarded by propagate. - :param kwargs: Keyword arguments forwarded by propagate. - :raises ValueError: If required node features are not present. - :return: Message tensor. - """ - x_j = kwargs.get("x_j") if "x_j" in kwargs else (args[1] if len(args) > 1 else None) - if x_j is None: - raise ValueError("message requires node features 'x_j'") - return x_j - - def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, dim_size: Optional[int] = None) -> torch.Tensor: - """Aggregate messages using ``torch_scatter.scatter``. - - :param inputs: Message tensor of shape [E, hidden]. - :param index: Indices to aggregate into nodes. - :param dim_size: Optional target size for the aggregation dimension. - :return: Aggregated node tensor. - """ - from torch_scatter import scatter # local dependency - - return scatter( - inputs, - index, - dim=0, - dim_size=dim_size, - reduce=self.aggr, - ) - - def update(self, inputs: torch.Tensor) -> torch.Tensor: - """Identity update by default. - - Override to apply post-aggregation transformations. - - :param inputs: Aggregated node tensor. - :return: Updated tensor. - """ - return inputs - - -class GraphAttentionConv(MessagePassing): - """Graph attention convolution used by MolGNet. - - :param hidden: Hidden feature dimension. - :param heads: Number of attention heads. - :param dropout: Attention dropout probability. - """ - - def __init__(self, hidden: int, heads: int = 3, dropout: float = 0.0) -> None: - """Create a GraphAttentionConv. - - :param hidden: Hidden feature dimension. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :raises ValueError: If hidden is not divisible by heads. - """ - super().__init__() - self.hidden = hidden - self.heads = heads - if hidden % heads != 0: - raise ValueError("hidden must be divisible by heads") - self.query = nn.Linear(hidden, heads * int(hidden / heads)) - self.key = nn.Linear(hidden, heads * int(hidden / heads)) - self.value = nn.Linear(hidden, heads * int(hidden / heads)) - self.attn_drop = nn.Dropout(dropout) - - def forward( - self, - x: torch.Tensor, - edge_index: torch.Tensor, - edge_attr: torch.Tensor, - size: Optional[tuple[int, int]] = None, - ) -> torch.Tensor: - """Execute the graph attention conv over the provided inputs. - - :param x: Node feature tensor. - :param edge_index: Edge indices tensor. - :param edge_attr: Edge attribute tensor. - :param size: Optional size tuple. - :return: Updated node tensor after attention. - """ - pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr - return self.propagate(edge_index=edge_index, x=x, pseudo=pseudo) - - def message( - self, - edge_index_i: torch.Tensor, - x_i: torch.Tensor, - x_j: torch.Tensor, - pseudo: torch.Tensor, - size_i: Optional[int] = None, - **kwargs, - ) -> torch.Tensor: - """Compute messages using multi-head attention between nodes. - - :param edge_index_i: Source indices for edges. - :param x_i: Node features for source nodes. - :param x_j: Node features for target nodes. - :param pseudo: Edge pseudo-features (edge attributes). - :param size_i: Optional number of destination nodes. - :param kwargs: Additional keyword arguments (ignored). - :return: Message tensor shaped for aggregation. - """ - query = self.query(x_i).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - key = self.key(x_j + pseudo).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - value = self.value(x_j + pseudo).view( - -1, - self.heads, - int(self.hidden / self.heads), - ) - denom = math.sqrt(int(self.hidden / self.heads)) - alpha = (query * key).sum(dim=-1) / denom - alpha = softmax(src=alpha, index=edge_index_i, num_nodes=size_i) - alpha = self.attn_drop(alpha.view(-1, self.heads, 1)) - return alpha * value - - def update(self, aggr_out: torch.Tensor) -> torch.Tensor: - """Reshape aggregated outputs from multi-head to flat hidden dim. - - :param aggr_out: Aggregated output tensor of shape [N*heads, head_dim]. - :return: Reshaped tensor of shape [N, hidden]. - """ - aggr_out = aggr_out.view(-1, self.heads * int(self.hidden / self.heads)) - return aggr_out - - -class GTLayer(nn.Module): - """Graph Transformer layer composed from attention and feed-forward blocks. - - :param hidden: Hidden dimension size. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :param num_message_passing: Number of internal message passing steps. - """ - - def __init__(self, hidden: int, heads: int, dropout: float, num_message_passing: int) -> None: - """Create a GTLayer composed of attention and feed-forward blocks. - - :param hidden: Hidden dimension size. - :param heads: Number of attention heads. - :param dropout: Dropout probability. - :param num_message_passing: Number of internal message passing steps. - """ - super().__init__() - self.attention = GraphAttentionConv(hidden, heads, dropout) - self.att_out = AttentionOut(hidden, dropout) - self.intermediate = Intermediate(hidden) - self.output = GTOut(hidden, dropout) - self.gru = nn.GRU(hidden, hidden) - self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) - self.time_step = num_message_passing - - def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: - """Run the GT layer for the configured number of message-passing steps. - - :param x: Node feature tensor of shape [N, hidden]. - :param edge_index: Edge index tensor. - :param edge_attr: Edge attribute tensor. - :return: Updated node tensor of shape [N, hidden]. - """ - h = x.unsqueeze(0) - for _ in range(self.time_step): - attention_output = self.attention.forward(x, edge_index, edge_attr) - attention_output = self.att_out.forward(attention_output, x) - intermediate_output = self.intermediate.forward(attention_output) - m = self.output.forward(intermediate_output, attention_output) - x, h = self.gru(m.unsqueeze(0), h) - x = self.LayerNorm.forward(x.squeeze(0)) - return x - - -class MolGNet(torch.nn.Module): - """MolGNet model implementation used for node embeddings. - - This implementation is intentionally minimal and only includes the - components required to run a checkpoint and produce per-node - embeddings saved by the featurizer script. - - :param num_layer: Number of GT layers. - :param emb_dim: Embedding dimensionality per node. - :param heads: Number of attention heads. - :param num_message_passing: Message passing steps per layer. - :param drop_ratio: Dropout probability. - """ - - def __init__( - self, - num_layer: int, - emb_dim: int, - heads: int, - num_message_passing: int, - drop_ratio: float = 0, - ) -> None: - """Create a MolGNet instance. - - :param num_layer: Number of GT layers. - :param emb_dim: Embedding dimensionality per node. - :param heads: Number of attention heads. - :param num_message_passing: Message passing steps per layer. - :param drop_ratio: Dropout probability. - """ - super().__init__() - self.num_layer = num_layer - self.drop_ratio = drop_ratio - self.x_embedding = torch.nn.Embedding(178, emb_dim) - self.x_seg_embed = torch.nn.Embedding(3, emb_dim) - self.edge_embedding = torch.nn.Embedding(18, emb_dim) - self.edge_seg_embed = torch.nn.Embedding(3, emb_dim) - self.reset_parameters() - self.gnns = torch.nn.ModuleList( - [GTLayer(emb_dim, heads, drop_ratio, num_message_passing) for _ in range(num_layer)] - ) - - def reset_parameters(self) -> None: - """Re-initialize embedding parameters with Xavier uniform. - - This mirrors common initialization used for transformer-style - embeddings. - """ - torch.nn.init.xavier_uniform_(self.x_embedding.weight.data) - torch.nn.init.xavier_uniform_(self.x_seg_embed.weight.data) - torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data) - torch.nn.init.xavier_uniform_(self.edge_seg_embed.weight.data) - - def forward(self, *argv: Any) -> torch.Tensor: - """Forward pass supporting two calling conventions. - - Accepts either explicit tensors (x, edge_index, edge_attr, node_seg, - edge_seg) or a single ``Data`` object containing those attributes. - - :param argv: Positional arguments as described above. - :raises ValueError: If an unsupported number of arguments is provided. - :return: Node embeddings tensor of shape [N, emb_dim]. - """ - if len(argv) == 5: - x, edge_index, edge_attr, node_seg, edge_seg = (argv[0], argv[1], argv[2], argv[3], argv[4]) - elif len(argv) == 1: - data = argv[0] - x, edge_index, edge_attr, node_seg, edge_seg = ( - data.x, - data.edge_index, - data.edge_attr, - data.node_seg, - data.edge_seg, - ) - else: - raise ValueError("unmatched number of arguments.") - x = self.x_embedding(x).sum(1) + self.x_seg_embed(node_seg) - edge_attr = self.edge_embedding(edge_attr).sum(1) - edge_attr = edge_attr + self.edge_seg_embed(edge_seg) - for gnn in self.gnns: - x = gnn(x, edge_index, edge_attr) - return x - - -def tensor_to_csv_friendly(tensor: Any) -> np.ndarray: - """Convert a tensor-like object into a NumPy array safe for CSV output. - - :param tensor: Input tensor or array-like object. - :return: NumPy array on CPU. - """ - if isinstance(tensor, torch.Tensor): - return tensor.cpu().detach().numpy() - return np.array(tensor) - - -def run(args: argparse.Namespace) -> None: - """Execute the featurization pipeline for a given dataset. - - The function builds graphs from SMILES, runs the MolGNet checkpoint - to extract node embeddings, and writes per-drug CSVs and pickles in - the dataset folder. - - :param args: Parsed CLI arguments. - :raises FileNotFoundError: If expected files or directories are missing. - :raises ValueError: If expected columns are missing in the input CSV. - :raises Exception: For various failures during graph building or inference. - """ - # Use dataset-oriented paths: {data_path}/{dataset_name}/... - # Expand user (~) and resolve to an absolute path. - data_dir = Path(args.data_path).expanduser().resolve() - dataset_dir = data_dir / args.dataset_name - if not dataset_dir.exists(): - raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") - - out_graphs = str(dataset_dir / "GRAPH_dict.pkl") - out_molg = str(dataset_dir / "MolGNet_dict.pkl") - - # read input csv (expected at {data_path}/{dataset_name}/drug_smiles.csv) - smiles_csv = dataset_dir / "drug_smiles.csv" - if not smiles_csv.exists(): - raise FileNotFoundError(f"Expected SMILES CSV at: {smiles_csv}") - df = pd.read_csv(smiles_csv) - if args.smiles_col not in df.columns or args.id_col not in df.columns: - msg = f"Provided columns not in CSV: {args.smiles_col}, " f"{args.id_col}" - raise ValueError(msg) - df = df.dropna(subset=[args.smiles_col]) - smiles_map = dict(zip(df[args.id_col], df[args.smiles_col])) - - # Build graphs - graph_dict: dict[Any, Data] = {} - failed_conversions = [] - for idx, smi in tqdm(smiles_map.items(), desc="building graphs"): - mol = Chem.MolFromSmiles(smi) - if mol is None: - failed_conversions.append((idx, smi, "MolFromSmiles returned None")) - continue - try: - graph_dict[idx] = mol_to_graph_data_obj_complex(mol) - except Exception as e: - failed_conversions.append((idx, smi, str(e))) - if failed_conversions: - print(f"\n{len(failed_conversions)} molecules failed to convert to graphs.") - for idx, smi, err in failed_conversions: - print(f"Failed to convert {idx} (SMILES: {smi}): {err}") - else: - print("\nAll molecules converted to graphs successfully.") - # save graphs to dataset folder - with open(out_graphs, "wb") as f: - pickle.dump(graph_dict, f) - # load model - if args.device: - device = torch.device(args.device) - else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - num_layer = 5 - emb_dim = 768 - heads = 12 - msg_pass = 3 - drop = 0.0 - model = MolGNet( - num_layer=num_layer, - emb_dim=emb_dim, - heads=heads, - num_message_passing=msg_pass, - drop_ratio=drop, - ) - # Prefer pathlib operations when working with Path objects - checkpoint_path = data_dir / args.checkpoint - ckpt = torch.load(checkpoint_path, map_location=device) # noqa S614 - try: - model.load_state_dict(ckpt) - except Exception: - if isinstance(ckpt, dict) and "state_dict" in ckpt: - model.load_state_dict(ckpt["state_dict"]) - else: - raise - model = model.to(device) - model.eval() - - self_loop = SelfLoop() - add_seg = AddSegId() - - molgnet_dict: dict[Any, torch.Tensor] = {} - with torch.no_grad(): - for idx, graph in tqdm(graph_dict.items(), desc="running model"): - try: - g = self_loop(graph) - g = add_seg(g) - g = g.to(device) - emb = model(g) - molgnet_dict[idx] = emb.cpu() - except Exception as e: - print(f"Inference failed for {idx}: {e}") - - with open(out_molg, "wb") as f: - pickle.dump(molgnet_dict, f) - - # write per-drug CSVs to {dataset_dir}/DIPK_features/Drugs - out_drugs_dir = dataset_dir / "DIPK_features/Drugs" - os.makedirs(out_drugs_dir, exist_ok=True) - for idx, emb in tqdm(molgnet_dict.items(), desc="writing csvs"): - arr = tensor_to_csv_friendly(emb) - df_emb = pd.DataFrame(arr) - out_path = out_drugs_dir / f"MolGNet_{idx}.csv" - df_emb.to_csv(out_path, sep="\t", index=False) - - print("Done.") - print("Graphs saved to:", out_graphs) - print("Node embeddings saved to:", out_molg) - print("Per-drug CSVs in:", out_drugs_dir) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments. - - :return: Parsed arguments namespace. - """ - p = argparse.ArgumentParser(description=("Standalone MolGNet extractor " "(dataset-oriented)")) - p.add_argument( - "dataset_name", - help="Name of the dataset (folder under data_path)", - ) - p.add_argument( - "--data_path", - default="data", - help="Top-level data folder path", - ) - p.add_argument( - "--smiles-col", - dest="smiles_col", - default="canonical_smiles", - help="Column name for SMILES in input CSV", - ) - p.add_argument( - "--id-col", - dest="id_col", - default="pubchem_id", - help="Column name for unique ID in input CSV", - ) - p.add_argument( - "--checkpoint", - default="MolGNet.pt", - help="MolGNet checkpoint (state_dict), can be obtained from Zenodo: https://doi.org/10.5281/zenodo.12633909", - ) - p.add_argument( - "--device", - default=None, - help="torch device string, e.g. cpu or cuda:0", - ) - return p.parse_args() - - -if __name__ == "__main__": - args = parse_args() - run(args) diff --git a/drevalpy/datasets/featurizer/create_transcriptome_pca.py b/drevalpy/datasets/featurizer/create_transcriptome_pca.py deleted file mode 100644 index fd642577..00000000 --- a/drevalpy/datasets/featurizer/create_transcriptome_pca.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Preprocesses transcriptome (gene expression) data using PCA dimensionality reduction.""" - -import argparse -from pathlib import Path - -import joblib -import numpy as np -import pandas as pd -from sklearn.decomposition import PCA -from sklearn.preprocessing import StandardScaler - -from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER - - -def main(): - """Process transcriptome data and save PCA-transformed features. - - :raises FileNotFoundError: If the gene expression file is not found. - """ - parser = argparse.ArgumentParser(description="Preprocess transcriptome (gene expression) data using PCA.") - parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") - parser.add_argument( - "--n_components", - type=int, - default=100, - help="Number of principal components to keep (default: 100)", - ) - parser.add_argument( - "--data_path", - type=str, - default="data", - help="Path to the data folder (default: data)", - ) - parser.add_argument( - "--feature_type", - type=str, - default="gene_expression", - help="Type of transcriptome feature to use (default: gene_expression)", - ) - args = parser.parse_args() - - dataset_name = args.dataset_name - n_components = args.n_components - data_dir = Path(args.data_path).resolve() - feature_type = args.feature_type - - # Input file: gene expression CSV - input_file = data_dir / dataset_name / f"{feature_type}.csv" - # Output files: PCA features CSV and fitted PCA/scaler objects - output_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}.csv" - pca_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}_pca.pkl" - scaler_file = data_dir / dataset_name / f"cell_line_{feature_type}_pca_{n_components}_scaler.pkl" - - if not input_file.exists(): - raise FileNotFoundError(f"Error: {input_file} not found.") - - print(f"Loading transcriptome data from {input_file}...") - # Load gene expression data - # Format: rows are cell lines (indexed by cell_line_name), columns are genes - ge_df = pd.read_csv(input_file, index_col=CELL_LINE_IDENTIFIER) - ge_df.index = ge_df.index.astype(str) - - # Drop cellosaurus_id if present - if "cellosaurus_id" in ge_df.columns: - ge_df = ge_df.drop(columns=["cellosaurus_id"]) - - print(f"Loaded {len(ge_df)} cell lines with {len(ge_df.columns)} genes") - print(f"Performing PCA with {n_components} components...") - - # Extract cell line IDs and gene expression matrix - cell_line_ids = ge_df.index.values - gene_expression_matrix = ge_df.values.astype(np.float32) - - # Handle missing values: fill with 0 or mean (using 0 as default) - if np.isnan(gene_expression_matrix).any(): - print("Warning: Found NaN values. Filling with 0.") - gene_expression_matrix = np.nan_to_num(gene_expression_matrix, nan=0.0) - - # Standardize the data before PCA - scaler = StandardScaler() - gene_expression_scaled = scaler.fit_transform(gene_expression_matrix) - - # Perform PCA - pca = PCA(n_components=n_components) - pca_features = pca.fit_transform(gene_expression_scaled) - - print(f"PCA explained variance ratio: {pca.explained_variance_ratio_.sum():.4f}") - print(f"PCA explained variance (first 10 components): {pca.explained_variance_ratio_[:10]}") - - # Create output DataFrame - pca_df = pd.DataFrame( - pca_features, - index=cell_line_ids, - columns=[f"PC{i + 1}" for i in range(n_components)], - ) - pca_df.index.name = CELL_LINE_IDENTIFIER - pca_df = pca_df.reset_index() - - # Save PCA-transformed features - pca_df.to_csv(output_file, index=False) - print(f"PCA features saved to {output_file}") - - # Save fitted PCA and scaler for potential future use (e.g., transforming new data) - joblib.dump(pca, pca_file) - print(f"Fitted PCA model saved to {pca_file}") - - joblib.dump(scaler, scaler_file) - print(f"Fitted scaler saved to {scaler_file}") - - print("Finished processing transcriptome PCA featurization.") - - -if __name__ == "__main__": - main() diff --git a/drevalpy/datasets/featurizer/drug/__init__.py b/drevalpy/datasets/featurizer/drug/__init__.py new file mode 100644 index 00000000..c392efa4 --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/__init__.py @@ -0,0 +1,13 @@ +"""Drug featurizers for converting drug representations to embeddings.""" + +from .base import DrugFeaturizer +from .chemberta import ChemBERTaFeaturizer +from .drug_graph import DrugGraphFeaturizer +from .molgnet import MolGNetFeaturizer + +__all__ = [ + "DrugFeaturizer", + "ChemBERTaFeaturizer", + "DrugGraphFeaturizer", + "MolGNetFeaturizer", +] diff --git a/drevalpy/datasets/featurizer/drug/base.py b/drevalpy/datasets/featurizer/drug/base.py new file mode 100644 index 00000000..c03e334c --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/base.py @@ -0,0 +1,193 @@ +"""Abstract base class for drug featurizers.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + +from drevalpy.datasets.dataset import FeatureDataset +from drevalpy.datasets.utils import DRUG_IDENTIFIER + + +class DrugFeaturizer(ABC): + """Abstract base class for drug featurizers. + + Drug featurizers convert drug representations (e.g., SMILES strings) into + numerical embeddings that can be used as input features for machine learning models. + + Subclasses must implement: + - featurize(): Convert a single drug to its embedding + - get_feature_name(): Return the name of the feature view + - get_output_filename(): Return the filename for cached embeddings + + The base class provides: + - load_or_generate(): Load cached embeddings or generate and cache them + - generate_embeddings(): Generate embeddings for all drugs in a dataset + - load_embeddings(): Load pre-generated embeddings from disk + """ + + def __init__(self, device: str = "cpu"): + """Initialize the featurizer. + + :param device: Device to use for computation (e.g., 'cpu', 'cuda') + """ + self.device = device + + @abstractmethod + def featurize(self, smiles: str) -> np.ndarray | Any: + """Convert a SMILES string to a feature representation. + + :param smiles: SMILES string representing the drug + :returns: Feature representation (numpy array or other format like torch_geometric.Data) + """ + + @classmethod + @abstractmethod + def get_feature_name(cls) -> str: + """Return the name of the feature view. + + This name is used as the key in the FeatureDataset. + + :returns: Feature view name (e.g., 'chemberta_embeddings') + """ + + @classmethod + @abstractmethod + def get_output_filename(cls) -> str: + """Return the filename for cached embeddings. + + :returns: Filename (e.g., 'drug_chemberta_embeddings.csv') + """ + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached embeddings or generate and cache them if not available. + + This is the main entry point for using a featurizer. It checks if + pre-generated embeddings exist and loads them, otherwise generates + new embeddings and saves them for future use. + + :param data_path: Path to the data directory (e.g., 'data/') + :param dataset_name: Name of the dataset (e.g., 'GDSC1') + :returns: FeatureDataset containing the drug embeddings + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists(): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Embeddings not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Generate embeddings for all drugs in a dataset and save to disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the generated embeddings + :raises FileNotFoundError: If the drug_smiles.csv file is not found + """ + data_dir = Path(data_path).resolve() + smiles_file = data_dir / dataset_name / "drug_smiles.csv" + output_file = data_dir / dataset_name / self.get_output_filename() + + if not smiles_file.exists(): + raise FileNotFoundError(f"SMILES file not found: {smiles_file}") + + smiles_df = pd.read_csv(smiles_file, dtype={"canonical_smiles": str, DRUG_IDENTIFIER: str}) + + embeddings_list = [] + drug_ids = [] + + print(f"Processing {len(smiles_df)} drugs for dataset {dataset_name}...") + + for row in smiles_df.itertuples(index=False): + drug_id = getattr(row, DRUG_IDENTIFIER) + smiles = row.canonical_smiles + + try: + embedding = self.featurize(smiles) + embeddings_list.append(embedding) + drug_ids.append(drug_id) + except Exception as e: + print(f"Failed to process drug {drug_id} (SMILES: {smiles}): {e}") + continue + + # Save embeddings + self._save_embeddings(embeddings_list, drug_ids, output_file) + + print(f"Embeddings saved to {output_file}") + + # Return as FeatureDataset + return self._create_feature_dataset(embeddings_list, drug_ids) + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save embeddings to disk. + + Default implementation saves as CSV. Subclasses can override for other formats. + + :param embeddings: List of embedding arrays + :param drug_ids: List of drug identifiers + :param output_path: Path to save the embeddings + """ + embeddings_df = pd.DataFrame(embeddings) + embeddings_df.insert(0, DRUG_IDENTIFIER, drug_ids) + embeddings_df.to_csv(output_path, index=False) + + def _create_feature_dataset(self, embeddings: list, drug_ids: list[str]) -> FeatureDataset: + """Create a FeatureDataset from embeddings. + + :param embeddings: List of embedding arrays + :param drug_ids: List of drug identifiers + :returns: FeatureDataset containing the embeddings + """ + feature_name = self.get_feature_name() + features = {} + for drug_id, embedding in zip(drug_ids, embeddings, strict=True): + if isinstance(embedding, np.ndarray): + features[drug_id] = {feature_name: embedding.astype(np.float32)} + else: + features[drug_id] = {feature_name: embedding} + return FeatureDataset(features) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"Embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + embeddings_df = pd.read_csv(embeddings_file, dtype={DRUG_IDENTIFIER: str}) + feature_name = self.get_feature_name() + features = {} + + for _, row in embeddings_df.iterrows(): + drug_id = row[DRUG_IDENTIFIER] + embedding = row.drop(DRUG_IDENTIFIER).to_numpy(dtype=np.float32) + features[drug_id] = {feature_name: embedding} + + return FeatureDataset(features) + + +def main(): + """Entry point for running featurizer from command line. + + This function should be overridden by subclasses that support CLI usage. + + :raises NotImplementedError: Always, as subclasses should implement their own main() + """ + raise NotImplementedError("Subclasses should implement their own main() function") + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/chemberta.py b/drevalpy/datasets/featurizer/drug/chemberta.py new file mode 100644 index 00000000..8e564a0e --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/chemberta.py @@ -0,0 +1,101 @@ +"""ChemBERTa drug featurizer for generating embeddings from SMILES strings.""" + +import argparse + +import numpy as np + +from .base import DrugFeaturizer + + +class ChemBERTaFeaturizer(DrugFeaturizer): + """Featurizer that generates ChemBERTa embeddings from SMILES strings. + + ChemBERTa is a transformer model pre-trained on chemical SMILES strings. + This featurizer uses the model to generate fixed-size embeddings for drugs. + + Example usage:: + + featurizer = ChemBERTaFeaturizer(device="cuda") + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, device: str = "cpu"): + """Initialize the ChemBERTa featurizer. + + :param device: Device to use for computation ('cpu' or 'cuda') + """ + super().__init__(device=device) + self._tokenizer = None + self._model = None + + def _load_model(self): + """Lazily load the ChemBERTa model and tokenizer. + + :raises ImportError: If transformers or torch packages are not installed + """ + if self._model is None: + try: + import torch # noqa: F401 + from transformers import AutoModel, AutoTokenizer + except ImportError: + raise ImportError( + "Please install transformers package for ChemBERTa featurizer: pip install transformers torch" + ) + + self._tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") + self._model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") + self._model.to(self.device) + self._model.eval() + + def featurize(self, smiles: str) -> np.ndarray: + """Convert a SMILES string to a ChemBERTa embedding. + + :param smiles: SMILES string representing the drug + :returns: ChemBERTa embedding as numpy array + """ + import torch + + self._load_model() + + inputs = self._tokenizer(smiles, return_tensors="pt", truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._model(**inputs) + hidden_states = outputs.last_hidden_state + + # Mean pooling over sequence length + embedding = hidden_states.mean(dim=1).squeeze(0) + return embedding.cpu().numpy() + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'chemberta_embeddings' + """ + return "chemberta_embeddings" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output filename for cached embeddings. + + :returns: 'drug_chemberta_embeddings.csv' + """ + return "drug_chemberta_embeddings.csv" + + +def main(): + """Process drug SMILES and save ChemBERTa embeddings from command line.""" + parser = argparse.ArgumentParser(description="Generate ChemBERTa embeddings for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + args = parser.parse_args() + + featurizer = ChemBERTaFeaturizer(device=args.device) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/drug_graph.py b/drevalpy/datasets/featurizer/drug/drug_graph.py new file mode 100644 index 00000000..e76bde87 --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/drug_graph.py @@ -0,0 +1,220 @@ +"""Drug graph featurizer for converting SMILES to molecular graphs.""" + +import argparse +import os +from pathlib import Path + +import torch +from torch_geometric.data import Data + +from drevalpy.datasets.dataset import FeatureDataset + +from .base import DrugFeaturizer + +try: + from rdkit import Chem +except ImportError: + Chem = None + + +# Atom feature configuration +ATOM_FEATURES = { + "atomic_num": list(range(1, 119)), + "degree": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "formal_charge": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], + "num_hs": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "hybridization": [], # Will be populated after rdkit import check +} + +# Bond feature configuration +BOND_FEATURES = { + "bond_type": [], # Will be populated after rdkit import check +} + + +def _init_rdkit_features(): + """Initialize RDKit-dependent feature configurations. + + :raises ImportError: If rdkit package is not installed + """ + if Chem is None: + raise ImportError("Please install rdkit package for drug graphs featurizer: pip install rdkit") + + ATOM_FEATURES["hybridization"] = [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + ] + BOND_FEATURES["bond_type"] = [ + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, + ] + + +def _one_hot_encode(value, choices): + """Create a one-hot encoding for a value in a list of choices. + + :param value: The value to be one-hot encoded. + :param choices: A list of possible choices for the value. + :return: A list representing the one-hot encoding. + """ + encoding = [0] * (len(choices) + 1) + index = choices.index(value) if value in choices else -1 + encoding[index] = 1 + return encoding + + +class DrugGraphFeaturizer(DrugFeaturizer): + """Featurizer that converts SMILES strings to molecular graphs. + + The graphs are stored as torch_geometric.data.Data objects with: + - x: Node features (atom features) + - edge_index: Edge connectivity + - edge_attr: Edge features (bond features) + + Example usage:: + + featurizer = DrugGraphFeaturizer() + features = featurizer.load_or_generate("data", "GDSC1") + """ + + def __init__(self, device: str = "cpu"): + """Initialize the drug graph featurizer. + + :param device: Device to use (not used for graph generation, but kept for API consistency) + """ + super().__init__(device=device) + _init_rdkit_features() + + def featurize(self, smiles: str) -> Data | None: + """Convert a SMILES string to a molecular graph. + + :param smiles: SMILES string representing the drug + :returns: torch_geometric.data.Data object or None if conversion fails + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + + # Atom features + atom_features_list = [] + for atom in mol.GetAtoms(): + features = [] + features.extend(_one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURES["atomic_num"])) + features.extend(_one_hot_encode(atom.GetDegree(), ATOM_FEATURES["degree"])) + features.extend(_one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURES["formal_charge"])) + features.extend(_one_hot_encode(atom.GetTotalNumHs(), ATOM_FEATURES["num_hs"])) + features.extend(_one_hot_encode(atom.GetHybridization(), ATOM_FEATURES["hybridization"])) + features.append(atom.GetIsAromatic()) + features.append(atom.IsInRing()) + atom_features_list.append(features) + x = torch.tensor(atom_features_list, dtype=torch.float) + + # Edge index and edge features + edge_indices = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + + # Edge features + features = [] + features.extend(_one_hot_encode(bond.GetBondType(), BOND_FEATURES["bond_type"])) + features.append(bond.GetIsConjugated()) + features.append(bond.IsInRing()) + + edge_indices.extend([[i, j], [j, i]]) + edge_features_list.extend([features, features]) # Same features for both directions + + edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() + edge_attr = torch.tensor(edge_features_list, dtype=torch.float) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'drug_graphs' + """ + return "drug_graphs" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output directory name for cached graphs. + + :returns: 'drug_graphs' + """ + return "drug_graphs" + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save graph embeddings to disk as individual .pt files. + + :param embeddings: List of Data objects + :param drug_ids: List of drug identifiers + :param output_path: Directory path to save the graphs + """ + os.makedirs(output_path, exist_ok=True) + for drug_id, graph in zip(drug_ids, embeddings, strict=True): + if graph is not None: + torch.save(graph, output_path / f"{drug_id}.pt") + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated graph embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the graph embeddings + :raises FileNotFoundError: If the graphs directory is not found + """ + graphs_dir = Path(data_path) / dataset_name / self.get_output_filename() + + if not graphs_dir.exists(): + raise FileNotFoundError( + f"Graphs directory not found: {graphs_dir}. " + f"Use load_or_generate() to automatically generate graphs." + ) + + feature_name = self.get_feature_name() + features = {} + + for graph_file in graphs_dir.glob("*.pt"): + drug_id = graph_file.stem + graph = torch.load(graph_file) # noqa: S614 + features[drug_id] = {feature_name: graph} + + return FeatureDataset(features) + + def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load cached graphs or generate and cache them if not available. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the drug graphs + """ + output_path = Path(data_path) / dataset_name / self.get_output_filename() + + if output_path.exists() and any(output_path.glob("*.pt")): + return self.load_embeddings(data_path, dataset_name) + else: + print(f"Graphs not found at {output_path}. Generating...") + return self.generate_embeddings(data_path, dataset_name) + + +def main(): + """Process drug SMILES and save molecular graphs from command line.""" + parser = argparse.ArgumentParser(description="Generate molecular graphs for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + args = parser.parse_args() + + featurizer = DrugGraphFeaturizer() + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/datasets/featurizer/drug/molgnet.py b/drevalpy/datasets/featurizer/drug/molgnet.py new file mode 100644 index 00000000..6bce7467 --- /dev/null +++ b/drevalpy/datasets/featurizer/drug/molgnet.py @@ -0,0 +1,817 @@ +"""MolGNet drug featurizer for generating graph-based embeddings. + +This module provides a featurizer that uses the MolGNet model to generate +node embeddings for molecules. It requires a pre-trained MolGNet checkpoint. +""" + +import argparse +import math +import os +import pickle # noqa: S403 +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as torch_nn_f +from torch import nn +from torch.nn import Parameter +from torch_geometric.data import Data +from torch_geometric.utils import add_self_loops, softmax + +from drevalpy.datasets.dataset import FeatureDataset + +from .base import DrugFeaturizer + +try: + from rdkit import Chem + from rdkit.Chem.rdchem import Mol as RDMol +except ImportError: + Chem = None + RDMol = None + + +# Feature configuration for MolGNet graph building +allowable_features: dict[str, list[Any]] = { + "atomic_num": list(range(1, 122)), + "formal_charge": ["unk", -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], + "chirality": [], # Populated after rdkit import check + "hybridization": [], # Populated after rdkit import check + "numH": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8], + "implicit_valence": ["unk", 0, 1, 2, 3, 4, 5, 6], + "degree": ["unk", 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "isaromatic": [False, True], + "bond_type": [], # Populated after rdkit import check + "bond_dirs": [], # Populated after rdkit import check + "bond_isconjugated": [False, True], + "bond_inring": [False, True], + "bond_stereo": [ + "STEREONONE", + "STEREOANY", + "STEREOZ", + "STEREOE", + "STEREOCIS", + "STEREOTRANS", + ], +} + + +def _init_rdkit_features(): + """Initialize RDKit-dependent feature configurations. + + :raises ImportError: If rdkit package is not installed + """ + if Chem is None: + raise ImportError("Please install rdkit package for MolGNet featurizer: pip install rdkit") + + allowable_features["chirality"] = [ + "unk", + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER, + ] + allowable_features["hybridization"] = [ + "unk", + Chem.rdchem.HybridizationType.S, + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.UNSPECIFIED, + ] + allowable_features["bond_type"] = [ + "unk", + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, + ] + allowable_features["bond_dirs"] = [ + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT, + ] + + +# Compute cumulative sums for feature indexing +atom_dic = [ + len(allowable_features["atomic_num"]), + 12, # formal_charge + 5, # chirality + 8, # hybridization + 10, # numH + 7, # implicit_valence + 12, # degree + 2, # isaromatic +] +bond_dic = [ + 5, # bond_type + 3, # bond_dirs + 2, # bond_isconjugated + 2, # bond_inring + 6, # bond_stereo +] +atom_cumsum = np.cumsum(atom_dic) +bond_cumsum = np.cumsum(bond_dic) + + +def mol_to_graph_data_obj_complex(mol: "RDMol") -> Data: + """Convert an RDKit Mol into a torch_geometric Data object for MolGNet. + + :param mol: RDKit Mol instance + :returns: torch_geometric.data.Data object + :raises ValueError: If mol is None + """ + if mol is None: + raise ValueError("mol must not be None") + + _init_rdkit_features() + + atom_features_list: list = [] + fc_list = allowable_features["formal_charge"] + ch_list = allowable_features["chirality"] + hyb_list = allowable_features["hybridization"] + numh_list = allowable_features["numH"] + imp_list = allowable_features["implicit_valence"] + deg_list = allowable_features["degree"] + isa_list = allowable_features["isaromatic"] + bt_list = allowable_features["bond_type"] + bd_list = allowable_features["bond_dirs"] + bic_list = allowable_features["bond_isconjugated"] + bir_list = allowable_features["bond_inring"] + bs_list = allowable_features["bond_stereo"] + + for atom in mol.GetAtoms(): + a_idx = allowable_features["atomic_num"].index(atom.GetAtomicNum()) + fc_idx = fc_list.index(atom.GetFormalCharge()) + atom_cumsum[0] + ch_idx = ch_list.index(atom.GetChiralTag()) + atom_cumsum[1] + hyb_idx = hyb_list.index(atom.GetHybridization()) + atom_cumsum[2] + numh_idx = numh_list.index(atom.GetTotalNumHs()) + atom_cumsum[3] + imp_idx = imp_list.index(atom.GetImplicitValence()) + atom_cumsum[4] + deg_idx = deg_list.index(atom.GetDegree()) + atom_cumsum[5] + isa_idx = isa_list.index(atom.GetIsAromatic()) + atom_cumsum[6] + + atom_feature = [a_idx, fc_idx, ch_idx, hyb_idx, numh_idx, imp_idx, deg_idx, isa_idx] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + num_bond_features = 5 + if len(mol.GetBonds()) > 0: + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + bt = bt_list.index(bond.GetBondType()) + bd = bd_list.index(bond.GetBondDir()) + bond_cumsum[0] + bic = bic_list.index(bond.GetIsConjugated()) + bond_cumsum[1] + bir = bir_list.index(bond.IsInRing()) + bond_cumsum[2] + bs = bs_list.index(str(bond.GetStereo())) + bond_cumsum[3] + + edge_feature = [bt, bd, bic, bir, bs] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + else: + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + +class SelfLoop: + """Callable that appends self-loops and matching edge attributes.""" + + def __call__(self, data: Data) -> Data: + """Add self-loop indices and corresponding edge attributes. + + :param data: torch_geometric.data.Data to modify + :returns: Modified Data object + """ + num_nodes = data.num_nodes + data.edge_index, _ = add_self_loops(data.edge_index, num_nodes=num_nodes) + self_loop_attr = torch.LongTensor([0, 5, 8, 10, 12]).repeat(num_nodes, 1) + data.edge_attr = torch.cat((data.edge_attr, self_loop_attr), dim=0) + return data + + +class AddSegId: + """Attach zero-valued segment id tensors to nodes and edges.""" + + def __call__(self, data: Data) -> Data: + """Attach zero-filled node_seg and edge_seg tensors. + + :param data: torch_geometric.data.Data to modify + :returns: Modified Data object + """ + num_nodes = data.num_nodes + num_edges = data.num_edges + data.edge_seg = torch.LongTensor([0] * num_edges) + data.node_seg = torch.LongTensor([0] * num_nodes) + return data + + +# MolGNet model components + + +class BertLayerNorm(nn.Module): + """Layer normalization compatible with BERT-style implementations.""" + + def __init__(self, hidden_size: int, eps: float = 1e-12) -> None: + """Initialize the layer normalization. + + :param hidden_size: Size of the hidden dimension + :param eps: Small constant for numerical stability + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer normalization. + + :param x: Input tensor + :returns: Normalized tensor + """ + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight * x + self.bias + + +def gelu(x: torch.Tensor) -> torch.Tensor: + """Apply Gaussian Error Linear Unit activation. + + :param x: Input tensor + :returns: Activated tensor + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) + + +def bias_gelu(bias: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Apply GELU activation to bias + y. + + :param bias: Bias tensor + :param y: Input tensor + :returns: Activated tensor + """ + x = bias + y + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2))) + + +class LinearActivation(nn.Module): + """Linear layer with optional bias-aware GELU activation.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: + """Initialize the linear activation layer. + + :param in_features: Number of input features + :param out_features: Number of output features + :param bias: Whether to include a bias term + """ + super().__init__() + self.in_features = in_features + self.out_features = out_features + if bias: + self.biased_act_fn = bias_gelu + else: + self.act_fn = gelu + self.weight = Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset layer parameters using Kaiming initialization.""" + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Apply linear transformation with GELU activation. + + :param input: Input tensor + :returns: Transformed tensor + """ + if self.bias is not None: + linear_out = torch_nn_f.linear(input, self.weight, None) + return self.biased_act_fn(self.bias, linear_out) + else: + return self.act_fn(torch_nn_f.linear(input, self.weight, self.bias)) + + +class Intermediate(nn.Module): + """Intermediate feed-forward block used inside GT layers.""" + + def __init__(self, hidden: int) -> None: + """Initialize the intermediate layer. + + :param hidden: Hidden dimension size + """ + super().__init__() + self.dense_act = LinearActivation(hidden, 4 * hidden) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply feed-forward transformation. + + :param hidden_states: Input tensor + :returns: Transformed tensor + """ + return self.dense_act(hidden_states) + + +class AttentionOut(nn.Module): + """Post-attention output block: projection, dropout and residual norm.""" + + def __init__(self, hidden: int, dropout: float) -> None: + """Initialize the attention output layer. + + :param hidden: Hidden dimension size + :param dropout: Dropout probability + """ + super().__init__() + self.dense = nn.Linear(hidden, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + """Apply output transformation with residual connection. + + :param hidden_states: Attention output tensor + :param input_tensor: Original input for residual connection + :returns: Transformed tensor + """ + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return self.LayerNorm(hidden_states + input_tensor) + + +class GTOut(nn.Module): + """Output projection used in GT blocks.""" + + def __init__(self, hidden: int, dropout: float) -> None: + """Initialize the GT output layer. + + :param hidden: Hidden dimension size + :param dropout: Dropout probability + """ + super().__init__() + self.dense = nn.Linear(hidden * 4, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + """Apply output transformation with residual connection. + + :param hidden_states: Intermediate output tensor + :param input_tensor: Original input for residual connection + :returns: Transformed tensor + """ + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return self.LayerNorm(hidden_states + input_tensor) + + +class MessagePassing(nn.Module): + """Minimal MessagePassing base class used by the MolGNet layers.""" + + def __init__(self, aggr: str = "add", flow: str = "source_to_target", node_dim: int = 0) -> None: + """Initialize the message passing layer. + + :param aggr: Aggregation method ('add', 'mean', 'max') + :param flow: Direction of message flow + :param node_dim: Dimension along which to aggregate + """ + super().__init__() + self.aggr = aggr + self.flow = flow + self.node_dim = node_dim + + def propagate(self, edge_index: torch.Tensor, size: Optional[tuple[int, int]] = None, **kwargs) -> torch.Tensor: + """Propagate messages along edges. + + :param edge_index: Edge connectivity tensor + :param size: Optional size tuple for bipartite graphs + :param kwargs: Additional arguments including node features 'x' + :returns: Aggregated messages + :raises ValueError: If node features 'x' are not provided + """ + i = 1 if self.flow == "source_to_target" else 0 + j = 0 if i == 1 else 1 + x = kwargs.get("x") + if x is None: + raise ValueError("propagate requires node features passed as keyword 'x'") + x_i = x[edge_index[i]] + x_j = x[edge_index[j]] + msg = self.message( + edge_index_i=edge_index[i], + edge_index_j=edge_index[j], + x_i=x_i, + x_j=x_j, + **kwargs, + ) + dim_size = x.size(0) if hasattr(x, "size") else len(x) + out = self.aggregate(msg, index=edge_index[i], dim_size=dim_size) + return self.update(out) + + def message(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Compute messages for each edge. + + :param args: Positional arguments + :param kwargs: Keyword arguments including 'x_j' for source node features + :returns: Message tensor + :raises ValueError: If 'x_j' is not provided + """ + x_j = kwargs.get("x_j") if "x_j" in kwargs else (args[1] if len(args) > 1 else None) + if x_j is None: + raise ValueError("message requires node features 'x_j'") + return x_j + + def aggregate(self, inputs: torch.Tensor, index: torch.Tensor, dim_size: Optional[int] = None) -> torch.Tensor: + """Aggregate messages at target nodes. + + :param inputs: Message tensor + :param index: Target node indices + :param dim_size: Number of target nodes + :returns: Aggregated tensor + """ + from torch_scatter import scatter + + return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr) + + def update(self, inputs: torch.Tensor) -> torch.Tensor: + """Update node representations after aggregation. + + :param inputs: Aggregated messages + :returns: Updated node representations + """ + return inputs + + +class GraphAttentionConv(MessagePassing): + """Graph attention convolution used by MolGNet.""" + + def __init__(self, hidden: int, heads: int = 3, dropout: float = 0.0) -> None: + """Initialize the graph attention convolution. + + :param hidden: Hidden dimension size + :param heads: Number of attention heads + :param dropout: Dropout probability + :raises ValueError: If hidden is not divisible by heads + """ + super().__init__() + self.hidden = hidden + self.heads = heads + if hidden % heads != 0: + raise ValueError("hidden must be divisible by heads") + self.query = nn.Linear(hidden, heads * int(hidden / heads)) + self.key = nn.Linear(hidden, heads * int(hidden / heads)) + self.value = nn.Linear(hidden, heads * int(hidden / heads)) + self.attn_drop = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + size: Optional[tuple[int, int]] = None, + ) -> torch.Tensor: + """Apply graph attention convolution. + + :param x: Node feature tensor + :param edge_index: Edge connectivity tensor + :param edge_attr: Edge attribute tensor + :param size: Optional size tuple for bipartite graphs + :returns: Updated node features + """ + pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr + return self.propagate(edge_index=edge_index, x=x, pseudo=pseudo) + + def message( + self, + edge_index_i: torch.Tensor, + x_i: torch.Tensor, + x_j: torch.Tensor, + pseudo: torch.Tensor, + size_i: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: + """Compute attention-weighted messages. + + :param edge_index_i: Target node indices + :param x_i: Target node features + :param x_j: Source node features + :param pseudo: Edge features + :param size_i: Number of target nodes + :param kwargs: Additional arguments + :returns: Attention-weighted messages + """ + query = self.query(x_i).view(-1, self.heads, int(self.hidden / self.heads)) + key = self.key(x_j + pseudo).view(-1, self.heads, int(self.hidden / self.heads)) + value = self.value(x_j + pseudo).view(-1, self.heads, int(self.hidden / self.heads)) + denom = math.sqrt(int(self.hidden / self.heads)) + alpha = (query * key).sum(dim=-1) / denom + alpha = softmax(src=alpha, index=edge_index_i, num_nodes=size_i) + alpha = self.attn_drop(alpha.view(-1, self.heads, 1)) + return alpha * value + + def update(self, aggr_out: torch.Tensor) -> torch.Tensor: + """Reshape aggregated output. + + :param aggr_out: Aggregated attention output + :returns: Reshaped tensor + """ + return aggr_out.view(-1, self.heads * int(self.hidden / self.heads)) + + +class GTLayer(nn.Module): + """Graph Transformer layer composed from attention and feed-forward blocks.""" + + def __init__(self, hidden: int, heads: int, dropout: float, num_message_passing: int) -> None: + """Initialize the Graph Transformer layer. + + :param hidden: Hidden dimension size + :param heads: Number of attention heads + :param dropout: Dropout probability + :param num_message_passing: Number of message passing iterations + """ + super().__init__() + self.attention = GraphAttentionConv(hidden, heads, dropout) + self.att_out = AttentionOut(hidden, dropout) + self.intermediate = Intermediate(hidden) + self.output = GTOut(hidden, dropout) + self.gru = nn.GRU(hidden, hidden) + self.LayerNorm = BertLayerNorm(hidden, eps=1e-12) + self.time_step = num_message_passing + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor: + """Apply Graph Transformer layer. + + :param x: Node feature tensor + :param edge_index: Edge connectivity tensor + :param edge_attr: Edge attribute tensor + :returns: Updated node features + """ + h = x.unsqueeze(0) + for _ in range(self.time_step): + attention_output = self.attention.forward(x, edge_index, edge_attr) + attention_output = self.att_out.forward(attention_output, x) + intermediate_output = self.intermediate.forward(attention_output) + m = self.output.forward(intermediate_output, attention_output) + x, h = self.gru(m.unsqueeze(0), h) + x = self.LayerNorm.forward(x.squeeze(0)) + return x + + +class MolGNet(torch.nn.Module): + """MolGNet model implementation used for node embeddings.""" + + def __init__( + self, + num_layer: int, + emb_dim: int, + heads: int, + num_message_passing: int, + drop_ratio: float = 0, + ) -> None: + """Initialize the MolGNet model. + + :param num_layer: Number of Graph Transformer layers + :param emb_dim: Embedding dimension + :param heads: Number of attention heads + :param num_message_passing: Number of message passing iterations per layer + :param drop_ratio: Dropout ratio + """ + super().__init__() + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.x_embedding = torch.nn.Embedding(178, emb_dim) + self.x_seg_embed = torch.nn.Embedding(3, emb_dim) + self.edge_embedding = torch.nn.Embedding(18, emb_dim) + self.edge_seg_embed = torch.nn.Embedding(3, emb_dim) + self.reset_parameters() + self.gnns = torch.nn.ModuleList( + [GTLayer(emb_dim, heads, drop_ratio, num_message_passing) for _ in range(num_layer)] + ) + + def reset_parameters(self) -> None: + """Reset model parameters using Xavier initialization.""" + torch.nn.init.xavier_uniform_(self.x_embedding.weight.data) + torch.nn.init.xavier_uniform_(self.x_seg_embed.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embedding.weight.data) + torch.nn.init.xavier_uniform_(self.edge_seg_embed.weight.data) + + def forward(self, *argv: Any) -> torch.Tensor: + """Forward pass through the MolGNet model. + + :param argv: Either 5 tensors (x, edge_index, edge_attr, node_seg, edge_seg) + or a single Data object + :returns: Node embeddings + :raises ValueError: If incorrect number of arguments provided + """ + if len(argv) == 5: + x, edge_index, edge_attr, node_seg, edge_seg = argv + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr, node_seg, edge_seg = ( + data.x, + data.edge_index, + data.edge_attr, + data.node_seg, + data.edge_seg, + ) + else: + raise ValueError("unmatched number of arguments.") + x = self.x_embedding(x).sum(1) + self.x_seg_embed(node_seg) + edge_attr = self.edge_embedding(edge_attr).sum(1) + edge_attr = edge_attr + self.edge_seg_embed(edge_seg) + for gnn in self.gnns: + x = gnn(x, edge_index, edge_attr) + return x + + +class MolGNetFeaturizer(DrugFeaturizer): + """Featurizer that generates MolGNet node embeddings from SMILES strings. + + MolGNet is a graph neural network that produces per-node embeddings for + molecules. This featurizer requires a pre-trained MolGNet checkpoint. + + Example usage:: + + featurizer = MolGNetFeaturizer(checkpoint_path="data/MolGNet.pt", device="cuda") + features = featurizer.load_or_generate("data", "GDSC1") + """ + + # Default model hyperparameters + NUM_LAYER = 5 + EMB_DIM = 768 + HEADS = 12 + MSG_PASS = 3 + DROP = 0.0 + + def __init__(self, checkpoint_path: str = "data/MolGNet.pt", device: str = "cpu"): + """Initialize the MolGNet featurizer. + + :param checkpoint_path: Path to the MolGNet checkpoint file + :param device: Device to use for computation ('cpu' or 'cuda') + """ + super().__init__(device=device) + self.checkpoint_path = checkpoint_path + self._model = None + self._self_loop = SelfLoop() + self._add_seg = AddSegId() + + def _load_model(self): + """Lazily load the MolGNet model. + + :raises Exception: If checkpoint loading fails + """ + if self._model is None: + _init_rdkit_features() + + self._model = MolGNet( + num_layer=self.NUM_LAYER, + emb_dim=self.EMB_DIM, + heads=self.HEADS, + num_message_passing=self.MSG_PASS, + drop_ratio=self.DROP, + ) + + device = torch.device(self.device) + ckpt = torch.load(self.checkpoint_path, map_location=device) # noqa: S614 + try: + self._model.load_state_dict(ckpt) + except Exception: + if isinstance(ckpt, dict) and "state_dict" in ckpt: + self._model.load_state_dict(ckpt["state_dict"]) + else: + raise + + self._model = self._model.to(device) + self._model.eval() + + def featurize(self, smiles: str) -> torch.Tensor | None: + """Convert a SMILES string to MolGNet node embeddings. + + :param smiles: SMILES string representing the drug + :returns: Node embeddings tensor or None if conversion fails + """ + _init_rdkit_features() + self._load_model() + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + + graph = mol_to_graph_data_obj_complex(mol) + graph = self._self_loop(graph) + graph = self._add_seg(graph) + graph = graph.to(self.device) + + with torch.no_grad(): + embeddings = self._model(graph) + + return embeddings.cpu() + + @classmethod + def get_feature_name(cls) -> str: + """Return the feature view name. + + :returns: 'molgnet_embeddings' + """ + return "molgnet_embeddings" + + @classmethod + def get_output_filename(cls) -> str: + """Return the output filename for cached embeddings. + + :returns: 'MolGNet_dict.pkl' + """ + return "MolGNet_dict.pkl" + + def _save_embeddings(self, embeddings: list, drug_ids: list[str], output_path: Path) -> None: + """Save MolGNet embeddings to disk as a pickle file. + + :param embeddings: List of embedding tensors + :param drug_ids: List of drug identifiers + :param output_path: Path to save the embeddings + """ + molgnet_dict = {} + for drug_id, emb in zip(drug_ids, embeddings, strict=True): + if emb is not None: + molgnet_dict[drug_id] = emb + + with open(output_path, "wb") as f: + pickle.dump(molgnet_dict, f) + + # Also save per-drug CSVs for DIPK compatibility + dataset_dir = output_path.parent + out_drugs_dir = dataset_dir / "DIPK_features" / "Drugs" + os.makedirs(out_drugs_dir, exist_ok=True) + + for drug_id, emb in molgnet_dict.items(): + arr = emb.cpu().detach().numpy() if isinstance(emb, torch.Tensor) else np.array(emb) + df_emb = pd.DataFrame(arr) + out_csv = out_drugs_dir / f"MolGNet_{drug_id}.csv" + df_emb.to_csv(out_csv, sep="\t", index=False) + + def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load pre-generated MolGNet embeddings from disk. + + :param data_path: Path to the data directory + :param dataset_name: Name of the dataset + :returns: FeatureDataset containing the embeddings + :raises FileNotFoundError: If the embeddings file is not found + """ + embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() + + if not embeddings_file.exists(): + raise FileNotFoundError( + f"MolGNet embeddings file not found: {embeddings_file}. " + f"Use load_or_generate() to automatically generate embeddings." + ) + + with open(embeddings_file, "rb") as f: + molgnet_dict = pickle.load(f) # noqa: S301 + + feature_name = self.get_feature_name() + features = {} + + for drug_id, emb in molgnet_dict.items(): + features[str(drug_id)] = {feature_name: emb} + + return FeatureDataset(features) + + +def main(): + """Process drug SMILES and save MolGNet embeddings from command line.""" + parser = argparse.ArgumentParser(description="Generate MolGNet embeddings for drugs.") + parser.add_argument("dataset_name", type=str, help="The name of the dataset to process.") + parser.add_argument("--data_path", type=str, default="data", help="Path to the data folder") + parser.add_argument( + "--checkpoint", + type=str, + default="data/MolGNet.pt", + help="Path to MolGNet checkpoint (can be obtained from Zenodo: https://doi.org/10.5281/zenodo.12633909)", + ) + parser.add_argument("--device", type=str, default="cpu", help="Torch device (cpu or cuda)") + args = parser.parse_args() + + featurizer = MolGNetFeaturizer(checkpoint_path=args.checkpoint, device=args.device) + featurizer.generate_embeddings(args.data_path, args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index 89f0fb7d..de379e36 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -441,7 +441,7 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase if not graph_path.exists(): raise FileNotFoundError( f"Drug graph directory not found at {graph_path}. " - f"Please run 'create_drug_graphs.py' for the {dataset_name} dataset." + f"Please use DrugGraphFeaturizer to generate graphs for the {dataset_name} dataset." ) drug_graphs = {} diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 837e02a9..9fed75da 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -7,12 +7,11 @@ import joblib import numpy as np -import pandas as pd import torch from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset -from drevalpy.datasets.utils import CELL_LINE_IDENTIFIER +from drevalpy.datasets.featurizer import ChemBERTaFeaturizer, PCAFeaturizer from ..drp_model import DRPModel from ..utils import load_and_select_gene_features, load_drug_fingerprint_features, scale_gene_expression @@ -273,26 +272,15 @@ def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDatase """ Loads the ChemBERTa embeddings. + Uses the ChemBERTaFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + :param data_path: Path to the ChemBERTa embeddings, e.g., data/ :param dataset_name: name of the dataset, e.g., GDSC1 :returns: FeatureDataset containing the ChemBERTa embeddings - :raises FileNotFoundError: if the ChemBERTa embeddings file is not found """ - chemberta_file = os.path.join(data_path, dataset_name, "drug_chemberta_embeddings.csv") - if not os.path.exists(chemberta_file): - raise FileNotFoundError( - f"ChemBERTa embeddings file not found: {chemberta_file}. " - "Please create it first with the respective drug_featurizer." - ) - - chemberta_df = pd.read_csv(chemberta_file, dtype={"pubchem_id": str}) - features = {} - for _, row in chemberta_df.iterrows(): - drug_id = row["pubchem_id"] - embedding = row.drop("pubchem_id").to_numpy(dtype=np.float32) - features[drug_id] = {"chemberta_embeddings": embedding} - - return FeatureDataset(features) + featurizer = ChemBERTaFeaturizer(device="cuda" if torch.cuda.is_available() else "cpu") + return featurizer.load_or_generate(data_path, dataset_name) class PCANeuralNetwork(SimpleNeuralNetwork): @@ -313,24 +301,13 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD """ Loads the PCA-transformed gene expression features. + Uses the PCAFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + :param data_path: Path to the data, e.g., data/ :param dataset_name: name of the dataset, e.g., GDSC1 :returns: FeatureDataset containing the PCA features - :raises FileNotFoundError: if the PCA features file is not found """ n_components = self.hyperparameters.get("n_components", 100) - pca_file = os.path.join(data_path, dataset_name, f"cell_line_gene_expression_pca_{n_components}.csv") - if not os.path.exists(pca_file): - raise FileNotFoundError( - f"PCA features file not found: {pca_file}. " - f"Please create it first with create_transcriptome_pca.py using --n_components {n_components}." - ) - - pca_df = pd.read_csv(pca_file, dtype={CELL_LINE_IDENTIFIER: str}) - features = {} - for _, row in pca_df.iterrows(): - cell_line_id = row[CELL_LINE_IDENTIFIER] - embedding = row.drop(CELL_LINE_IDENTIFIER).to_numpy(dtype=np.float32) - features[cell_line_id] = {"gene_expression_pca": embedding} - - return FeatureDataset(features) + featurizer = PCAFeaturizer(n_components=n_components, omics_types="gene_expression") + return featurizer.load_or_generate(data_path, dataset_name) diff --git a/tests/test_featurizers.py b/tests/test_featurizers.py index 08d78368..8e766194 100644 --- a/tests/test_featurizers.py +++ b/tests/test_featurizers.py @@ -1,10 +1,11 @@ -"""Tests for drug featurizers.""" +"""Tests for drug and cell line featurizers.""" import sys from unittest.mock import patch import numpy as np import pandas as pd +import pytest import torch @@ -15,10 +16,10 @@ def test_chemberta_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_chemberta_drug_embeddings as chemberta + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer except ImportError: - print("transformers package not installed; skipping ChemBERTa featurizer test.") - return + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer test.") + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -27,20 +28,54 @@ def test_chemberta_featurizer(tmp_path): df = pd.DataFrame({"pubchem_id": ["X1"], "canonical_smiles": ["CCO"]}) (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) - fake_embedding = [1.0, 2.0, 3.0] + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) - with patch.object(chemberta, "_smiles_to_chemberta", return_value=fake_embedding), patch.object( - sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)] - ): + featurizer = ChemBERTaFeaturizer(device="cpu") - chemberta.main() + with patch.object(featurizer, "featurize", return_value=fake_embedding): + result = featurizer.generate_embeddings(str(tmp_path), dataset) out_file = data_dir / "drug_chemberta_embeddings.csv" assert out_file.exists() df_out = pd.read_csv(out_file) assert df_out.pubchem_id.tolist() == ["X1"] - assert df_out.iloc[0, 1:].tolist() == fake_embedding + assert df_out.iloc[0, 1:].tolist() == fake_embedding.tolist() + + # Test that FeatureDataset is returned correctly + assert "X1" in result.features + assert "chemberta_embeddings" in result.features["X1"] + + +def test_chemberta_featurizer_cli(tmp_path): + """ + Test ChemBERTa featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import chemberta + except ImportError: + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # fake input CSV + df = pd.DataFrame({"pubchem_id": ["X1"], "canonical_smiles": ["CCO"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + with ( + patch.object(chemberta.ChemBERTaFeaturizer, "featurize", return_value=fake_embedding), + patch.object(sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)]), + ): + chemberta.main() + + out_file = data_dir / "drug_chemberta_embeddings.csv" + assert out_file.exists() def test_graph_featurizer(tmp_path): @@ -50,10 +85,10 @@ def test_graph_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_drug_graphs as graphs + from drevalpy.datasets.featurizer import DrugGraphFeaturizer except ImportError: - print("rdkit package not installed; skipping graph featurizer test.") - return + pytest.skip("rdkit package not installed; skipping graph featurizer test.") + dataset = "testset" data_dir = tmp_path / dataset data_dir.mkdir(parents=True) @@ -62,9 +97,39 @@ def test_graph_featurizer(tmp_path): df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) - # run main exactly as the script would - sys.argv = ["prog", dataset, "--data_path", str(tmp_path)] - graphs.main() + featurizer = DrugGraphFeaturizer() + result = featurizer.generate_embeddings(str(tmp_path), dataset) + + # expected output file + out_file = data_dir / "drug_graphs" / "D1.pt" + assert out_file.exists() + + # Test that FeatureDataset is returned correctly + assert "D1" in result.features + assert "drug_graphs" in result.features["D1"] + + +def test_graph_featurizer_cli(tmp_path): + """ + Test graph featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import drug_graph + except ImportError: + pytest.skip("rdkit package not installed; skipping graph featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # write minimal SMILES CSV + df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + with patch.object(sys, "argv", ["prog", dataset, "--data_path", str(tmp_path)]): + drug_graph.main() # expected output file out_file = data_dir / "drug_graphs" / "D1.pt" @@ -78,10 +143,54 @@ def test_molgnet_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_molgnet_embeddings as molg + from drevalpy.datasets.featurizer import MolGNetFeaturizer + from drevalpy.datasets.featurizer.drug import molgnet except ImportError: - print("rdkit package not installed; skipping molgnet featurizer test.") - return + pytest.skip("rdkit package not installed; skipping molgnet featurizer test.") + + ds = "testset" + ds_dir = tmp_path / ds + ds_dir.mkdir(parents=True) + + # minimal SMILES CSV + df = pd.DataFrame({"pubchem_id": ["D1"], "canonical_smiles": ["CCO"]}) + (ds_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + # Create a fake checkpoint file + checkpoint_path = str(tmp_path / "MolGNet.pt") + + featurizer = MolGNetFeaturizer(checkpoint_path=checkpoint_path, device="cpu") + + with ( + # we dont need real model weights for this test, takes too long to load + patch("drevalpy.datasets.featurizer.drug.molgnet.torch.load", return_value={}), + # prevent load_state_dict from complaining + patch.object(molgnet.MolGNet, "load_state_dict", return_value=None), + # cheap forward pass + patch.object(molgnet.MolGNet, "forward", return_value=torch.zeros((1, 768))), + ): + result = featurizer.generate_embeddings(str(tmp_path), ds) + + # verify outputs + assert (ds_dir / "DIPK_features/Drugs" / "MolGNet_D1.csv").exists() + assert (ds_dir / "MolGNet_dict.pkl").exists() + + # Test that FeatureDataset is returned correctly + assert "D1" in result.features + assert "molgnet_embeddings" in result.features["D1"] + + +def test_molgnet_featurizer_cli(tmp_path): + """ + Test MolGNet featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.drug import molgnet + except ImportError: + pytest.skip("rdkit package not installed; skipping molgnet featurizer CLI test.") + ds = "testset" ds_dir = tmp_path / ds ds_dir.mkdir(parents=True) @@ -92,22 +201,19 @@ def test_molgnet_featurizer(tmp_path): with ( # we dont need real model weights for this test, takes too long to load - patch("drevalpy.datasets.featurizer.create_molgnet_embeddings.torch.load", return_value={}), + patch("drevalpy.datasets.featurizer.drug.molgnet.torch.load", return_value={}), # prevent load_state_dict from complaining - patch.object(molg.MolGNet, "load_state_dict", return_value=None), + patch.object(molgnet.MolGNet, "load_state_dict", return_value=None), # cheap forward pass - patch.object(molg.MolGNet, "forward", return_value=torch.zeros((1, 768))), - # avoid writing pickles - patch.object(molg.pickle, "dump", return_value=None), + patch.object(molgnet.MolGNet, "forward", return_value=torch.zeros((1, 768))), # simulate CLI patch.object( sys, "argv", - ["prog", ds, "--data_path", str(tmp_path), "--checkpoint", "MolGNet.pt"], + ["prog", ds, "--data_path", str(tmp_path), "--checkpoint", str(tmp_path / "MolGNet.pt")], ), ): - args = molg.parse_args() - molg.run(args) + molgnet.main() # verify outputs assert (ds_dir / "DIPK_features/Drugs" / "MolGNet_D1.csv").exists() @@ -166,10 +272,9 @@ def test_transcriptome_pca_featurizer(tmp_path): :param tmp_path: Temporary path provided by pytest. """ try: - import drevalpy.datasets.featurizer.create_transcriptome_pca as pca_feat + from drevalpy.datasets.featurizer import PCAFeaturizer except ImportError: - print("sklearn package not installed; skipping transcriptome PCA featurizer test.") - return + pytest.skip("sklearn package not installed; skipping transcriptome PCA featurizer test.") dataset = "testset" data_dir = tmp_path / dataset @@ -193,25 +298,153 @@ def test_transcriptome_pca_featurizer(tmp_path): (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) # Run the featurizer - with patch.object( - sys, - "argv", - ["prog", dataset, "--data_path", str(tmp_path), "--n_components", "10"], - ): - pca_feat.main() + featurizer = PCAFeaturizer(n_components=10) + result = featurizer.generate_embeddings(str(tmp_path), dataset) # Check output files output_file = data_dir / "cell_line_gene_expression_pca_10.csv" - pca_file = data_dir / "cell_line_gene_expression_pca_10_pca.pkl" - scaler_file = data_dir / "cell_line_gene_expression_pca_10_scaler.pkl" + model_file = data_dir / "cell_line_gene_expression_pca_10_models.pkl" assert output_file.exists() - assert pca_file.exists() - assert scaler_file.exists() + assert model_file.exists() # Verify output CSV structure df_out = pd.read_csv(output_file) assert "cell_line_name" in df_out.columns assert len(df_out.columns) == 11 # cell_line_name + 10 PC columns assert len(df_out) == n_cell_lines - assert all(f"PC{i + 1}" in df_out.columns for i in range(10)) + + # Test that FeatureDataset is returned correctly + assert "CL0" in result.features + assert "gene_expression_pca" in result.features["CL0"] + assert result.features["CL0"]["gene_expression_pca"].shape == (10,) + + +def test_pca_featurizer_cli(tmp_path): + """ + Test transcriptome PCA featurizer CLI entry point. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer.cell_line import pca + except ImportError: + pytest.skip("sklearn package not installed; skipping transcriptome PCA featurizer CLI test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # Run the featurizer via CLI + with patch.object( + sys, + "argv", + ["prog", dataset, "--data_path", str(tmp_path), "--n_components", "10"], + ): + pca.main() + + # Check output files + output_file = data_dir / "cell_line_gene_expression_pca_10.csv" + assert output_file.exists() + + +def test_pca_featurizer_load_or_generate(tmp_path): + """ + Test that load_or_generate loads existing embeddings or generates new ones. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer import PCAFeaturizer + except ImportError: + pytest.skip("sklearn package not installed; skipping PCA featurizer test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # Create fake gene expression CSV + n_cell_lines = 10 + n_genes = 100 + cell_line_names = [f"CL{i}" for i in range(n_cell_lines)] + gene_names = [f"GENE{i}" for i in range(n_genes)] + + np.random.seed(42) + ge_data = np.random.randn(n_cell_lines, n_genes).astype(np.float32) + + ge_df = pd.DataFrame(ge_data, index=cell_line_names, columns=gene_names) + ge_df.index.name = "cell_line_name" + ge_df = ge_df.reset_index() + + (data_dir / "gene_expression.csv").write_text(ge_df.to_csv(index=False)) + + # First call should generate embeddings + featurizer1 = PCAFeaturizer(n_components=10) + result1 = featurizer1.load_or_generate(str(tmp_path), dataset) + + # Second call should load existing embeddings + featurizer2 = PCAFeaturizer(n_components=10) + result2 = featurizer2.load_or_generate(str(tmp_path), dataset) + + # Results should be the same + assert set(result1.features.keys()) == set(result2.features.keys()) + for cell_line_id in result1.features: + np.testing.assert_array_almost_equal( + result1.features[cell_line_id]["gene_expression_pca"], + result2.features[cell_line_id]["gene_expression_pca"], + ) + + +def test_chemberta_featurizer_load_or_generate(tmp_path): + """ + Test that load_or_generate loads existing embeddings or generates new ones. + + :param tmp_path: Temporary path provided by pytest. + """ + try: + from drevalpy.datasets.featurizer import ChemBERTaFeaturizer + except ImportError: + pytest.skip("transformers package not installed; skipping ChemBERTa featurizer test.") + + dataset = "testset" + data_dir = tmp_path / dataset + data_dir.mkdir(parents=True) + + # fake input CSV + df = pd.DataFrame({"pubchem_id": ["X1", "X2"], "canonical_smiles": ["CCO", "CC"]}) + (data_dir / "drug_smiles.csv").write_text(df.to_csv(index=False)) + + fake_embedding = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + featurizer1 = ChemBERTaFeaturizer(device="cpu") + + # First call should generate embeddings + with patch.object(featurizer1, "featurize", return_value=fake_embedding): + result1 = featurizer1.load_or_generate(str(tmp_path), dataset) + + # Second call should load existing embeddings (no mock needed) + featurizer2 = ChemBERTaFeaturizer(device="cpu") + result2 = featurizer2.load_or_generate(str(tmp_path), dataset) + + # Results should be the same + assert set(result1.features.keys()) == set(result2.features.keys()) + for drug_id in result1.features: + np.testing.assert_array_almost_equal( + result1.features[drug_id]["chemberta_embeddings"], + result2.features[drug_id]["chemberta_embeddings"], + ) From 77bba71adb34eee83c5d81536b3edcd5dfee050d Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 15:33:51 +0000 Subject: [PATCH 11/15] Fix failing tests --- .gitignore | 1 + drevalpy/datasets/featurizer/cell_line/pca.py | 8 ++------ .../models/SimpleNeuralNetwork/simple_neural_network.py | 2 +- tests/test_featurizers.py | 1 - 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 5e873b1f..e41b2127 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/lightning_logs/ # Translations *.mo diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py index 25b095ad..630a2570 100644 --- a/drevalpy/datasets/featurizer/cell_line/pca.py +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -148,7 +148,7 @@ def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: :param data_path: Path to the data directory :param dataset_name: Name of the dataset :returns: FeatureDataset containing the embeddings - :raises FileNotFoundError: If the embeddings or model file is not found + :raises FileNotFoundError: If the embeddings file is not found """ embeddings_file = Path(data_path) / dataset_name / self.get_output_filename() model_file = Path(data_path) / dataset_name / self._get_model_filename() @@ -159,17 +159,13 @@ def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: f"Use load_or_generate() to automatically generate embeddings." ) - # Load fitted models if available + # Load fitted models if available (optional - only needed for transforming new data) if model_file.exists(): with open(model_file, "rb") as f: models = pickle.load(f) # noqa: S301 self._scaler = models["scaler"] self._pca = models["pca"] self._fitted = True - else: - raise FileNotFoundError( - f"Fitted model file not found: {model_file}. " f"Use generate_embeddings() to fit and save model." - ) # Load embeddings embeddings_df = pd.read_csv(embeddings_file, dtype={CELL_LINE_IDENTIFIER: str}) diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 9fed75da..a5fe20bc 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -309,5 +309,5 @@ def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureD :returns: FeatureDataset containing the PCA features """ n_components = self.hyperparameters.get("n_components", 100) - featurizer = PCAFeaturizer(n_components=n_components, omics_types="gene_expression") + featurizer = PCAFeaturizer(n_components=n_components) return featurizer.load_or_generate(data_path, dataset_name) diff --git a/tests/test_featurizers.py b/tests/test_featurizers.py index 8e766194..f27e178a 100644 --- a/tests/test_featurizers.py +++ b/tests/test_featurizers.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import pytest -import torch def test_chemberta_featurizer(tmp_path): From ea77c5b69f82ab441e391ddd0c51416ae3c99909 Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 16:04:59 +0000 Subject: [PATCH 12/15] Implement featurizer mixins --- drevalpy/datasets/featurizer/__init__.py | 25 +++++++++++ .../datasets/featurizer/cell_line/__init__.py | 3 +- drevalpy/datasets/featurizer/cell_line/pca.py | 41 +++++++++++++++++ drevalpy/datasets/featurizer/drug/__init__.py | 9 ++-- .../datasets/featurizer/drug/chemberta.py | 45 +++++++++++++++++-- .../datasets/featurizer/drug/drug_graph.py | 30 +++++++++++++ drevalpy/datasets/featurizer/drug/molgnet.py | 41 +++++++++++++++++ .../simple_neural_network.py | 35 ++------------- 8 files changed, 190 insertions(+), 39 deletions(-) diff --git a/drevalpy/datasets/featurizer/__init__.py b/drevalpy/datasets/featurizer/__init__.py index 3a178ca9..5ba97ed4 100644 --- a/drevalpy/datasets/featurizer/__init__.py +++ b/drevalpy/datasets/featurizer/__init__.py @@ -13,6 +13,12 @@ - CellLineFeaturizer: Abstract base class for cell line featurizers - PCAFeaturizer: PCA dimensionality reduction for omics data +Mixins for DRP Models: + - ChemBERTaMixin: Provides load_drug_features using ChemBERTa + - DrugGraphMixin: Provides load_drug_features using DrugGraphFeaturizer + - MolGNetMixin: Provides load_drug_features using MolGNet + - PCAMixin: Provides load_cell_line_features using PCA + Example usage:: from drevalpy.datasets.featurizer import ChemBERTaFeaturizer, PCAFeaturizer @@ -24,20 +30,34 @@ # Cell line features cell_featurizer = PCAFeaturizer(n_components=100) cell_features = cell_featurizer.load_or_generate("data", "GDSC1") + +Example using mixins in a model:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer import ChemBERTaMixin, PCAMixin + + class MyModel(ChemBERTaMixin, PCAMixin, DRPModel): + # ChemBERTaMixin provides load_drug_features + # PCAMixin provides load_cell_line_features + ... """ # Cell line featurizers from .cell_line import ( CellLineFeaturizer, PCAFeaturizer, + PCAMixin, ) # Drug featurizers from .drug import ( ChemBERTaFeaturizer, + ChemBERTaMixin, DrugFeaturizer, DrugGraphFeaturizer, + DrugGraphMixin, MolGNetFeaturizer, + MolGNetMixin, ) __all__ = [ @@ -49,4 +69,9 @@ # Cell line featurizers "CellLineFeaturizer", "PCAFeaturizer", + # Mixins + "ChemBERTaMixin", + "DrugGraphMixin", + "MolGNetMixin", + "PCAMixin", ] diff --git a/drevalpy/datasets/featurizer/cell_line/__init__.py b/drevalpy/datasets/featurizer/cell_line/__init__.py index 47a5f613..8c11b2ba 100644 --- a/drevalpy/datasets/featurizer/cell_line/__init__.py +++ b/drevalpy/datasets/featurizer/cell_line/__init__.py @@ -1,9 +1,10 @@ """Cell line featurizers for converting omics data to embeddings.""" from .base import CellLineFeaturizer -from .pca import PCAFeaturizer +from .pca import PCAFeaturizer, PCAMixin __all__ = [ "CellLineFeaturizer", "PCAFeaturizer", + "PCAMixin", ] diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py index 630a2570..9813625e 100644 --- a/drevalpy/datasets/featurizer/cell_line/pca.py +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -180,6 +180,47 @@ def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: return FeatureDataset(features) +class PCAMixin: + """Mixin that provides PCA-transformed gene expression loading for DRP models. + + This mixin implements load_cell_line_features using the PCAFeaturizer. + It automatically generates embeddings if they don't exist. + + The number of PCA components can be configured via: + - hyperparameters['n_components'] (if the model has hyperparameters) + - pca_n_components class attribute (default: 100) + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.cell_line.pca import PCAMixin + + class MyModel(PCAMixin, DRPModel): + cell_line_views = ["gene_expression_pca"] + ... + """ + + pca_n_components: int = 100 + + def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load PCA-transformed gene expression features. + + Uses the PCAFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the PCA-transformed gene expression + """ + # Try to get n_components from hyperparameters if available + n_components = self.pca_n_components + if hasattr(self, "hyperparameters") and self.hyperparameters is not None: + n_components = self.hyperparameters.get("n_components", n_components) + + featurizer = PCAFeaturizer(n_components=n_components) + return featurizer.load_or_generate(data_path, dataset_name) + + def main(): """Generate PCA embeddings for cell line gene expression from command line.""" parser = argparse.ArgumentParser(description="Generate PCA embeddings for cell line gene expression.") diff --git a/drevalpy/datasets/featurizer/drug/__init__.py b/drevalpy/datasets/featurizer/drug/__init__.py index c392efa4..c7c5fbd7 100644 --- a/drevalpy/datasets/featurizer/drug/__init__.py +++ b/drevalpy/datasets/featurizer/drug/__init__.py @@ -1,13 +1,16 @@ """Drug featurizers for converting drug representations to embeddings.""" from .base import DrugFeaturizer -from .chemberta import ChemBERTaFeaturizer -from .drug_graph import DrugGraphFeaturizer -from .molgnet import MolGNetFeaturizer +from .chemberta import ChemBERTaFeaturizer, ChemBERTaMixin +from .drug_graph import DrugGraphFeaturizer, DrugGraphMixin +from .molgnet import MolGNetFeaturizer, MolGNetMixin __all__ = [ "DrugFeaturizer", "ChemBERTaFeaturizer", + "ChemBERTaMixin", "DrugGraphFeaturizer", + "DrugGraphMixin", "MolGNetFeaturizer", + "MolGNetMixin", ] diff --git a/drevalpy/datasets/featurizer/drug/chemberta.py b/drevalpy/datasets/featurizer/drug/chemberta.py index 8e564a0e..a49b6a3b 100644 --- a/drevalpy/datasets/featurizer/drug/chemberta.py +++ b/drevalpy/datasets/featurizer/drug/chemberta.py @@ -3,6 +3,9 @@ import argparse import numpy as np +import torch + +from drevalpy.datasets.dataset import FeatureDataset from .base import DrugFeaturizer @@ -35,7 +38,6 @@ def _load_model(self): """ if self._model is None: try: - import torch # noqa: F401 from transformers import AutoModel, AutoTokenizer except ImportError: raise ImportError( @@ -53,8 +55,6 @@ def featurize(self, smiles: str) -> np.ndarray: :param smiles: SMILES string representing the drug :returns: ChemBERTa embedding as numpy array """ - import torch - self._load_model() inputs = self._tokenizer(smiles, return_tensors="pt", truncation=True) @@ -85,6 +85,45 @@ def get_output_filename(cls) -> str: return "drug_chemberta_embeddings.csv" +class ChemBERTaMixin: + """Mixin that provides ChemBERTa drug embeddings loading for DRP models. + + This mixin implements load_drug_features using the ChemBERTaFeaturizer. + It automatically generates embeddings if they don't exist. + + Class attributes that can be overridden: + - chemberta_device: Device for ChemBERTa model ('cpu', 'cuda', or 'auto') + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.chemberta import ChemBERTaMixin + + class MyModel(ChemBERTaMixin, DRPModel): + drug_views = ["chemberta_embeddings"] + ... + """ + + chemberta_device: str = "auto" + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load ChemBERTa drug embeddings. + + Uses the ChemBERTaFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the ChemBERTa embeddings + """ + device = self.chemberta_device + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + featurizer = ChemBERTaFeaturizer(device=device) + return featurizer.load_or_generate(data_path, dataset_name) + + def main(): """Process drug SMILES and save ChemBERTa embeddings from command line.""" parser = argparse.ArgumentParser(description="Generate ChemBERTa embeddings for drugs.") diff --git a/drevalpy/datasets/featurizer/drug/drug_graph.py b/drevalpy/datasets/featurizer/drug/drug_graph.py index e76bde87..6692d8b7 100644 --- a/drevalpy/datasets/featurizer/drug/drug_graph.py +++ b/drevalpy/datasets/featurizer/drug/drug_graph.py @@ -205,6 +205,36 @@ def load_or_generate(self, data_path: str, dataset_name: str) -> FeatureDataset: return self.generate_embeddings(data_path, dataset_name) +class DrugGraphMixin: + """Mixin that provides drug graph loading for DRP models. + + This mixin implements load_drug_features using the DrugGraphFeaturizer. + It automatically generates graphs if they don't exist. + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.drug_graph import DrugGraphMixin + + class MyModel(DrugGraphMixin, DRPModel): + drug_views = ["drug_graphs"] + ... + """ + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load drug graph features. + + Uses the DrugGraphFeaturizer to load pre-generated graphs or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the drug graphs + """ + featurizer = DrugGraphFeaturizer() + return featurizer.load_or_generate(data_path, dataset_name) + + def main(): """Process drug SMILES and save molecular graphs from command line.""" parser = argparse.ArgumentParser(description="Generate molecular graphs for drugs.") diff --git a/drevalpy/datasets/featurizer/drug/molgnet.py b/drevalpy/datasets/featurizer/drug/molgnet.py index 6bce7467..e0e03884 100644 --- a/drevalpy/datasets/featurizer/drug/molgnet.py +++ b/drevalpy/datasets/featurizer/drug/molgnet.py @@ -795,6 +795,47 @@ def load_embeddings(self, data_path: str, dataset_name: str) -> FeatureDataset: return FeatureDataset(features) +class MolGNetMixin: + """Mixin that provides MolGNet drug embeddings loading for DRP models. + + This mixin implements load_drug_features using the MolGNetFeaturizer. + It automatically generates embeddings if they don't exist. + + Class attributes that can be overridden: + - molgnet_checkpoint_path: Path to MolGNet checkpoint (default: 'data/MolGNet.pt') + - molgnet_device: Device for MolGNet model ('cpu', 'cuda', or 'auto') + + Example usage:: + + from drevalpy.models.drp_model import DRPModel + from drevalpy.datasets.featurizer.drug.molgnet import MolGNetMixin + + class MyModel(MolGNetMixin, DRPModel): + drug_views = ["molgnet_embeddings"] + ... + """ + + molgnet_checkpoint_path: str = "data/MolGNet.pt" + molgnet_device: str = "auto" + + def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: + """Load MolGNet drug embeddings. + + Uses the MolGNetFeaturizer to load pre-generated embeddings or generate + them automatically if they don't exist. + + :param data_path: Path to the data directory, e.g., 'data/' + :param dataset_name: Name of the dataset, e.g., 'GDSC1' + :returns: FeatureDataset containing the MolGNet embeddings + """ + device = self.molgnet_device + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + + featurizer = MolGNetFeaturizer(checkpoint_path=self.molgnet_checkpoint_path, device=device) + return featurizer.load_or_generate(data_path, dataset_name) + + def main(): """Process drug SMILES and save MolGNet embeddings from command line.""" parser = argparse.ArgumentParser(description="Generate MolGNet embeddings for drugs.") diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index a5fe20bc..6185f4bc 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -11,7 +11,7 @@ from sklearn.preprocessing import StandardScaler from drevalpy.datasets.dataset import DrugResponseDataset, FeatureDataset -from drevalpy.datasets.featurizer import ChemBERTaFeaturizer, PCAFeaturizer +from drevalpy.datasets.featurizer import ChemBERTaMixin, PCAMixin from ..drp_model import DRPModel from ..utils import load_and_select_gene_features, load_drug_fingerprint_features, scale_gene_expression @@ -254,7 +254,7 @@ def load(cls, directory: str) -> "SimpleNeuralNetwork": return instance -class ChemBERTaNeuralNetwork(SimpleNeuralNetwork): +class ChemBERTaNeuralNetwork(ChemBERTaMixin, SimpleNeuralNetwork): """ChemBERTa Neural Network model using gene expression and ChemBERTa drug embeddings.""" drug_views = ["chemberta_embeddings"] @@ -268,22 +268,8 @@ def get_model_name(cls) -> str: """ return "ChemBERTaNeuralNetwork" - def load_drug_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - """ - Loads the ChemBERTa embeddings. - - Uses the ChemBERTaFeaturizer to load pre-generated embeddings or generate - them automatically if they don't exist. - - :param data_path: Path to the ChemBERTa embeddings, e.g., data/ - :param dataset_name: name of the dataset, e.g., GDSC1 - :returns: FeatureDataset containing the ChemBERTa embeddings - """ - featurizer = ChemBERTaFeaturizer(device="cuda" if torch.cuda.is_available() else "cpu") - return featurizer.load_or_generate(data_path, dataset_name) - -class PCANeuralNetwork(SimpleNeuralNetwork): +class PCANeuralNetwork(PCAMixin, SimpleNeuralNetwork): """Neural Network model using PCA-transformed gene expression and fingerprints.""" cell_line_views = ["gene_expression_pca"] @@ -296,18 +282,3 @@ def get_model_name(cls) -> str: :returns: PCANeuralNetwork """ return "PCANeuralNetwork" - - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: - """ - Loads the PCA-transformed gene expression features. - - Uses the PCAFeaturizer to load pre-generated embeddings or generate - them automatically if they don't exist. - - :param data_path: Path to the data, e.g., data/ - :param dataset_name: name of the dataset, e.g., GDSC1 - :returns: FeatureDataset containing the PCA features - """ - n_components = self.hyperparameters.get("n_components", 100) - featurizer = PCAFeaturizer(n_components=n_components) - return featurizer.load_or_generate(data_path, dataset_name) From eb7c0bbb7287869a205166380890835e5ecbd0cb Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 16:18:24 +0000 Subject: [PATCH 13/15] Remove redundant get_model_name functions --- drevalpy/models/DrugGNN/drug_gnn.py | 8 --- drevalpy/models/MOLIR/molir.py | 9 ---- drevalpy/models/SRMF/srmf.py | 9 ---- .../multiomics_neural_network.py | 9 ---- .../simple_neural_network.py | 27 ---------- drevalpy/models/SuperFELTR/superfeltr.py | 9 ---- .../baselines/multi_omics_random_forest.py | 9 ---- drevalpy/models/baselines/naive_pred.py | 54 ------------------- .../baselines/singledrug_elastic_net.py | 18 ------- .../baselines/singledrug_random_forest.py | 18 ------- drevalpy/models/baselines/sklearn_models.py | 36 ------------- drevalpy/models/drp_model.py | 4 +- 12 files changed, 2 insertions(+), 208 deletions(-) diff --git a/drevalpy/models/DrugGNN/drug_gnn.py b/drevalpy/models/DrugGNN/drug_gnn.py index de379e36..ff8f3bb3 100644 --- a/drevalpy/models/DrugGNN/drug_gnn.py +++ b/drevalpy/models/DrugGNN/drug_gnn.py @@ -234,14 +234,6 @@ def __init__(self): self.model: DrugGNNModule | None = None self.hyperparameters = {} - @classmethod - def get_model_name(cls) -> str: - """Return the name of the model. - - :return: The name of the model. - """ - return "DrugGNN" - @property def cell_line_views(self) -> list[str]: """Return the sources the model needs as input for describing the cell line. diff --git a/drevalpy/models/MOLIR/molir.py b/drevalpy/models/MOLIR/molir.py index 4aaab346..4c456c59 100644 --- a/drevalpy/models/MOLIR/molir.py +++ b/drevalpy/models/MOLIR/molir.py @@ -48,15 +48,6 @@ def __init__(self) -> None: self.gene_expression_scaler = StandardScaler() self.selector: VarianceFeatureSelector | None = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MOLIR - """ - return "MOLIR" - def build_model(self, hyperparameters: dict[str, Any]) -> None: """ Builds the model from hyperparameters. diff --git a/drevalpy/models/SRMF/srmf.py b/drevalpy/models/SRMF/srmf.py index 8e1aaeb0..1d1a5eda 100644 --- a/drevalpy/models/SRMF/srmf.py +++ b/drevalpy/models/SRMF/srmf.py @@ -52,15 +52,6 @@ def __init__(self) -> None: self.max_iter: int = 50 self.seed: int = 1 - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SRMF - """ - return "SRMF" - def build_model(self, hyperparameters: dict) -> None: """ Initializes hyperparameters for SRMF model. diff --git a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py index 39b778a1..17e95f14 100644 --- a/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/multiomics_neural_network.py @@ -42,15 +42,6 @@ def __init__(self): self.pca_ncomp = 100 self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MultiOmicsNeuralNetwork - """ - return "MultiOmicsNeuralNetwork" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 6185f4bc..4df87970 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -35,15 +35,6 @@ def __init__(self): self.model = None self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SimpleNeuralNetwork - """ - return "SimpleNeuralNetwork" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -259,26 +250,8 @@ class ChemBERTaNeuralNetwork(ChemBERTaMixin, SimpleNeuralNetwork): drug_views = ["chemberta_embeddings"] - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: ChemBERTaNeuralNetwork - """ - return "ChemBERTaNeuralNetwork" - class PCANeuralNetwork(PCAMixin, SimpleNeuralNetwork): """Neural Network model using PCA-transformed gene expression and fingerprints.""" cell_line_views = ["gene_expression_pca"] - - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: PCANeuralNetwork - """ - return "PCANeuralNetwork" diff --git a/drevalpy/models/SuperFELTR/superfeltr.py b/drevalpy/models/SuperFELTR/superfeltr.py index 6334dfc5..a4058157 100644 --- a/drevalpy/models/SuperFELTR/superfeltr.py +++ b/drevalpy/models/SuperFELTR/superfeltr.py @@ -61,15 +61,6 @@ def __init__(self) -> None: self.copy_number_variation_features = None self.selectors: dict[str, VarianceFeatureSelector] = {} - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SuperFELTR - """ - return "SuperFELTR" - def build_model(self, hyperparameters) -> None: """ Builds the model from hyperparameters. diff --git a/drevalpy/models/baselines/multi_omics_random_forest.py b/drevalpy/models/baselines/multi_omics_random_forest.py index 005bb14e..afe3e5c1 100644 --- a/drevalpy/models/baselines/multi_omics_random_forest.py +++ b/drevalpy/models/baselines/multi_omics_random_forest.py @@ -32,15 +32,6 @@ def __init__(self): self.pca = None self.pca_ncomp = 100 - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: MultiOmicsRandomForest - """ - return "MultiOmicsRandomForest" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. diff --git a/drevalpy/models/baselines/naive_pred.py b/drevalpy/models/baselines/naive_pred.py index db1391c6..759bd499 100644 --- a/drevalpy/models/baselines/naive_pred.py +++ b/drevalpy/models/baselines/naive_pred.py @@ -109,15 +109,6 @@ def __init__(self): """ super().__init__() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaivePredictor - """ - return "NaivePredictor" - def train( self, output: DrugResponseDataset, @@ -191,15 +182,6 @@ def __init__(self): super().__init__() self.drug_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveDrugMeanPredictor - """ - return "NaiveDrugMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -299,15 +281,6 @@ def __init__(self): super().__init__() self.cell_line_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveCellLineMeanPredictor - """ - return "NaiveCellLineMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -406,15 +379,6 @@ def __init__(self): super().__init__() self.tissue_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveTissueMeanPredictor - """ - return "NaiveTissueMeanPredictor" - def train( self, output: DrugResponseDataset, @@ -518,15 +482,6 @@ def __init__(self): self.cell_line_effects = {} self.drug_effects = {} - @classmethod - def get_model_name(cls) -> str: - """ - Returns the name of the model. - - :return: The name of the model as a string. - """ - return "NaiveMeanEffectsPredictor" - def train( self, output: DrugResponseDataset, @@ -643,15 +598,6 @@ def __init__(self): super().__init__() self.tissue_drug_means = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: NaiveTissueDrugMeanPredictor - """ - return "NaiveTissueDrugMeanPredictor" - def save(self, directory: str) -> None: """ Saves the model parameters to the given directory. diff --git a/drevalpy/models/baselines/singledrug_elastic_net.py b/drevalpy/models/baselines/singledrug_elastic_net.py index 13b27379..957b39fb 100644 --- a/drevalpy/models/baselines/singledrug_elastic_net.py +++ b/drevalpy/models/baselines/singledrug_elastic_net.py @@ -35,15 +35,6 @@ def build_model(self, hyperparameters): self.model = ElasticNet(**hyperparameters) self.gene_expression_scaler = StandardScaler() - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugElasticNet - """ - return "SingleDrugElasticNet" - def train( self, output: DrugResponseDataset, @@ -173,15 +164,6 @@ def build_model(self, hyperparameters: dict): ) super().build_model(hyperparameters) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugProteomicsElasticNet - """ - return "SingleDrugProteomicsElasticNet" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the proteomics data. diff --git a/drevalpy/models/baselines/singledrug_random_forest.py b/drevalpy/models/baselines/singledrug_random_forest.py index c09567d7..6749a754 100644 --- a/drevalpy/models/baselines/singledrug_random_forest.py +++ b/drevalpy/models/baselines/singledrug_random_forest.py @@ -19,15 +19,6 @@ class SingleDrugRandomForest(RandomForest): drug_views = [] early_stopping = False - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugRandomForest - """ - return "SingleDrugRandomForest" - def train( self, output: DrugResponseDataset, @@ -148,15 +139,6 @@ def build_model(self, hyperparameters: dict): normalization_width=self.normalization_width, ) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: SingleDrugProteomicsRandomForest - """ - return "SingleDrugProteomicsRandomForest" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the proteomics features. diff --git a/drevalpy/models/baselines/sklearn_models.py b/drevalpy/models/baselines/sklearn_models.py index 425727b3..b70a6b12 100644 --- a/drevalpy/models/baselines/sklearn_models.py +++ b/drevalpy/models/baselines/sklearn_models.py @@ -39,15 +39,6 @@ def __init__(self): self.gene_expression_scaler = StandardScaler() self.proteomics_transformer = None - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :raises NotImplementedError: If the method is not implemented in the child class. - """ - raise NotImplementedError("get_model_name method has to be implemented in the child class.") - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -251,15 +242,6 @@ def build_model(self, hyperparameters: dict): class RandomForest(SklearnModel): """RandomForest model for drug response prediction.""" - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: RandomForest - """ - return "RandomForest" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -308,15 +290,6 @@ def build_model(self, hyperparameters: dict): class GradientBoosting(SklearnModel): """Gradient Boosting model for drug response prediction.""" - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: GradientBoosting - """ - return "GradientBoosting" - def build_model(self, hyperparameters: dict): """ Builds the model from hyperparameters. @@ -375,15 +348,6 @@ def build_model(self, hyperparameters: dict): normalization_width=self.normalization_width, ) - @classmethod - def get_model_name(cls) -> str: - """ - Returns the model name. - - :returns: ProteomicsRandomForest - """ - return "ProteomicsRandomForest" - def load_cell_line_features(self, data_path: str, dataset_name: str) -> FeatureDataset: """ Loads the cell line features. diff --git a/drevalpy/models/drp_model.py b/drevalpy/models/drp_model.py index cde01b67..b3736e2b 100644 --- a/drevalpy/models/drp_model.py +++ b/drevalpy/models/drp_model.py @@ -251,14 +251,14 @@ def finish_wandb(self) -> None: self.wandb_run = None @classmethod - @abstractmethod @pipeline_function def get_model_name(cls) -> str: """ Returns the name of the model. - :return: model name + :return: model name (the class name) """ + return cls.__name__ @classmethod @pipeline_function From 792eb8e2954e7e91c5be00c2182c4b9e2eb01102 Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 16:49:21 +0000 Subject: [PATCH 14/15] Fix typing issues --- drevalpy/datasets/featurizer/cell_line/base.py | 13 +++++-------- drevalpy/datasets/featurizer/cell_line/pca.py | 7 +++++-- drevalpy/datasets/featurizer/drug/chemberta.py | 4 ++++ drevalpy/datasets/featurizer/drug/drug_graph.py | 2 +- drevalpy/datasets/featurizer/drug/molgnet.py | 4 ++++ 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/drevalpy/datasets/featurizer/cell_line/base.py b/drevalpy/datasets/featurizer/cell_line/base.py index a441ed32..499321f2 100644 --- a/drevalpy/datasets/featurizer/cell_line/base.py +++ b/drevalpy/datasets/featurizer/cell_line/base.py @@ -145,21 +145,18 @@ def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDatas omics_data = self._load_omics_data(data_path, dataset_name) # Get common cell line IDs across all omics types - cell_line_ids = None + cell_line_ids_set: set[str] = set() for _omics_type, df in omics_data.items(): - if cell_line_ids is None: - cell_line_ids = set(df.index) - else: - cell_line_ids = cell_line_ids.intersection(set(df.index)) + cell_line_ids_set = cell_line_ids_set.intersection(set(df.index)) - cell_line_ids = sorted(list(cell_line_ids)) - print(f"Processing {len(cell_line_ids)} cell lines for dataset {dataset_name}...") + cell_line_ids_list = sorted(list(cell_line_ids_set)) + print(f"Processing {len(cell_line_ids_list)} cell lines for dataset {dataset_name}...") # Generate embeddings embeddings_list = [] valid_cell_line_ids = [] - for cell_line_id in cell_line_ids: + for cell_line_id in cell_line_ids_list: try: # Prepare omics data for this cell line cell_omics = {} diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py index 9813625e..27880d07 100644 --- a/drevalpy/datasets/featurizer/cell_line/pca.py +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -47,7 +47,7 @@ def featurize(self, omics_data: dict[str, np.ndarray]) -> np.ndarray: :raises RuntimeError: If the PCA model is not fitted :raises ValueError: If gene_expression data is not provided """ - if not self._fitted: + if not self._fitted or self._scaler is None or self._pca is None: raise RuntimeError("PCA model is not fitted. Call generate_embeddings() or fit() first.") if "gene_expression" not in omics_data: @@ -102,6 +102,7 @@ def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDatas :param dataset_name: Name of the dataset :returns: FeatureDataset containing the PCA embeddings :raises FileNotFoundError: If the gene expression file is not found + :raises RuntimeError: If fitting fails """ data_dir = Path(data_path).resolve() output_file = data_dir / dataset_name / self.get_output_filename() @@ -121,7 +122,9 @@ def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDatas # Fit the model self.fit(ge_df) - # Transform all cell lines + # Transform all cell lines (scaler and pca are guaranteed to be set after fit()) + if self._scaler is None or self._pca is None: + raise RuntimeError("Fitting failed: scaler or PCA model is None") scaled_data = self._scaler.transform(ge_df.values) embeddings = self._pca.transform(scaled_data) diff --git a/drevalpy/datasets/featurizer/drug/chemberta.py b/drevalpy/datasets/featurizer/drug/chemberta.py index a49b6a3b..9b642d77 100644 --- a/drevalpy/datasets/featurizer/drug/chemberta.py +++ b/drevalpy/datasets/featurizer/drug/chemberta.py @@ -54,9 +54,13 @@ def featurize(self, smiles: str) -> np.ndarray: :param smiles: SMILES string representing the drug :returns: ChemBERTa embedding as numpy array + :raises RuntimeError: If model is not loaded """ self._load_model() + if self._tokenizer is None or self._model is None: + raise RuntimeError("Model not loaded. Call _load_model() first.") + inputs = self._tokenizer(smiles, return_tensors="pt", truncation=True) inputs = {k: v.to(self.device) for k, v in inputs.items()} diff --git a/drevalpy/datasets/featurizer/drug/drug_graph.py b/drevalpy/datasets/featurizer/drug/drug_graph.py index 6692d8b7..cdd0182b 100644 --- a/drevalpy/datasets/featurizer/drug/drug_graph.py +++ b/drevalpy/datasets/featurizer/drug/drug_graph.py @@ -27,7 +27,7 @@ } # Bond feature configuration -BOND_FEATURES = { +BOND_FEATURES: dict[str, list] = { "bond_type": [], # Will be populated after rdkit import check } diff --git a/drevalpy/datasets/featurizer/drug/molgnet.py b/drevalpy/datasets/featurizer/drug/molgnet.py index e0e03884..53e872ea 100644 --- a/drevalpy/datasets/featurizer/drug/molgnet.py +++ b/drevalpy/datasets/featurizer/drug/molgnet.py @@ -707,6 +707,7 @@ def featurize(self, smiles: str) -> torch.Tensor | None: :param smiles: SMILES string representing the drug :returns: Node embeddings tensor or None if conversion fails + :raises RuntimeError: If model is not loaded """ _init_rdkit_features() self._load_model() @@ -720,6 +721,9 @@ def featurize(self, smiles: str) -> torch.Tensor | None: graph = self._add_seg(graph) graph = graph.to(self.device) + if self._model is None: + raise RuntimeError("Model not loaded. Call _load_model() first.") + with torch.no_grad(): embeddings = self._model(graph) From e2edb4cd35022843e3a5e5cb1a61b0945300ed50 Mon Sep 17 00:00:00 2001 From: nictru Date: Fri, 9 Jan 2026 17:02:39 +0000 Subject: [PATCH 15/15] Fix production dataset colname issue --- drevalpy/datasets/featurizer/cell_line/pca.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/drevalpy/datasets/featurizer/cell_line/pca.py b/drevalpy/datasets/featurizer/cell_line/pca.py index 27880d07..4fba4ae8 100644 --- a/drevalpy/datasets/featurizer/cell_line/pca.py +++ b/drevalpy/datasets/featurizer/cell_line/pca.py @@ -116,6 +116,9 @@ def generate_embeddings(self, data_path: str, dataset_name: str) -> FeatureDatas ge_df = pd.read_csv(ge_file, dtype={CELL_LINE_IDENTIFIER: str}) ge_df = ge_df.set_index(CELL_LINE_IDENTIFIER) + # Drop non-numeric columns (e.g., cellosaurus_id) + ge_df = ge_df.select_dtypes(include=[np.number]) + cell_line_ids = list(ge_df.index) print(f"Processing {len(cell_line_ids)} cell lines for dataset {dataset_name}...")