diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 00000000..90fdbe2d --- /dev/null +++ b/alembic.ini @@ -0,0 +1,35 @@ +[alembic] +script_location = alembic +sqlalchemy.url = sqlite:///zsim/data/zsim.db + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stdout,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 00000000..40bdf971 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,75 @@ +"""Alembic环境配置""" + +from __future__ import annotations + +import sys +from logging.config import fileConfig +from pathlib import Path + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + + +def _load_metadata(): + """加载SQLAlchemy元数据""" + + import zsim.api_src.services.database.apl_db # noqa: F401 + import zsim.api_src.services.database.character_db # noqa: F401 + import zsim.api_src.services.database.enemy_db # noqa: F401 + import zsim.api_src.services.database.session_db # noqa: F401 + from zsim.api_src.services.database.orm import Base + + return Base.metadata + + +def _get_database_url() -> str: + """获取同步数据库URL""" + + from zsim.api_src.services.database.orm import get_sync_database_url + + return get_sync_database_url() + + +target_metadata = _load_metadata() +config.set_main_option("sqlalchemy.url", _get_database_url()) + + +def run_migrations_offline() -> None: + """Offline模式运行迁移""" + + url = config.get_main_option("sqlalchemy.url") + context.configure(url=url, target_metadata=target_metadata, literal_binds=True) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Online模式运行迁移""" + + connectable = engine_from_config( + config.get_section(config.config_ini_section) or {}, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 00000000..c29f1c0b --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,30 @@ +"""${message} + +Revision ID: ${up_revision} +Revises:${" " + (down_revision | comma,n) if down_revision else ""} +Create Date: ${create_date} + +""" + +from __future__ import annotations + +from typing import Sequence + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +revision: str = ${repr(up_revision)} +down_revision: str | Sequence[str] | None = ${repr(down_revision)} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} +depends_on: str | Sequence[str] | None = ${repr(depends_on)} + + +def upgrade() -> None: + """执行升级操作""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """执行回滚操作""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/.gitkeep b/alembic/versions/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/alembic/versions/74ee1818bd42_init_schema.py b/alembic/versions/74ee1818bd42_init_schema.py new file mode 100644 index 00000000..6c1c5a16 --- /dev/null +++ b/alembic/versions/74ee1818bd42_init_schema.py @@ -0,0 +1,94 @@ +"""init schema + +Revision ID: 74ee1818bd42 +Revises: +Create Date: 2025-10-07 12:40:12.492096 + +""" + +from __future__ import annotations + +from typing import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "74ee1818bd42" +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """执行升级操作""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table("apl_configs", + sa.Column("id", sa.String(length=64), nullable=False), + sa.Column("title", sa.String(length=255), nullable=False), + sa.Column("author", sa.String(length=255), nullable=True), + sa.Column("comment", sa.Text(), nullable=True), + sa.Column("create_time", sa.String(length=32), nullable=False), + sa.Column("latest_change_time", sa.String(length=32), nullable=False), + sa.Column("content", sa.Text(), nullable=False), + sa.PrimaryKeyConstraint("id") + ) + op.create_table("character_configs", + sa.Column("config_id", sa.String(length=128), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("config_name", sa.String(length=255), nullable=False), + sa.Column("weapon", sa.String(length=255), nullable=False), + sa.Column("weapon_level", sa.Integer(), nullable=False), + sa.Column("cinema", sa.Integer(), nullable=False), + sa.Column("crit_balancing", sa.Boolean(), nullable=False), + sa.Column("crit_rate_limit", sa.Float(), nullable=False), + sa.Column("scATK_percent", sa.Integer(), nullable=False), + sa.Column("scATK", sa.Integer(), nullable=False), + sa.Column("scHP_percent", sa.Integer(), nullable=False), + sa.Column("scHP", sa.Integer(), nullable=False), + sa.Column("scDEF_percent", sa.Integer(), nullable=False), + sa.Column("scDEF", sa.Integer(), nullable=False), + sa.Column("scAnomalyProficiency", sa.Integer(), nullable=False), + sa.Column("scPEN", sa.Integer(), nullable=False), + sa.Column("scCRIT", sa.Integer(), nullable=False), + sa.Column("scCRIT_DMG", sa.Integer(), nullable=False), + sa.Column("drive4", sa.Text(), nullable=False), + sa.Column("drive5", sa.Text(), nullable=False), + sa.Column("drive6", sa.Text(), nullable=False), + sa.Column("equip_style", sa.String(length=255), nullable=False), + sa.Column("equip_set4", sa.String(length=255), nullable=True), + sa.Column("equip_set2_a", sa.String(length=255), nullable=True), + sa.Column("equip_set2_b", sa.String(length=255), nullable=True), + sa.Column("equip_set2_c", sa.String(length=255), nullable=True), + sa.Column("create_time", sa.String(length=32), nullable=False), + sa.Column("update_time", sa.String(length=32), nullable=False), + sa.PrimaryKeyConstraint("config_id") + ) + op.create_table("enemy_configs", + sa.Column("config_id", sa.String(length=128), nullable=False), + sa.Column("enemy_index", sa.Integer(), nullable=False), + sa.Column("enemy_adjust", sa.Text(), nullable=False), + sa.Column("create_time", sa.String(length=32), nullable=False), + sa.Column("update_time", sa.String(length=32), nullable=False), + sa.PrimaryKeyConstraint("config_id") + ) + op.create_table("sessions", + sa.Column("session_id", sa.String(length=128), nullable=False), + sa.Column("session_name", sa.String(length=255), nullable=False), + sa.Column("create_time", sa.String(length=32), nullable=False), + sa.Column("status", sa.String(length=32), nullable=False), + sa.Column("session_run", sa.Text(), nullable=True), + sa.Column("session_result", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("session_id") + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """执行回滚操作""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("sessions") + op.drop_table("enemy_configs") + op.drop_table("character_configs") + op.drop_table("apl_configs") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 7b34778d..a56a6a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,9 @@ dependencies = [ "httpx>=0.28.1", "dotenv>=0.9.9", "tomli-w>=1.2.0", + "sqlalchemy>=2.0.43", + "alembic>=1.16.5", + "greenlet>=3.0.3", ] [tool.ruff] diff --git a/uv.lock b/uv.lock index 84312a0f..7837efa4 100644 --- a/uv.lock +++ b/uv.lock @@ -32,6 +32,20 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, ] +[[package]] +name = "alembic" +version = "1.16.5" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9a/ca/4dc52902cf3491892d464f5265a81e9dff094692c8a049a3ed6a05fe7ee8/alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e", size = 1969868, upload-time = "2025-08-27T18:02:05.668Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/39/4a/4c61d4c84cfd9befb6fa08a702535b27b21fff08c946bc2f6139decbf7f7/alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3", size = 247355, upload-time = "2025-08-27T18:02:07.37Z" }, +] + [[package]] name = "altair" version = "5.5.0" @@ -448,6 +462,39 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599, upload-time = "2025-01-02T07:32:40.731Z" }, ] +[[package]] +name = "greenlet" +version = "3.2.4" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/b8/704d753a5a45507a7aab61f18db9509302ed3d0a27ac7e0359ec2905b1a6/greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d", size = 188260, upload-time = "2025-08-07T13:24:33.51Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3b/16/035dcfcc48715ccd345f3a93183267167cdd162ad123cd93067d86f27ce4/greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968", size = 655185, upload-time = "2025-08-07T13:45:27.624Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/31/da/0386695eef69ffae1ad726881571dfe28b41970173947e7c558d9998de0f/greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9", size = 649926, upload-time = "2025-08-07T13:53:15.251Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/68/88/69bf19fd4dc19981928ceacbc5fd4bb6bc2215d53199e367832e98d1d8fe/greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6", size = 651839, upload-time = "2025-08-07T13:18:30.281Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f7/0b/bc13f787394920b23073ca3b6c4a7a21396301ed75a655bcb47196b50e6e/greenlet-3.2.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:710638eb93b1fa52823aa91bf75326f9ecdfd5e0466f00789246a5280f4ba0fc", size = 655191, upload-time = "2025-08-07T13:45:29.752Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f2/d6/6adde57d1345a8d0f14d31e4ab9c23cfe8e2cd39c3baf7674b4b0338d266/greenlet-3.2.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c5111ccdc9c88f423426df3fd1811bfc40ed66264d35aa373420a34377efc98a", size = 649516, upload-time = "2025-08-07T13:53:16.314Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7f/3b/3a3328a788d4a473889a2d403199932be55b1b0060f4ddd96ee7cdfcad10/greenlet-3.2.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d76383238584e9711e20ebe14db6c88ddcedc1829a9ad31a584389463b5aa504", size = 652169, upload-time = "2025-08-07T13:18:32.861Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/aa/687d6b12ffb505a4447567d1f3abea23bd20e73a5bed63871178e0831b7a/greenlet-3.2.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c17b6b34111ea72fc5a4e4beec9711d2226285f0386ea83477cbb97c30a3f3a5", size = 699218, upload-time = "2025-08-07T13:45:30.969Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e3/a5/6ddab2b4c112be95601c13428db1d8b6608a8b6039816f2ba09c346c08fc/greenlet-3.2.4-cp314-cp314-win_amd64.whl", hash = "sha256:e37ab26028f12dbb0ff65f29a8d3d44a765c61e729647bf2ddfbbed621726f01", size = 303425, upload-time = "2025-08-07T13:32:27.59Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -674,6 +721,18 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/5d/c059c180c84f7962db0aeae7c3b9303ed1d73d76f2bfbc32bc231c8be314/macholib-1.16.3-py2.py3-none-any.whl", hash = "sha256:0e315d7583d38b8c77e815b1ecbdbf504a8258d8b3e17b61165c6feb60d18f2c" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -1619,6 +1678,35 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.43" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "greenlet", marker = "(python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64')" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d7/bc/d59b5d97d27229b0e009bd9098cd81af71c2fa5549c580a0a67b9bed0496/sqlalchemy-2.0.43.tar.gz", hash = "sha256:788bfcef6787a7764169cfe9859fe425bf44559619e1d9f56f5bddf2ebf6f417", size = 9762949, upload-time = "2025-08-11T14:24:58.438Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/61/db/20c78f1081446095450bdc6ee6cc10045fce67a8e003a5876b6eaafc5cc4/sqlalchemy-2.0.43-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:20d81fc2736509d7a2bd33292e489b056cbae543661bb7de7ce9f1c0cd6e7f24", size = 2134891, upload-time = "2025-08-11T15:51:13.019Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/0a/3d89034ae62b200b4396f0f95319f7d86e9945ee64d2343dcad857150fa2/sqlalchemy-2.0.43-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b9fc27650ff5a2c9d490c13c14906b918b0de1f8fcbb4c992712d8caf40e83", size = 2123061, upload-time = "2025-08-11T15:51:14.319Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/10/2711f7ff1805919221ad5bee205971254845c069ee2e7036847103ca1e4c/sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6772e3ca8a43a65a37c88e2f3e2adfd511b0b1da37ef11ed78dea16aeae85bd9", size = 3320384, upload-time = "2025-08-11T15:52:35.088Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/0e/3d155e264d2ed2778484006ef04647bc63f55b3e2d12e6a4f787747b5900/sqlalchemy-2.0.43-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a113da919c25f7f641ffbd07fbc9077abd4b3b75097c888ab818f962707eb48", size = 3329648, upload-time = "2025-08-11T15:56:34.153Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5b/81/635100fb19725c931622c673900da5efb1595c96ff5b441e07e3dd61f2be/sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4286a1139f14b7d70141c67a8ae1582fc2b69105f1b09d9573494eb4bb4b2687", size = 3258030, upload-time = "2025-08-11T15:52:36.933Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0c/ed/a99302716d62b4965fded12520c1cbb189f99b17a6d8cf77611d21442e47/sqlalchemy-2.0.43-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:529064085be2f4d8a6e5fab12d36ad44f1909a18848fcfbdb59cc6d4bbe48efe", size = 3294469, upload-time = "2025-08-11T15:56:35.553Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5d/a2/3a11b06715149bf3310b55a98b5c1e84a42cfb949a7b800bc75cb4e33abc/sqlalchemy-2.0.43-cp312-cp312-win32.whl", hash = "sha256:b535d35dea8bbb8195e7e2b40059e2253acb2b7579b73c1b432a35363694641d", size = 2098906, upload-time = "2025-08-11T15:55:00.645Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bc/09/405c915a974814b90aa591280623adc6ad6b322f61fd5cff80aeaef216c9/sqlalchemy-2.0.43-cp312-cp312-win_amd64.whl", hash = "sha256:1c6d85327ca688dbae7e2b06d7d84cfe4f3fffa5b5f9e21bb6ce9d0e1a0e0e0a", size = 2126260, upload-time = "2025-08-11T15:55:02.965Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/41/1c/a7260bd47a6fae7e03768bf66451437b36451143f36b285522b865987ced/sqlalchemy-2.0.43-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e7c08f57f75a2bb62d7ee80a89686a5e5669f199235c6d1dac75cd59374091c3", size = 2130598, upload-time = "2025-08-11T15:51:15.903Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8e/84/8a337454e82388283830b3586ad7847aa9c76fdd4f1df09cdd1f94591873/sqlalchemy-2.0.43-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:14111d22c29efad445cd5021a70a8b42f7d9152d8ba7f73304c4d82460946aaa", size = 2118415, upload-time = "2025-08-11T15:51:17.256Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cf/ff/22ab2328148492c4d71899d62a0e65370ea66c877aea017a244a35733685/sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:21b27b56eb2f82653168cefe6cb8e970cdaf4f3a6cb2c5e3c3c1cf3158968ff9", size = 3248707, upload-time = "2025-08-11T15:52:38.444Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/dc/29/11ae2c2b981de60187f7cbc84277d9d21f101093d1b2e945c63774477aba/sqlalchemy-2.0.43-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c5a9da957c56e43d72126a3f5845603da00e0293720b03bde0aacffcf2dc04f", size = 3253602, upload-time = "2025-08-11T15:56:37.348Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/61/987b6c23b12c56d2be451bc70900f67dd7d989d52b1ee64f239cf19aec69/sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5d79f9fdc9584ec83d1b3c75e9f4595c49017f5594fee1a2217117647225d738", size = 3183248, upload-time = "2025-08-11T15:52:39.865Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/85/29d216002d4593c2ce1c0ec2cec46dda77bfbcd221e24caa6e85eff53d89/sqlalchemy-2.0.43-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9df7126fd9db49e3a5a3999442cc67e9ee8971f3cb9644250107d7296cb2a164", size = 3219363, upload-time = "2025-08-11T15:56:39.11Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/e4/bd78b01919c524f190b4905d47e7630bf4130b9f48fd971ae1c6225b6f6a/sqlalchemy-2.0.43-cp313-cp313-win32.whl", hash = "sha256:7f1ac7828857fcedb0361b48b9ac4821469f7694089d15550bbcf9ab22564a1d", size = 2096718, upload-time = "2025-08-11T15:55:05.349Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ac/a5/ca2f07a2a201f9497de1928f787926613db6307992fe5cda97624eb07c2f/sqlalchemy-2.0.43-cp313-cp313-win_amd64.whl", hash = "sha256:971ba928fcde01869361f504fcff3b7143b47d30de188b11c6357c0505824197", size = 2123200, upload-time = "2025-08-11T15:55:07.932Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1888,9 +1976,11 @@ source = { editable = "." } dependencies = [ { name = "aiofiles" }, { name = "aiosqlite" }, + { name = "alembic" }, { name = "dash" }, { name = "dotenv" }, { name = "fastapi" }, + { name = "greenlet" }, { name = "httpx" }, { name = "numpy" }, { name = "pandas" }, @@ -1900,6 +1990,7 @@ dependencies = [ { name = "pydantic" }, { name = "pywebview" }, { name = "setuptools" }, + { name = "sqlalchemy" }, { name = "streamlit" }, { name = "streamlit-ace" }, { name = "tomli-w" }, @@ -1929,9 +2020,11 @@ dev = [ requires-dist = [ { name = "aiofiles", specifier = ">=24.1.0" }, { name = "aiosqlite", specifier = ">=0.21.0" }, + { name = "alembic", specifier = ">=1.16.5" }, { name = "dash", specifier = "~=2.18.2" }, { name = "dotenv", specifier = ">=0.9.9" }, { name = "fastapi", specifier = ">=0.115.12" }, + { name = "greenlet", specifier = ">=3.0.3" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "numpy", specifier = "~=2.2.5" }, { name = "pandas", specifier = "~=2.2.3" }, @@ -1942,6 +2035,7 @@ requires-dist = [ { name = "pywebview", specifier = ">=5.4" }, { name = "pywin32", marker = "extra == 'windows'", specifier = ">=308" }, { name = "setuptools", specifier = "~=75.1.0" }, + { name = "sqlalchemy", specifier = ">=2.0.43" }, { name = "streamlit", specifier = "~=1.44.0" }, { name = "streamlit-ace", specifier = ">=0.1.1" }, { name = "tomli-w", specifier = ">=1.2.0" }, diff --git a/zsim/api_src/services/database/apl_db.py b/zsim/api_src/services/database/apl_db.py index 95cf731c..ee90dd67 100644 --- a/zsim/api_src/services/database/apl_db.py +++ b/zsim/api_src/services/database/apl_db.py @@ -3,436 +3,500 @@ 负责APL相关数据的数据库操作 """ +from __future__ import annotations + import asyncio import os +import tomllib import uuid +from datetime import datetime from typing import Any -import aiosqlite import tomli_w -import tomllib +from sqlalchemy import String, Text, delete, select +from sqlalchemy.orm import Mapped, mapped_column + +from zsim.api_src.services.database.orm import Base, get_async_engine, get_async_session +from zsim.define import COSTOM_APL_DIR, DEFAULT_APL_DIR -from zsim.define import COSTOM_APL_DIR, DEFAULT_APL_DIR, SQLITE_PATH + +class APLConfigORM(Base): + + __tablename__ = "apl_configs" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + title: Mapped[str] = mapped_column(String(255), nullable=False) + author: Mapped[str | None] = mapped_column(String(255), nullable=True) + comment: Mapped[str | None] = mapped_column(Text, nullable=True) + create_time: Mapped[str] = mapped_column(String(32), nullable=False) + latest_change_time: Mapped[str] = mapped_column(String(32), nullable=False) + content: Mapped[str] = mapped_column(Text, nullable=False) class APLDatabase: """APL数据库操作类""" - def __init__(self): - """初始化APL数据库""" - # 确保数据库目录存在 - os.makedirs(os.path.dirname(SQLITE_PATH), exist_ok=True) - # 不在这里初始化数据库,而是在首次使用时异步初始化 + def __init__(self) -> None: + """初始化APL数据库实例""" self._initialized = False - async def _ensure_initialized(self): - """确保数据库已初始化""" - if not self._initialized: - await self._init_database() - self._initialized = True - - async def _init_database(self): - """初始化数据库表""" - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """CREATE TABLE IF NOT EXISTS apl_configs ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - author TEXT, - comment TEXT, - create_time TEXT, - latest_change_time TEXT, - content TEXT NOT NULL - ) - """ - ) - await db.commit() + async def _ensure_initialized(self) -> None: + """确保数据库元数据已创建""" + if self._initialized: + return + async with get_async_engine().begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._initialized = True def get_apl_templates(self) -> list[dict[str, Any]]: - """获取所有APL模板""" - templates = [] + """获取所有APL模板。 - # 获取默认APL模板 - default_templates = self._get_apl_from_dir(DEFAULT_APL_DIR, "default") - templates.extend(default_templates) - - # 获取自定义APL模板 - custom_templates = self._get_apl_from_dir(COSTOM_APL_DIR, "custom") - templates.extend(custom_templates) + Returns: + list[dict[str, Any]]: 模板信息列表。 + """ + templates = [] + templates.extend(self._get_apl_from_dir(DEFAULT_APL_DIR, "default")) + templates.extend(self._get_apl_from_dir(COSTOM_APL_DIR, "custom")) return templates def get_apl_config(self, config_id: str) -> dict[str, Any] | None: - """获取特定APL配置""" - try: - return asyncio.get_event_loop().run_until_complete( - self._get_apl_config_async(config_id) - ) - except Exception as e: - print(f"Error loading APL config {config_id}: {e}") + """获取特定APL配置。 + + Args: + config_id (str): APL配置ID。 + + Returns: + dict[str, Any] | None: APL配置内容,未找到时返回None。 + """ + if not config_id or not isinstance(config_id, str): return None + return asyncio.get_event_loop().run_until_complete( + self._get_apl_config_async(config_id) + ) + async def _get_apl_config_async(self, config_id: str) -> dict[str, Any] | None: - """异步获取特定APL配置""" + """异步获取特定APL配置。 + + Args: + config_id (str): APL配置ID。 + + Returns: + dict[str, Any] | None: APL配置内容,未找到时返回None。 + """ + await self._ensure_initialized() - async with aiosqlite.connect(SQLITE_PATH) as db: - async with db.execute( - "SELECT title, author, comment, create_time, latest_change_time, content FROM apl_configs WHERE id = ?", - (config_id,), - ) as cursor: - row = await cursor.fetchone() - if row: - # 解析TOML内容 - content = tomllib.loads(row[5]) - result = { - "title": row[0], - "author": row[1], - "comment": row[2], - "create_time": row[3], - "latest_change_time": row[4], - **content, - } - return result + async with get_async_session() as session: + result = await session.execute( + select(APLConfigORM).where(APLConfigORM.id == config_id) + ) + record = result.scalar_one_or_none() + if record is None: return None + content = tomllib.loads(record.content) + return { + "title": record.title, + "author": record.author, + "comment": record.comment, + "create_time": record.create_time, + "latest_change_time": record.latest_change_time, + **content, + } def create_apl_config(self, config_data: dict[str, Any]) -> str: - """创建新的APL配置""" - # 生成唯一ID + """创建新的APL配置。 + + Args: + config_data (dict[str, Any]): APL配置数据。 + + Returns: + str: 新建配置的ID。 + + Raises: + Exception: 当写入数据库失败时抛出。 + """ + if not config_data or not isinstance(config_data, dict): + raise ValueError("配置数据不能为空且必须是字典类型") + config_id = str(uuid.uuid4()) + asyncio.get_event_loop().run_until_complete( + self._create_apl_config_async(config_id, config_data) + ) + return config_id - try: - asyncio.get_event_loop().run_until_complete( - self._create_apl_config_async(config_id, config_data) - ) - return config_id - except Exception as e: - raise Exception(f"Failed to create APL config: {e}") + async def _create_apl_config_async( + self, config_id: str, config_data: dict[str, Any] + ) -> None: + """异步创建APL配置。 - async def _create_apl_config_async(self, config_id: str, config_data: dict[str, Any]) -> None: - """异步创建新的APL配置""" - await self._ensure_initialized() - from datetime import datetime + Args: + config_id (str): 新配置ID。 + config_data (dict[str, Any]): APL配置数据。 + """ - # 提取通用信息 + await self._ensure_initialized() + current_time = datetime.now().isoformat() title = config_data.get("title", "") author = config_data.get("author", "") comment = config_data.get("comment", "") - - # 系统决定创建时间和最后修改时间 - current_time = datetime.now().isoformat() - create_time = current_time - latest_change_time = current_time - - # 创建要存储的配置数据副本(不包含通用信息) content_data = config_data.copy() content_data.pop("title", None) content_data.pop("author", None) content_data.pop("comment", None) content_data.pop("create_time", None) content_data.pop("latest_change_time", None) - - # 将配置数据转换为TOML格式 content = tomli_w.dumps(content_data) - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """ - INSERT INTO apl_configs - (id, title, author, comment, create_time, latest_change_time, content) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, - (config_id, title, author, comment, create_time, latest_change_time, content), + async with get_async_session() as session: + session.add( + APLConfigORM( + id=config_id, + title=title, + author=author, + comment=comment, + create_time=current_time, + latest_change_time=current_time, + content=content, + ) ) - await db.commit() + await session.commit() def update_apl_config(self, config_id: str, config_data: dict[str, Any]) -> bool: - """更新APL配置""" - try: - return asyncio.get_event_loop().run_until_complete( - self._update_apl_config_async(config_id, config_data) - ) - except Exception as e: - print(f"Error updating APL config {config_id}: {e}") - return False + """更新APL配置。 - async def _update_apl_config_async(self, config_id: str, config_data: dict[str, Any]) -> bool: - """异步更新APL配置""" - await self._ensure_initialized() - from datetime import datetime + Args: + config_id (str): APL配置ID。 + config_data (dict[str, Any]): 更新后的数据。 - # 提取通用信息 - title = config_data.get("title", "") - author = config_data.get("author", "") - comment = config_data.get("comment", "") + Returns: + bool: 更新成功返回True,否则False。 + """ + if not config_id or not isinstance(config_id, str): + return False + if not config_data or not isinstance(config_data, dict): + return False - # 获取现有的create_time(保持不变) - # 系统决定最后修改时间 - latest_change_time = datetime.now().isoformat() + return asyncio.get_event_loop().run_until_complete( + self._update_apl_config_async(config_id, config_data) + ) - # 创建要存储的配置数据副本(不包含通用信息) - content_data = config_data.copy() - content_data.pop("title", None) - content_data.pop("author", None) - content_data.pop("comment", None) - content_data.pop("create_time", None) - content_data.pop("latest_change_time", None) + async def _update_apl_config_async( + self, config_id: str, config_data: dict[str, Any] + ) -> bool: + """异步更新APL配置。 - # 将配置数据转换为TOML格式 - content = tomli_w.dumps(content_data) + Args: + config_id (str): APL配置ID。 + config_data (dict[str, Any]): 更新后的数据。 + + Returns: + bool: 更新成功返回True,否则False。 + """ - async with aiosqlite.connect(SQLITE_PATH) as db: - # 先获取现有的create_time - async with db.execute( - "SELECT create_time FROM apl_configs WHERE id = ?", (config_id,) - ) as cursor: - row = await cursor.fetchone() - if row: - create_time = row[0] - else: - # 如果记录不存在,使用当前时间作为create_time - create_time = latest_change_time - - cursor = await db.execute( - """ - UPDATE apl_configs - SET title = ?, author = ?, comment = ?, create_time = ?, latest_change_time = ?, content = ? - WHERE id = ? - """, - (title, author, comment, create_time, latest_change_time, content, config_id), + await self._ensure_initialized() + async with get_async_session() as session: + result = await session.execute( + select(APLConfigORM).where(APLConfigORM.id == config_id) ) - await db.commit() - return cursor.rowcount > 0 + record = result.scalar_one_or_none() + if record is None: + return False + + latest_change_time = datetime.now().isoformat() + record.title = config_data.get("title", "") + record.author = config_data.get("author", "") + record.comment = config_data.get("comment", "") + record.latest_change_time = latest_change_time + + content_data = config_data.copy() + content_data.pop("title", None) + content_data.pop("author", None) + content_data.pop("comment", None) + content_data.pop("create_time", None) + content_data.pop("latest_change_time", None) + record.content = tomli_w.dumps(content_data) + + await session.flush() + await session.commit() + return True def delete_apl_config(self, config_id: str) -> bool: - """删除APL配置""" - try: - return asyncio.get_event_loop().run_until_complete( - self._delete_apl_config_async(config_id) - ) - except Exception as e: - print(f"Error deleting APL config {config_id}: {e}") + """删除APL配置。 + + Args: + config_id (str): APL配置ID。 + + Returns: + bool: 删除成功返回True,否则False。 + """ + if not config_id or not isinstance(config_id, str): return False + return asyncio.get_event_loop().run_until_complete( + self._delete_apl_config_async(config_id) + ) + async def _delete_apl_config_async(self, config_id: str) -> bool: - """异步删除APL配置""" + """异步删除APL配置。 + + Args: + config_id (str): APL配置ID。 + + Returns: + bool: 删除成功返回True,否则False。 + """ + await self._ensure_initialized() - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute("DELETE FROM apl_configs WHERE id = ?", (config_id,)) - await db.commit() - return cursor.rowcount > 0 + async with get_async_session() as session: + result = await session.execute( + delete(APLConfigORM).where(APLConfigORM.id == config_id) + ) + if result.rowcount == 0: + await session.rollback() + return False + await session.commit() + return True def export_apl_config(self, config_id: str, file_path: str) -> bool: - """导出APL配置到TOML文件""" - try: - config = self.get_apl_config(config_id) - if config: - # 创建要导出的配置数据副本(不包含数据库特定字段) - export_data = config.copy() - export_data.pop("create_time", None) - export_data.pop("latest_change_time", None) - - # 将配置数据转换为TOML格式并保存到文件 - with open(file_path, "wb") as f: - tomli_w.dump(export_data, f) - return True - else: - return False - except Exception as e: - print(f"Error exporting APL config {config_id}: {e}") + """导出APL配置到TOML文件。 + + Args: + config_id (str): APL配置ID。 + file_path (str): 导出文件路径。 + + Returns: + bool: 导出成功返回True,否则False。 + """ + if not config_id or not isinstance(config_id, str): + return False + if not file_path or not isinstance(file_path, str): + return False + + config = self.get_apl_config(config_id) + if config is None: return False + export_data = config.copy() + export_data.pop("create_time", None) + export_data.pop("latest_change_time", None) + + # 确保目标目录存在 + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + # 使用tomli_w.dump写入文件对象 + with open(file_path, "wb") as file: + tomli_w.dump(export_data, file) + return True + def import_apl_config(self, file_path: str) -> str | None: - """从TOML文件导入APL配置""" - try: - # 读取TOML文件 - with open(file_path, "rb") as f: - config_data = tomllib.load(f) - - # 生成唯一ID - config_id = str(uuid.uuid4()) - - # 保存到数据库(create_time和latest_change_time将由系统决定) - asyncio.get_event_loop().run_until_complete( - self._create_apl_config_async(config_id, config_data) - ) - return config_id - except Exception as e: - print(f"Error importing APL config from {file_path}: {e}") + """从TOML文件导入APL配置。 + + Args: + file_path (str): APL文件路径。 + + Returns: + str | None: 导入成功时返回新配置ID,否则None。 + """ + if not file_path or not isinstance(file_path, str): + return None + if not os.path.exists(file_path): return None - def get_apl_files(self) -> list[dict[str, Any]]: - """获取所有APL文件列表""" - files = [] + with open(file_path, "rb") as file: + config_data = tomllib.load(file) + + config_id = str(uuid.uuid4()) + asyncio.get_event_loop().run_until_complete( + self._create_apl_config_async(config_id, config_data) + ) + return config_id - # 获取默认APL文件 - default_files = self._get_apl_files_from_dir(DEFAULT_APL_DIR, "default") - files.extend(default_files) + def get_apl_files(self) -> list[dict[str, Any]]: + """获取所有APL文件列表。 - # 获取自定义APL文件 - custom_files = self._get_apl_files_from_dir(COSTOM_APL_DIR, "custom") - files.extend(custom_files) + Returns: + list[dict[str, Any]]: APL文件信息列表。 + """ + files = [] + files.extend(self._get_apl_files_from_dir(DEFAULT_APL_DIR, "default")) + files.extend(self._get_apl_files_from_dir(COSTOM_APL_DIR, "custom")) return files def get_apl_file_content(self, file_id: str) -> dict[str, Any] | None: - """获取APL文件内容""" - # 根据file_id获取对应的APL文件内容 - # 解析file_id获取source和相对路径 - try: - if file_id.startswith("default_"): - rel_path = file_id[len("default_") :] - base_dir = DEFAULT_APL_DIR - elif file_id.startswith("custom_"): - rel_path = file_id[len("custom_") :] - base_dir = COSTOM_APL_DIR - else: - return None - - file_path = os.path.join(base_dir, rel_path) + """获取APL文件内容。 - if os.path.exists(file_path): - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() + Args: + file_id (str): APL文件标识。 - return {"file_id": file_id, "content": content, "file_path": file_path} - else: - return None - except Exception as e: - print(f"Error reading APL file {file_id}: {e}") + Returns: + dict[str, Any] | None: 文件内容信息,未找到时返回None。 + """ + if not file_id or not isinstance(file_id, str): return None - def create_apl_file(self, file_data: dict[str, Any]) -> str: - """创建新的APL文件""" - # 实现创建APL文件的逻辑 - try: - name = file_data.get("name", "new_apl.toml") - content = file_data.get("content", "") - - # 确保文件名以.toml结尾 - if not name.endswith(".toml"): - name += ".toml" - - # 保存到自定义目录 - file_path = os.path.join(COSTOM_APL_DIR, name) + if file_id.startswith("default_"): + rel_path = file_id[len("default_") :] + base_dir = DEFAULT_APL_DIR + elif file_id.startswith("custom_"): + rel_path = file_id[len("custom_") :] + base_dir = COSTOM_APL_DIR + else: + return None - # 确保目录存在 - os.makedirs(COSTOM_APL_DIR, exist_ok=True) + file_path = os.path.join(base_dir, rel_path) + if not os.path.exists(file_path): + return None - # 写入文件 - with open(file_path, "w", encoding="utf-8") as f: - f.write(content) + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + return {"file_id": file_id, "content": content, "file_path": file_path} - # 生成文件ID - file_id = f"custom_{name}" - return file_id - except Exception as e: - raise Exception(f"Failed to create APL file: {e}") + def create_apl_file(self, file_data: dict[str, Any]) -> str: + """创建新的APL文件。 + + Args: + file_data (dict[str, Any]): APL文件数据。 + + Returns: + str: 新建APL文件的标识。 + + Raises: + Exception: 当写入文件失败时抛出。 + """ + if not file_data or not isinstance(file_data, dict): + raise ValueError("文件数据不能为空且必须是字典类型") + + name = file_data.get("name", "new_apl.toml") + content = file_data.get("content", "") + if not name.endswith(".toml"): + name += ".toml" + file_path = os.path.join(COSTOM_APL_DIR, name) + os.makedirs(COSTOM_APL_DIR, exist_ok=True) + with open(file_path, "w", encoding="utf-8") as file: + file.write(content) + return f"custom_{name}" def update_apl_file(self, file_id: str, content: str) -> bool: - """更新APL文件内容""" - # 实现更新APL文件内容的逻辑 - try: - # 解析file_id获取source和相对路径 - if file_id.startswith("default_"): - # 不允许更新默认文件 - return False - elif file_id.startswith("custom_"): - rel_path = file_id[len("custom_") :] - base_dir = COSTOM_APL_DIR - else: - return False + """更新APL文件内容。 - file_path = os.path.join(base_dir, rel_path) + Args: + file_id (str): APL文件标识。 + content (str): 文件内容。 - if os.path.exists(file_path): - with open(file_path, "w", encoding="utf-8") as f: - f.write(content) - return True - else: - return False - except Exception as e: - print(f"Error updating APL file {file_id}: {e}") + Returns: + bool: 更新成功返回True,否则False。 + """ + if not file_id or not isinstance(file_id, str): + return False + if content is None or not isinstance(content, str): + return False + if file_id.startswith("default_"): + return False + if not file_id.startswith("custom_"): + return False + + rel_path = file_id[len("custom_") :] + file_path = os.path.join(COSTOM_APL_DIR, rel_path) + if not os.path.exists(file_path): return False + with open(file_path, "w", encoding="utf-8") as file: + file.write(content) + return True + def delete_apl_file(self, file_id: str) -> bool: - """删除APL文件""" - # 实现删除APL文件的逻辑 - try: - # 解析file_id获取source和相对路径 - if file_id.startswith("default_"): - # 不允许删除默认文件 - return False - elif file_id.startswith("custom_"): - rel_path = file_id[len("custom_") :] - base_dir = COSTOM_APL_DIR - else: - return False + """删除APL文件。 - file_path = os.path.join(base_dir, rel_path) + Args: + file_id (str): APL文件标识。 - if os.path.exists(file_path): - os.remove(file_path) - return True - else: - return False - except Exception as e: - print(f"Error deleting APL file {file_id}: {e}") + Returns: + bool: 删除成功返回True,否则False。 + """ + if not file_id or not isinstance(file_id, str): + return False + if file_id.startswith("default_"): + return False + if not file_id.startswith("custom_"): + return False + + rel_path = file_id[len("custom_") :] + file_path = os.path.join(COSTOM_APL_DIR, rel_path) + if not os.path.exists(file_path): return False + os.remove(file_path) + return True + def _get_apl_from_dir(self, apl_dir: str, source_type: str) -> list[dict[str, Any]]: - """从指定目录获取APL模板""" - apl_list = [] + """从指定目录获取APL模板。 + + Args: + apl_dir (str): 目录路径。 + source_type (str): 模板来源标识。 + + Returns: + list[dict[str, Any]]: 模板列表。 + """ + apl_list: list[dict[str, Any]] = [] if not os.path.exists(apl_dir): return apl_list - for root, _, files in os.walk(apl_dir): - for file in files: - if file.endswith(".toml"): - file_path = os.path.join(root, file) - try: - with open(file_path, "rb") as f: - apl_data = tomllib.load(f) - - # 提取基本信息 - general_info = apl_data.get("general", {}) - apl_info = { - "id": f"{source_type}_{os.path.relpath(file_path, apl_dir).replace(os.sep, '_')}", - "title": general_info.get("title", ""), - "author": general_info.get("author", ""), - "comment": general_info.get("comment", ""), - "create_time": general_info.get("create_time", ""), - "latest_change_time": general_info.get("latest_change_time", ""), - "source": source_type, - "file_path": file_path, - } - apl_list.append(apl_info) - except Exception as e: - # 记录错误但继续处理其他文件 - print(f"Error loading APL file {file_path}: {e}") - + for file_name in files: + if not file_name.endswith(".toml"): + continue + file_path = os.path.join(root, file_name) + with open(file_path, "rb") as file: + apl_data = tomllib.load(file) + general_info = apl_data.get("general", {}) + apl_list.append( + { + "id": f"{source_type}_{os.path.relpath(file_path, apl_dir).replace(os.sep, '_')}", + "title": general_info.get("title", ""), + "author": general_info.get("author", ""), + "comment": general_info.get("comment", ""), + "create_time": general_info.get("create_time", ""), + "latest_change_time": general_info.get( + "latest_change_time", "" + ), + "source": source_type, + "file_path": file_path, + } + ) return apl_list - def _get_apl_files_from_dir(self, apl_dir: str, source_type: str) -> list[dict[str, Any]]: - """从指定目录获取APL文件列表""" - file_list = [] + def _get_apl_files_from_dir( + self, apl_dir: str, source_type: str + ) -> list[dict[str, Any]]: + """从指定目录获取APL文件列表。 + + Args: + apl_dir (str): 目录路径。 + source_type (str): 模板来源标识。 + Returns: + list[dict[str, Any]]: 文件信息列表。 + """ + + file_list: list[dict[str, Any]] = [] if not os.path.exists(apl_dir): return file_list - for root, _, files in os.walk(apl_dir): - for file in files: - if file.endswith(".toml"): - file_path = os.path.join(root, file) - rel_path = os.path.relpath(file_path, apl_dir) - - file_info = { + for file_name in files: + if not file_name.endswith(".toml"): + continue + file_path = os.path.join(root, file_name) + rel_path = os.path.relpath(file_path, apl_dir) + file_list.append( + { "id": f"{source_type}_{rel_path.replace(os.sep, '_')}", - "name": file, + "name": file_name, "path": rel_path, "source": source_type, "full_path": file_path, } - file_list.append(file_info) - + ) return file_list diff --git a/zsim/api_src/services/database/character_db.py b/zsim/api_src/services/database/character_db.py index 58e1d9e1..ecbb3df0 100644 --- a/zsim/api_src/services/database/character_db.py +++ b/zsim/api_src/services/database/character_db.py @@ -1,266 +1,312 @@ -import aiosqlite -from typing import Any -from zsim.define import SQLITE_PATH -from zsim.models.character.character_config import CharacterConfig +"""角色配置数据库访问层""" + +from __future__ import annotations + from datetime import datetime +from typing import Any +from sqlalchemy import Boolean, Float, Integer, String, Text, delete, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Mapped, mapped_column + +from zsim.api_src.services.database.orm import Base, get_async_engine, get_async_session +from zsim.models.character.character_config import CharacterConfig _character_db: "CharacterDB | None" = None +class CharacterConfigORM(Base): + + __tablename__ = "character_configs" + + config_id: Mapped[str] = mapped_column(String(128), primary_key=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + config_name: Mapped[str] = mapped_column(String(255), nullable=False) + weapon: Mapped[str] = mapped_column(String(255), nullable=False) + weapon_level: Mapped[int] = mapped_column(Integer, nullable=False) + cinema: Mapped[int] = mapped_column(Integer, nullable=False) + crit_balancing: Mapped[bool] = mapped_column(Boolean, nullable=False) + crit_rate_limit: Mapped[float] = mapped_column(Float, nullable=False) + scATK_percent: Mapped[int] = mapped_column(Integer, nullable=False) + scATK: Mapped[int] = mapped_column(Integer, nullable=False) + scHP_percent: Mapped[int] = mapped_column(Integer, nullable=False) + scHP: Mapped[int] = mapped_column(Integer, nullable=False) + scDEF_percent: Mapped[int] = mapped_column(Integer, nullable=False) + scDEF: Mapped[int] = mapped_column(Integer, nullable=False) + scAnomalyProficiency: Mapped[int] = mapped_column(Integer, nullable=False) + scPEN: Mapped[int] = mapped_column(Integer, nullable=False) + scCRIT: Mapped[int] = mapped_column(Integer, nullable=False) + scCRIT_DMG: Mapped[int] = mapped_column(Integer, nullable=False) + drive4: Mapped[str] = mapped_column(Text, nullable=False) + drive5: Mapped[str] = mapped_column(Text, nullable=False) + drive6: Mapped[str] = mapped_column(Text, nullable=False) + equip_style: Mapped[str] = mapped_column(String(255), nullable=False) + equip_set4: Mapped[str | None] = mapped_column(String(255), nullable=True) + equip_set2_a: Mapped[str | None] = mapped_column(String(255), nullable=True) + equip_set2_b: Mapped[str | None] = mapped_column(String(255), nullable=True) + equip_set2_c: Mapped[str | None] = mapped_column(String(255), nullable=True) + create_time: Mapped[str] = mapped_column(String(32), nullable=False) + update_time: Mapped[str] = mapped_column(String(32), nullable=False) + + class CharacterDB: - def __init__(self): + """角色配置数据库访问对象""" + + def __init__(self) -> None: + """初始化数据库访问对象""" self._cache: dict[str, Any] = {} - self._db_init: bool = False + self._db_init = False async def _init_db(self) -> None: - """初始化数据库,创建 character_configs 表""" + """确保数据库表结构已建立""" if self._db_init: return - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """CREATE TABLE IF NOT EXISTS character_configs ( - config_id TEXT PRIMARY KEY, - name TEXT NOT NULL, - config_name TEXT NOT NULL, - weapon TEXT NOT NULL, - weapon_level INTEGER NOT NULL, - cinema INTEGER NOT NULL, - crit_balancing BOOLEAN NOT NULL, - crit_rate_limit REAL NOT NULL, - scATK_percent INTEGER NOT NULL, - scATK INTEGER NOT NULL, - scHP_percent INTEGER NOT NULL, - scHP INTEGER NOT NULL, - scDEF_percent INTEGER NOT NULL, - scDEF INTEGER NOT NULL, - scAnomalyProficiency INTEGER NOT NULL, - scPEN INTEGER NOT NULL, - scCRIT INTEGER NOT NULL, - scCRIT_DMG INTEGER NOT NULL, - drive4 TEXT NOT NULL, - drive5 TEXT NOT NULL, - drive6 TEXT NOT NULL, - equip_style TEXT NOT NULL, - equip_set4 TEXT, - equip_set2_a TEXT, - equip_set2_b TEXT, - equip_set2_c TEXT, - create_time TEXT NOT NULL, - update_time TEXT NOT NULL - )""" - ) - await db.commit() + async with get_async_engine().begin() as conn: + await conn.run_sync(Base.metadata.create_all) self._db_init = True async def add_character_config(self, config: CharacterConfig) -> None: - """添加一个新的角色配置到数据库""" + """添加一个新的角色配置。 + + Args: + config (CharacterConfig): 角色配置数据。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + await self._init_db() - # 设置config_id if not config.config_id: config.config_id = f"{config.name}_{config.config_name}" - - # 更新时间戳 config.update_time = datetime.now() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """INSERT INTO character_configs ( - config_id, name, config_name, weapon, weapon_level, cinema, crit_balancing, crit_rate_limit, - scATK_percent, scATK, scHP_percent, scHP, scDEF_percent, scDEF, scAnomalyProficiency, - scPEN, scCRIT, scCRIT_DMG, drive4, drive5, drive6, equip_style, equip_set4, - equip_set2_a, equip_set2_b, equip_set2_c, create_time, update_time - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - config.config_id, - config.name, - config.config_name, - config.weapon, - config.weapon_level, - config.cinema, - config.crit_balancing, - config.crit_rate_limit, - config.scATK_percent, - config.scATK, - config.scHP_percent, - config.scHP, - config.scDEF_percent, - config.scDEF, - config.scAnomalyProficiency, - config.scPEN, - config.scCRIT, - config.scCRIT_DMG, - config.drive4, - config.drive5, - config.drive6, - config.equip_style, - config.equip_set4, - config.equip_set2_a, - config.equip_set2_b, - config.equip_set2_c, - config.create_time.isoformat(), - config.update_time.isoformat(), - ), + async with get_async_session() as session: + session.add( + CharacterConfigORM( + config_id=config.config_id, + name=config.name, + config_name=config.config_name, + weapon=config.weapon, + weapon_level=config.weapon_level, + cinema=config.cinema, + crit_balancing=config.crit_balancing, + crit_rate_limit=config.crit_rate_limit, + scATK_percent=config.scATK_percent, + scATK=config.scATK, + scHP_percent=config.scHP_percent, + scHP=config.scHP, + scDEF_percent=config.scDEF_percent, + scDEF=config.scDEF, + scAnomalyProficiency=config.scAnomalyProficiency, + scPEN=config.scPEN, + scCRIT=config.scCRIT, + scCRIT_DMG=config.scCRIT_DMG, + drive4=config.drive4, + drive5=config.drive5, + drive6=config.drive6, + equip_style=config.equip_style, + equip_set4=config.equip_set4, + equip_set2_a=config.equip_set2_a, + equip_set2_b=config.equip_set2_b, + equip_set2_c=config.equip_set2_c, + create_time=config.create_time.isoformat(), + update_time=config.update_time.isoformat(), + ) ) - await db.commit() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc + + async def get_character_config( + self, name: str, config_name: str + ) -> CharacterConfig | None: + """根据角色名称和配置名称获取角色配置。 + + Args: + name (str): 角色名称。 + config_name (str): 配置名称。 + + Returns: + CharacterConfig | None: 匹配的角色配置,未找到时返回None。 + """ - async def get_character_config(self, name: str, config_name: str) -> CharacterConfig | None: - """根据角色名称和配置名称从数据库获取角色配置""" await self._init_db() config_id = f"{name}_{config_name}" - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute( - """SELECT config_id, name, config_name, weapon, weapon_level, cinema, crit_balancing, crit_rate_limit, - scATK_percent, scATK, scHP_percent, scHP, scDEF_percent, scDEF, scAnomalyProficiency, - scPEN, scCRIT, scCRIT_DMG, drive4, drive5, drive6, equip_style, equip_set4, - equip_set2_a, equip_set2_b, equip_set2_c, create_time, update_time - FROM character_configs - WHERE config_id = ?""", - (config_id,), - ) - row = await cursor.fetchone() - if row: - return CharacterConfig( - config_id=row[0], - name=row[1], - config_name=row[2], - weapon=row[3], - weapon_level=row[4], - cinema=row[5], - crit_balancing=row[6], - crit_rate_limit=row[7], - scATK_percent=row[8], - scATK=row[9], - scHP_percent=row[10], - scHP=row[11], - scDEF_percent=row[12], - scDEF=row[13], - scAnomalyProficiency=row[14], - scPEN=row[15], - scCRIT=row[16], - scCRIT_DMG=row[17], - drive4=row[18], - drive5=row[19], - drive6=row[20], - equip_style=row[21], - equip_set4=row[22], - equip_set2_a=row[23], - equip_set2_b=row[24], - equip_set2_c=row[25], - create_time=datetime.fromisoformat(row[26]), - update_time=datetime.fromisoformat(row[27]), + async with get_async_session() as session: + result = await session.execute( + select(CharacterConfigORM).where( + CharacterConfigORM.config_id == config_id ) - return None + ) + record = result.scalar_one_or_none() + if record is None: + return None + return CharacterConfig( + config_id=record.config_id, + name=record.name, + config_name=record.config_name, + weapon=record.weapon, + weapon_level=record.weapon_level, + cinema=record.cinema, + crit_balancing=record.crit_balancing, + crit_rate_limit=record.crit_rate_limit, + scATK_percent=record.scATK_percent, + scATK=record.scATK, + scHP_percent=record.scHP_percent, + scHP=record.scHP, + scDEF_percent=record.scDEF_percent, + scDEF=record.scDEF, + scAnomalyProficiency=record.scAnomalyProficiency, + scPEN=record.scPEN, + scCRIT=record.scCRIT, + scCRIT_DMG=record.scCRIT_DMG, + drive4=record.drive4, + drive5=record.drive5, + drive6=record.drive6, + equip_style=record.equip_style, + equip_set4=record.equip_set4, + equip_set2_a=record.equip_set2_a, + equip_set2_b=record.equip_set2_b, + equip_set2_c=record.equip_set2_c, + create_time=datetime.fromisoformat(record.create_time), + update_time=datetime.fromisoformat(record.update_time), + ) async def update_character_config(self, config: CharacterConfig) -> None: - """更新数据库中的角色配置""" + """更新数据库中的角色配置。 + + Args: + config (CharacterConfig): 新的角色配置信息。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + await self._init_db() - # 更新时间戳 config.update_time = datetime.now() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """UPDATE character_configs SET - name = ?, config_name = ?, weapon = ?, weapon_level = ?, cinema = ?, crit_balancing = ?, crit_rate_limit = ?, - scATK_percent = ?, scATK = ?, scHP_percent = ?, scHP = ?, scDEF_percent = ?, scDEF = ?, - scAnomalyProficiency = ?, scPEN = ?, scCRIT = ?, scCRIT_DMG = ?, drive4 = ?, drive5 = ?, - drive6 = ?, equip_style = ?, equip_set4 = ?, equip_set2_a = ?, equip_set2_b = ?, - equip_set2_c = ?, update_time = ? - WHERE config_id = ?""", - ( - config.name, - config.config_name, - config.weapon, - config.weapon_level, - config.cinema, - config.crit_balancing, - config.crit_rate_limit, - config.scATK_percent, - config.scATK, - config.scHP_percent, - config.scHP, - config.scDEF_percent, - config.scDEF, - config.scAnomalyProficiency, - config.scPEN, - config.scCRIT, - config.scCRIT_DMG, - config.drive4, - config.drive5, - config.drive6, - config.equip_style, - config.equip_set4, - config.equip_set2_a, - config.equip_set2_b, - config.equip_set2_c, - config.update_time.isoformat(), - config.config_id, - ), + async with get_async_session() as session: + result = await session.execute( + select(CharacterConfigORM).where( + CharacterConfigORM.config_id == config.config_id + ) ) - await db.commit() + record = result.scalar_one_or_none() + if record is None: + return + record.name = config.name + record.config_name = config.config_name + record.weapon = config.weapon + record.weapon_level = config.weapon_level + record.cinema = config.cinema + record.crit_balancing = config.crit_balancing + record.crit_rate_limit = config.crit_rate_limit + record.scATK_percent = config.scATK_percent + record.scATK = config.scATK + record.scHP_percent = config.scHP_percent + record.scHP = config.scHP + record.scDEF_percent = config.scDEF_percent + record.scDEF = config.scDEF + record.scAnomalyProficiency = config.scAnomalyProficiency + record.scPEN = config.scPEN + record.scCRIT = config.scCRIT + record.scCRIT_DMG = config.scCRIT_DMG + record.drive4 = config.drive4 + record.drive5 = config.drive5 + record.drive6 = config.drive6 + record.equip_style = config.equip_style + record.equip_set4 = config.equip_set4 + record.equip_set2_a = config.equip_set2_a + record.equip_set2_b = config.equip_set2_b + record.equip_set2_c = config.equip_set2_c + record.update_time = config.update_time.isoformat() + await session.flush() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc async def delete_character_config(self, name: str, config_name: str) -> None: - """从数据库删除角色配置""" + """删除指定角色的配置。 + + Args: + name (str): 角色名称。 + config_name (str): 配置名称。 + """ + await self._init_db() config_id = f"{name}_{config_name}" - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute("DELETE FROM character_configs WHERE config_id = ?", (config_id,)) - await db.commit() + async with get_async_session() as session: + await session.execute( + delete(CharacterConfigORM).where( + CharacterConfigORM.config_id == config_id + ) + ) + await session.commit() async def list_character_configs(self, name: str) -> list[CharacterConfig]: - """从数据库获取指定角色的所有配置列表""" + """获取指定角色的所有配置列表。 + + Args: + name (str): 角色名称。 + + Returns: + list[CharacterConfig]: 角色配置列表。 + """ + await self._init_db() - configs = [] - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute( - """SELECT config_id, name, config_name, weapon, weapon_level, cinema, crit_balancing, crit_rate_limit, - scATK_percent, scATK, scHP_percent, scHP, scDEF_percent, scDEF, scAnomalyProficiency, - scPEN, scCRIT, scCRIT_DMG, drive4, drive5, drive6, equip_style, equip_set4, - equip_set2_a, equip_set2_b, equip_set2_c, create_time, update_time - FROM character_configs - WHERE name = ? - ORDER BY config_name""", - (name,), + async with get_async_session() as session: + result = await session.execute( + select(CharacterConfigORM) + .where(CharacterConfigORM.name == name) + .order_by(CharacterConfigORM.config_name) ) - rows = await cursor.fetchall() - for row in rows: - configs.append( - CharacterConfig( - config_id=row[0], - name=row[1], - config_name=row[2], - weapon=row[3], - weapon_level=row[4], - cinema=row[5], - crit_balancing=row[6], - crit_rate_limit=row[7], - scATK_percent=row[8], - scATK=row[9], - scHP_percent=row[10], - scHP=row[11], - scDEF_percent=row[12], - scDEF=row[13], - scAnomalyProficiency=row[14], - scPEN=row[15], - scCRIT=row[16], - scCRIT_DMG=row[17], - drive4=row[18], - drive5=row[19], - drive6=row[20], - equip_style=row[21], - equip_set4=row[22], - equip_set2_a=row[23], - equip_set2_b=row[24], - equip_set2_c=row[25], - create_time=datetime.fromisoformat(row[26]), - update_time=datetime.fromisoformat(row[27]), - ) - ) - return configs + records = result.scalars().all() + return [ + CharacterConfig( + config_id=record.config_id, + name=record.name, + config_name=record.config_name, + weapon=record.weapon, + weapon_level=record.weapon_level, + cinema=record.cinema, + crit_balancing=record.crit_balancing, + crit_rate_limit=record.crit_rate_limit, + scATK_percent=record.scATK_percent, + scATK=record.scATK, + scHP_percent=record.scHP_percent, + scHP=record.scHP, + scDEF_percent=record.scDEF_percent, + scDEF=record.scDEF, + scAnomalyProficiency=record.scAnomalyProficiency, + scPEN=record.scPEN, + scCRIT=record.scCRIT, + scCRIT_DMG=record.scCRIT_DMG, + drive4=record.drive4, + drive5=record.drive5, + drive6=record.drive6, + equip_style=record.equip_style, + equip_set4=record.equip_set4, + equip_set2_a=record.equip_set2_a, + equip_set2_b=record.equip_set2_b, + equip_set2_c=record.equip_set2_c, + create_time=datetime.fromisoformat(record.create_time), + update_time=datetime.fromisoformat(record.update_time), + ) + for record in records + ] async def get_character_db() -> CharacterDB: - """便捷函数:获取 CharacterDB 的单例实例""" + """获取CharacterDB单例。 + + Returns: + CharacterDB: 单例数据库访问对象。 + """ + global _character_db if _character_db is None: _character_db = CharacterDB() diff --git a/zsim/api_src/services/database/enemy_db.py b/zsim/api_src/services/database/enemy_db.py index b89f2b02..bf9adb05 100644 --- a/zsim/api_src/services/database/enemy_db.py +++ b/zsim/api_src/services/database/enemy_db.py @@ -1,125 +1,177 @@ +"""敌人配置数据库访问层""" + +from __future__ import annotations + import json -import aiosqlite -from typing import Any -from zsim.define import SQLITE_PATH -from zsim.models.enemy.enemy_config import EnemyConfig from datetime import datetime +from sqlalchemy import Integer, String, Text, delete, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Mapped, mapped_column + +from zsim.api_src.services.database.orm import Base, get_async_engine, get_async_session +from zsim.models.enemy.enemy_config import EnemyConfig _enemy_db: "EnemyDB | None" = None +class EnemyConfigORM(Base): + + __tablename__ = "enemy_configs" + + config_id: Mapped[str] = mapped_column(String(128), primary_key=True) + enemy_index: Mapped[int] = mapped_column(Integer, nullable=False) + enemy_adjust: Mapped[str] = mapped_column(Text, nullable=False) + create_time: Mapped[str] = mapped_column(String(32), nullable=False) + update_time: Mapped[str] = mapped_column(String(32), nullable=False) + + class EnemyDB: - def __init__(self): - self._cache: dict[str, Any] = {} - self._db_init: bool = False + """敌人配置数据库访问对象""" + + def __init__(self) -> None: + """初始化数据库访问对象""" + self._db_init = False async def _init_db(self) -> None: - """初始化数据库,创建 enemy_configs 表""" + """确保数据库表结构已建立""" if self._db_init: return - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """CREATE TABLE IF NOT EXISTS enemy_configs ( - config_id TEXT PRIMARY KEY, - enemy_index INTEGER NOT NULL, - enemy_adjust TEXT NOT NULL, - create_time TEXT NOT NULL, - update_time TEXT NOT NULL - )""" - ) - await db.commit() + async with get_async_engine().begin() as conn: + await conn.run_sync(Base.metadata.create_all) self._db_init = True async def add_enemy_config(self, config: EnemyConfig) -> None: - """添加一个新的敌人配置到数据库""" + """添加敌人配置。 + + Args: + config (EnemyConfig): 敌人配置数据。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + await self._init_db() - # 更新时间戳 config.update_time = datetime.now() - - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - "INSERT INTO enemy_configs (config_id, enemy_index, enemy_adjust, create_time, update_time) VALUES (?, ?, ?, ?, ?)", - ( - config.config_id, - config.enemy_index, - json.dumps(config.enemy_adjust), - config.create_time.isoformat(), - config.update_time.isoformat(), - ), + async with get_async_session() as session: + session.add( + EnemyConfigORM( + config_id=config.config_id, + enemy_index=config.enemy_index, + enemy_adjust=json.dumps(config.enemy_adjust), + create_time=config.create_time.isoformat(), + update_time=config.update_time.isoformat(), + ) ) - await db.commit() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc async def get_enemy_config(self, config_id: str) -> EnemyConfig | None: - """根据配置ID从数据库获取敌人配置""" + """根据配置ID获取敌人配置。 + + Args: + config_id (str): 敌人配置ID。 + + Returns: + EnemyConfig | None: 匹配的敌人配置,未找到时返回None。 + """ + await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute( - "SELECT config_id, enemy_index, enemy_adjust, create_time, update_time FROM enemy_configs WHERE config_id = ?", - (config_id,), + async with get_async_session() as session: + result = await session.execute( + select(EnemyConfigORM).where(EnemyConfigORM.config_id == config_id) + ) + record = result.scalar_one_or_none() + if record is None: + return None + return EnemyConfig( + config_id=record.config_id, + enemy_index=record.enemy_index, + enemy_adjust=json.loads(record.enemy_adjust), + create_time=datetime.fromisoformat(record.create_time), + update_time=datetime.fromisoformat(record.update_time), ) - row = await cursor.fetchone() - if row: - return EnemyConfig( - config_id=row[0], - enemy_index=row[1], - enemy_adjust=json.loads(row[2]), - create_time=datetime.fromisoformat(row[3]), - update_time=datetime.fromisoformat(row[4]), - ) - return None async def update_enemy_config(self, config: EnemyConfig) -> None: - """更新数据库中的敌人配置""" + """更新敌人配置。 + + Args: + config (EnemyConfig): 敌人配置数据。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + await self._init_db() - # 更新时间戳 config.update_time = datetime.now() - - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """UPDATE enemy_configs SET - enemy_index = ?, enemy_adjust = ?, update_time = ? - WHERE config_id = ?""", - ( - config.enemy_index, - json.dumps(config.enemy_adjust), - config.update_time.isoformat(), - config.config_id, - ), + async with get_async_session() as session: + result = await session.execute( + select(EnemyConfigORM).where( + EnemyConfigORM.config_id == config.config_id + ) ) - await db.commit() + record = result.scalar_one_or_none() + if record is None: + return + record.enemy_index = config.enemy_index + record.enemy_adjust = json.dumps(config.enemy_adjust) + record.update_time = config.update_time.isoformat() + await session.flush() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc async def delete_enemy_config(self, config_id: str) -> None: - """从数据库删除敌人配置""" + """删除敌人配置。 + + Args: + config_id (str): 敌人配置ID。 + """ + await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute("DELETE FROM enemy_configs WHERE config_id = ?", (config_id,)) - await db.commit() + async with get_async_session() as session: + await session.execute( + delete(EnemyConfigORM).where(EnemyConfigORM.config_id == config_id) + ) + await session.commit() async def list_enemy_configs(self) -> list[EnemyConfig]: - """从数据库获取所有敌人配置列表""" + """列出所有敌人配置。 + + Returns: + list[EnemyConfig]: 敌人配置列表。 + """ + await self._init_db() - configs = [] - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute( - "SELECT config_id, enemy_index, enemy_adjust, create_time, update_time FROM enemy_configs ORDER BY config_id" + async with get_async_session() as session: + result = await session.execute( + select(EnemyConfigORM).order_by(EnemyConfigORM.config_id) ) - rows = await cursor.fetchall() - for row in rows: - configs.append( - EnemyConfig( - config_id=row[0], - enemy_index=row[1], - enemy_adjust=json.loads(row[2]), - create_time=datetime.fromisoformat(row[3]), - update_time=datetime.fromisoformat(row[4]), - ) - ) - return configs + records = result.scalars().all() + return [ + EnemyConfig( + config_id=record.config_id, + enemy_index=record.enemy_index, + enemy_adjust=json.loads(record.enemy_adjust), + create_time=datetime.fromisoformat(record.create_time), + update_time=datetime.fromisoformat(record.update_time), + ) + for record in records + ] async def get_enemy_db() -> EnemyDB: - """便捷函数:获取 EnemyDB 的单例实例""" + """获取EnemyDB单例。 + + Returns: + EnemyDB: 敌人数据库访问对象。 + """ + global _enemy_db if _enemy_db is None: _enemy_db = EnemyDB() diff --git a/zsim/api_src/services/database/orm.py b/zsim/api_src/services/database/orm.py new file mode 100644 index 00000000..c958500f --- /dev/null +++ b/zsim/api_src/services/database/orm.py @@ -0,0 +1,97 @@ +"""SQLAlchemy基础设施定义""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from pathlib import Path + +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase + +from zsim.define import SQLITE_PATH + + +class Base(DeclarativeBase): + """声明式基类""" + + +def _database_path() -> Path: + """返回SQLite数据库路径。 + + Returns: + Path: 数据库文件的绝对路径。 + """ + + path = Path(SQLITE_PATH).expanduser() + path.parent.mkdir(parents=True, exist_ok=True) + return path.resolve() + + +def get_async_database_url() -> str: + """获取异步模式下的数据库URL。 + + Returns: + str: 适用于异步SQLAlchemy引擎的数据库URL。 + """ + + return f"sqlite+aiosqlite:///{_database_path().as_posix()}" + + +def get_sync_database_url() -> str: + """获取同步模式下的数据库URL。 + + Returns: + str: 适用于同步SQLAlchemy引擎(如Alembic)的数据库URL。 + """ + + return f"sqlite:///{_database_path().as_posix()}" + + +_async_engine: AsyncEngine = create_async_engine(get_async_database_url(), future=True) +_async_session_factory = async_sessionmaker(_async_engine, expire_on_commit=False) + + +def get_async_engine() -> AsyncEngine: + """返回复用的异步SQLAlchemy引擎实例。 + + Returns: + AsyncEngine: 进程范围内复用的异步引擎。 + """ + + return _async_engine + + +@asynccontextmanager +async def get_async_session() -> AsyncIterator[AsyncSession]: + """获取一个SQLAlchemy异步会话。 + + Returns: + AsyncIterator[AsyncSession]: SQLAlchemy异步会话上下文管理器。 + + Raises: + RuntimeError: 当执行过程中出现数据库错误时抛出。 + """ + + session = _async_session_factory() + try: + yield session + except Exception as exc: + await session.rollback() + raise RuntimeError("异步数据库会话执行失败") from exc + finally: + await session.close() + + +__all__ = [ + "Base", + "get_async_engine", + "get_async_session", + "get_async_database_url", + "get_sync_database_url", +] diff --git a/zsim/api_src/services/database/session_db.py b/zsim/api_src/services/database/session_db.py index 44521bea..8e169d48 100644 --- a/zsim/api_src/services/database/session_db.py +++ b/zsim/api_src/services/database/session_db.py @@ -1,152 +1,215 @@ +"""模拟会话数据库访问层""" + +from __future__ import annotations + import json +from datetime import datetime from typing import Any -import aiosqlite +from sqlalchemy import String, Text, delete, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Mapped, mapped_column -from zsim.define import SQLITE_PATH +from zsim.api_src.services.database.orm import Base, get_async_engine, get_async_session from zsim.models.session.session_create import Session -_session_db: "SessionDB | None" = None # 单例实例 +_session_db: "SessionDB | None" = None + + +class SessionORM(Base): + """模拟会话ORM模型""" + + __tablename__ = "sessions" + + session_id: Mapped[str] = mapped_column(String(128), primary_key=True) + session_name: Mapped[str] = mapped_column(String(255), nullable=False, default="") + create_time: Mapped[str] = mapped_column(String(32), nullable=False) + status: Mapped[str] = mapped_column(String(32), nullable=False) + session_run: Mapped[str | None] = mapped_column(Text, nullable=True) + session_result: Mapped[str | None] = mapped_column(Text, nullable=True) class SessionDB: - def __init__(self): + """会话数据库访问对象""" + + def __init__(self) -> None: + """初始化数据库访问对象""" self._cache: dict[str, Any] = {} - self._db_init: bool = False + self._db_init = False async def _init_db(self) -> None: - """初始化数据库,创建 sessions 表""" + """确保数据库表结构已建立""" if self._db_init: return - async with aiosqlite.connect(SQLITE_PATH) as db: - # Check if the table exists and has the session_name column - cursor = await db.execute("PRAGMA table_info(sessions)") - columns = await cursor.fetchall() - column_names = [column[1] for column in columns] - - if not columns: - # Table doesn't exist, create it with all columns - await db.execute( - """CREATE TABLE sessions ( - session_id TEXT PRIMARY KEY, - session_name TEXT NOT NULL DEFAULT '', - create_time TEXT NOT NULL, - status TEXT NOT NULL, - session_run TEXT, - session_result TEXT - )""" - ) - elif "session_name" not in column_names: - # Table exists but doesn't have session_name column, add it - await db.execute( - "ALTER TABLE sessions ADD COLUMN session_name TEXT NOT NULL DEFAULT ''" - ) - - await db.commit() + async with get_async_engine().begin() as conn: + await conn.run_sync(Base.metadata.create_all) self._db_init = True - async def add_session(self, session: Session) -> None: - """添加一个新的会话到数据库""" + async def add_session(self, session_data: Session) -> None: + """添加一个新的模拟会话。 + + Args: + session_data (Session): 会话数据。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - "INSERT INTO sessions (session_id, session_name, create_time, status, session_run, session_result) VALUES (?, ?, ?, ?, ?, ?)", - ( - session.session_id, - session.session_name, - session.create_time.isoformat(), - session.status, - session.session_run.model_dump_json(indent=4) if session.session_run else None, - json.dumps([r.model_dump() for r in session.session_result]) - if session.session_result - else None, - ), + async with get_async_session() as session: + session.add( + SessionORM( + session_id=session_data.session_id, + session_name=session_data.session_name, + create_time=session_data.create_time.isoformat(), + status=session_data.status, + session_run=( + session_data.session_run.model_dump_json(indent=4) + if session_data.session_run + else None + ), + session_result=( + json.dumps( + [ + result.model_dump() + for result in session_data.session_result + ] + ) + if session_data.session_result + else None + ), + ) ) - await db.commit() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc async def get_session(self, session_id: str) -> Session | None: - """根据 session_id 从数据库获取会话""" - await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute("SELECT * FROM sessions WHERE session_id = ?", (session_id,)) - row = await cursor.fetchone() - if row: - # Get column names to ensure correct indexing - column_names = [description[0] for description in cursor.description] - row_dict = dict(zip(column_names, row)) - - return Session( - session_id=row_dict["session_id"], - session_name=row_dict["session_name"], - create_time=row_dict["create_time"], - status=row_dict["status"], - session_run=json.loads(row_dict["session_run"]) - if row_dict["session_run"] - else None, - session_result=json.loads(row_dict["session_result"]) - if row_dict["session_result"] - else None, - ) - return None + """根据ID获取模拟会话。 + + Args: + session_id (str): 会话ID。 + + Returns: + Session | None: 匹配的会话数据,未找到时返回None。 + """ - async def update_session(self, session: Session) -> None: - """更新数据库中的会话""" await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute( - """UPDATE sessions - SET session_name = ?, create_time = ?, status = ?, session_run = ?, session_result = ? - WHERE session_id = ? - """, - ( - session.session_name, - session.create_time.isoformat(), - session.status, - session.session_run.model_dump_json(indent=4) if session.session_run else None, - json.dumps([r.model_dump() for r in session.session_result]) - if session.session_result - else None, - session.session_id, + async with get_async_session() as session: + result = await session.execute( + select(SessionORM).where(SessionORM.session_id == session_id) + ) + record = result.scalar_one_or_none() + if record is None: + return None + return Session( + session_id=record.session_id, + session_name=record.session_name, + create_time=datetime.fromisoformat(record.create_time), + status=record.status, + session_run=( + json.loads(record.session_run) if record.session_run else None + ), + session_result=( + json.loads(record.session_result) if record.session_result else None ), ) - await db.commit() + + async def update_session(self, session_data: Session) -> None: + """更新模拟会话。 + + Args: + session_data (Session): 会话数据。 + + Raises: + SQLAlchemyError: 当数据库写入失败时抛出。 + """ + + await self._init_db() + async with get_async_session() as session: + result = await session.execute( + select(SessionORM).where( + SessionORM.session_id == session_data.session_id + ) + ) + record = result.scalar_one_or_none() + if record is None: + return + record.session_name = session_data.session_name + record.create_time = session_data.create_time.isoformat() + record.status = session_data.status + record.session_run = ( + session_data.session_run.model_dump_json(indent=4) + if session_data.session_run + else None + ) + record.session_result = ( + json.dumps( + [result.model_dump() for result in session_data.session_result] + ) + if session_data.session_result + else None + ) + await session.flush() + try: + await session.commit() + except SQLAlchemyError as exc: # noqa: BLE001 + await session.rollback() + raise exc async def delete_session(self, session_id: str) -> None: - """从数据库删除会话""" + """删除模拟会话。 + + Args: + session_id (str): 会话ID。 + """ + await self._init_db() - async with aiosqlite.connect(SQLITE_PATH) as db: - await db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) - await db.commit() + async with get_async_session() as session: + await session.execute( + delete(SessionORM).where(SessionORM.session_id == session_id) + ) + await session.commit() async def list_sessions(self) -> list[Session]: - """从数据库获取所有会话列表""" + """列出所有模拟会话。 + + Returns: + list[Session]: 会话数据列表。 + """ + await self._init_db() - sessions = [] - async with aiosqlite.connect(SQLITE_PATH) as db: - cursor = await db.execute("SELECT * FROM sessions ORDER BY create_time DESC") - rows = await cursor.fetchall() - column_names = [description[0] for description in cursor.description] - for row in rows: - row_dict = dict(zip(column_names, row)) - sessions.append( - Session( - session_id=row_dict["session_id"], - session_name=row_dict["session_name"], - create_time=row_dict["create_time"], - status=row_dict["status"], - session_run=json.loads(row_dict["session_run"]) - if row_dict["session_run"] - else None, - session_result=json.loads(row_dict["session_result"]) - if row_dict["session_result"] - else None, - ) - ) - return sessions + async with get_async_session() as session: + result = await session.execute( + select(SessionORM).order_by(SessionORM.create_time.desc()) + ) + records = result.scalars().all() + return [ + Session( + session_id=record.session_id, + session_name=record.session_name, + create_time=datetime.fromisoformat(record.create_time), + status=record.status, + session_run=( + json.loads(record.session_run) if record.session_run else None + ), + session_result=( + json.loads(record.session_result) if record.session_result else None + ), + ) + for record in records + ] async def get_session_db() -> SessionDB: - """便捷函数:获取 SessionDB 的单例实例""" + """获取SessionDB单例。 + + Returns: + SessionDB: 会话数据库访问对象。 + """ + global _session_db if _session_db is None: _session_db = SessionDB() diff --git a/zsim/api_src/services/sim_controller/sim_controller.py b/zsim/api_src/services/sim_controller/sim_controller.py index 7eab4a4a..7d09a992 100644 --- a/zsim/api_src/services/sim_controller/sim_controller.py +++ b/zsim/api_src/services/sim_controller/sim_controller.py @@ -5,20 +5,6 @@ from typing import TYPE_CHECKING, Any, Iterator, Literal from zsim.api_src.services.database.session_db import get_session_db -from zsim.utils.constants import stats_trans_mapping -from zsim.utils.process_buff_result import ( - prepare_buff_data_and_cache as process_buff, -) -from zsim.utils.process_dmg_result import ( - prepare_dmg_data_and_cache as process_dmg, -) -from zsim.utils.process_parallel_data import ( - judge_parallel_result, - merge_parallel_dmg_data, -) -from zsim.utils.process_parallel_data import ( - prepare_parallel_data_and_cache as prepare_parallel_cache, -) from zsim.models.session.session_create import Session from zsim.models.session.session_result import ( AttrCurvePayload, @@ -43,6 +29,20 @@ SimulationConfig as SimCfg, ) from zsim.simulator import Simulator +from zsim.utils.constants import stats_trans_mapping +from zsim.utils.process_buff_result import ( + prepare_buff_data_and_cache as process_buff, +) +from zsim.utils.process_dmg_result import ( + prepare_dmg_data_and_cache as process_dmg, +) +from zsim.utils.process_parallel_data import ( + judge_parallel_result, + merge_parallel_dmg_data, +) +from zsim.utils.process_parallel_data import ( + prepare_parallel_data_and_cache as prepare_parallel_cache, +) if TYPE_CHECKING: from zsim.simulator.simulator_class import Confirmation @@ -144,21 +144,27 @@ async def execute_simulation(self) -> None: else session.session_run.stop_tick ) if stop_tick is None: - logger.warning(f"会话 {session_id} 未设置 stop_tick,使用默认值 3600") + logger.warning( + f"会话 {session_id} 未设置 stop_tick,使用默认值 3600" + ) stop_tick = 3600 def run_simulator( _common_cfg: CommonCfg, _sim_cfg: SimCfg | None, _stop_tick: int ) -> "Confirmation": simulator = Simulator() - return simulator.api_run_simulator(_common_cfg, _sim_cfg, _stop_tick) + return simulator.api_run_simulator( + _common_cfg, _sim_cfg, _stop_tick + ) # 创建模拟器实例并提交任务 future: asyncio.Future["Confirmation"] = event_loop.run_in_executor( self.executor, run_simulator, common_cfg, sim_cfg, stop_tick ) self._running_tasks.add(future) - future.add_done_callback(lambda f: self._task_done_callback(f, session_id)) + future.add_done_callback( + lambda f: self._task_done_callback(f, session_id) + ) # 让出控制权给其他协程 await asyncio.sleep(0) @@ -201,14 +207,18 @@ async def execute_simulation_test(self, max_tasks: int = 1) -> list[str]: else session.session_run.stop_tick ) if stop_tick is None: - logger.warning(f"会话 {session_id} 未设置 stop_tick,使用默认值 3600") + logger.warning( + f"会话 {session_id} 未设置 stop_tick,使用默认值 3600" + ) stop_tick = 3600 def run_simulator( _common_cfg: CommonCfg, _sim_cfg: SimCfg | None, _stop_tick: int ) -> "Confirmation": simulator = Simulator() - return simulator.api_run_simulator(_common_cfg, _sim_cfg, _stop_tick) + return simulator.api_run_simulator( + _common_cfg, _sim_cfg, _stop_tick + ) # 使用 ThreadPoolExecutor 避免序列化问题 with ThreadPoolExecutor() as thread_executor: @@ -265,7 +275,10 @@ async def execute_simulation_test_parallel( return completed_sessions def run_simulator( - session_id_inner: str, common_cfg: CommonCfg, sim_cfg: SimCfg | None, stop_tick: int + session_id_inner: str, + common_cfg: CommonCfg, + sim_cfg: SimCfg | None, + stop_tick: int, ) -> tuple[str, "Confirmation"]: simulator = Simulator() result = simulator.api_run_simulator(common_cfg, sim_cfg, stop_tick) @@ -289,7 +302,12 @@ def run_simulator( stop_tick = 1000 future = event_loop.run_in_executor( - thread_executor, run_simulator, session_id_inner, common_cfg, sim_cfg, stop_tick + thread_executor, + run_simulator, + session_id_inner, + common_cfg, + sim_cfg, + stop_tick, ) futures.append(future) @@ -348,7 +366,9 @@ def _run_simulator() -> "Confirmation": else: return await loop.run_in_executor(self.executor, _run_simulator) - def _task_done_callback(self, future: asyncio.Future["Confirmation"], session_id: str) -> None: + def _task_done_callback( + self, future: asyncio.Future["Confirmation"], session_id: str + ) -> None: """ 任务完成时的回调函数。 @@ -377,13 +397,16 @@ async def _update_session_status( # 处理模拟结果确认信息 if isinstance(result, dict) and "run_turn_uuid" in result: - processed_result: ( - NormalModeResult | ParallelModeResult - ) = await self._process_simulation_result(result) + processed_result: NormalModeResult | ParallelModeResult = ( + await self._process_simulation_result(result) + ) try: session.session_result = [processed_result] except Exception as e: - logger.error(f"TODO: 模拟任务 {session_id} 结果处理: {repr(e)}", exc_info=True) + logger.error( + f"TODO: 模拟任务 {session_id} 结果处理: {repr(e)}", + exc_info=True, + ) except Exception as e: logger.error(f"模拟任务 {session_id} 执行失败: {e}", exc_info=True) @@ -502,9 +525,13 @@ def generate_parallel_args( func_cfg = parallel_cfg.func_config if func == "attr_curve" and isinstance(func_cfg, ParallelCfg.AttrCurveConfig): - yield from self._generate_attr_curve_args(func_cfg, parallel_cfg, stop_tick, session_id) + yield from self._generate_attr_curve_args( + func_cfg, parallel_cfg, stop_tick, session_id + ) elif func == "weapon" and isinstance(func_cfg, ParallelCfg.WeaponConfig): - yield from self._generate_weapon_args(func_cfg, parallel_cfg, stop_tick, session_id) + yield from self._generate_weapon_args( + func_cfg, parallel_cfg, stop_tick, session_id + ) else: error_msg = f"未知的func类型: {func}, 完整配置: {parallel_cfg}" logger.error(error_msg) diff --git a/zsim_api.spec b/zsim_api.spec index 8aba5e1f..2f2e35d8 100644 --- a/zsim_api.spec +++ b/zsim_api.spec @@ -72,6 +72,9 @@ hiddenimports = [ "fastapi", "uvicorn", "aiosqlite", + "sqlalchemy", + "alembic", + "greenlet", "httpx", "zsim", "plotly",