diff --git a/.env.example b/.env.example index caa630b..a39a3a7 100644 --- a/.env.example +++ b/.env.example @@ -3,12 +3,13 @@ DB_PORT=5432 DB_USER=postgres DB_PASSWORD=CHANGE_ME_STRONG_PASSWORD DB_ADMIN_DB=postgres -DB_NEWS=news_db +NEWS_DB=news_db NEWSAPI_KEY=YOUR_REAL_NEWSAPI_KEY NEWSAPI_URL=https://newsapi.org/v2/everything NEWSAPI_DEFAULT_LANGUAGE=ru NEWSAPI_SORT_BY=publishedAt +NEWS_API_KEY_ENCRYPTION_SECRET=replace_me_with_32+_random_chars REQUEST_TIMEOUT_SECONDS=15 REQUEST_MAX_RETRIES=3 diff --git a/config/config.py b/config/config.py index c423705..5459021 100644 --- a/config/config.py +++ b/config/config.py @@ -95,12 +95,13 @@ class Settings: db_user: str db_password: str db_admin_db: str - db_news: str + news_db: str newsapi_key: str news_url: str default_language: str sort_by: str + news_api_key_encryption_secret: str request_timeout_seconds: float request_max_retries: int @@ -145,11 +146,12 @@ def build_settings() -> Settings: db_user=_get_env_str("DB_USER", "postgres"), db_password=_get_env_str("DB_PASSWORD", "postgres"), db_admin_db=_get_env_str("DB_ADMIN_DB", "postgres"), - db_news=_get_env_str("DB_NEWS", "news_db"), + news_db=_get_env_str("NEWS_DB", "news_db"), newsapi_key=_get_env_str("NEWSAPI_KEY", ""), news_url=_get_env_str("NEWSAPI_URL", "https://newsapi.org/v2/everything"), default_language=_get_env_str("NEWSAPI_DEFAULT_LANGUAGE", "ru"), sort_by=sort_by, + news_api_key_encryption_secret = _get_env_str("NEWS_API_KEY_ENCRYPTION_SECRET", "NO"), request_timeout_seconds=_get_env_float("REQUEST_TIMEOUT_SECONDS", 15.0, min_value=1.0), request_max_retries=_get_env_int("REQUEST_MAX_RETRIES", 3, min_value=0), request_backoff_factor=_get_env_float("REQUEST_BACKOFF_FACTOR", 1.0, min_value=0.0), diff --git a/main.py b/main.py index 30500cf..ba87c75 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ create_request_stats_table, create_search_requests_table, create_user_news_table, + create_users_keys_table, ensure_tables_exist, init_database, run_debug_pipeline, @@ -99,6 +100,7 @@ def init_all_tables(debug: bool) -> None: create_articles_table() create_user_news_table() create_request_stats_table() + create_users_keys_table() if debug: create_news_tables() @@ -123,9 +125,9 @@ def _validate_web_context(user_id: int, search_request_id: int) -> None: def _ensure_runtime_schema(debug_mode: bool) -> None: - if not database_exists(settings.db_news): + if not database_exists(settings.news_db): raise RuntimeError( - f"Database '{settings.db_news}' does not exist. " + f"Database '{settings.news_db}' does not exist. " "Run `python main.py --init-only` once or re-run with `--bootstrap`." ) @@ -134,11 +136,12 @@ def _ensure_runtime_schema(debug_mode: bool) -> None: "search_requests", "articles", "user_news", + "users_keys", "request_stats", ] - ensure_tables_exist(settings.db_news, required_tables) + ensure_tables_exist(settings.news_db, required_tables) - if debug_mode and not table_exists(settings.db_news, "bad_news_bears"): + if debug_mode and not table_exists(settings.news_db, "bad_news_bears"): create_news_tables() diff --git a/requirements.txt b/requirements.txt index 5711f57..9011296 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ python-dotenv==1.2.2 psycopg2-binary>=2.9,<3.0 requests==2.32.5 +cryptography>=42.0,<46.0 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index aa41ac2..4e88663 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -8,6 +8,7 @@ create_request_stats_table, create_search_requests_table, create_user_news_table, + create_users_keys_table, database_exists, ensure_databases_exists, ensure_tables_exist, @@ -34,6 +35,7 @@ "create_request_stats_table", "create_search_requests_table", "create_user_news_table", + "create_users_keys_table", "database_exists", "ensure_databases_exists", "ensure_tables_exist", diff --git a/src/db.py b/src/db.py index cfe3789..3e921a0 100644 --- a/src/db.py +++ b/src/db.py @@ -161,7 +161,7 @@ def create_database_if_not_exists(db_name: str) -> None: def init_database() -> None: - create_database_if_not_exists(settings.db_news) + create_database_if_not_exists(settings.news_db) def create_search_requests_table() -> None: @@ -192,7 +192,7 @@ def create_search_requests_table() -> None: """, ] - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) for index_query in index_list: cur.execute(index_query) @@ -218,7 +218,7 @@ def create_articles_table() -> None: ON articles(published_at DESC) """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) cur.execute(index) conn.commit() @@ -252,7 +252,7 @@ def create_user_news_table() -> None: """, ] - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) for index_query in index_list: cur.execute(index_query) @@ -298,7 +298,7 @@ def create_request_stats_table() -> None: END $$; """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) cur.execute(trigger_function) cur.execute(trigger) @@ -318,7 +318,7 @@ def create_app_users_table() -> None: ) """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) conn.commit() @@ -338,10 +338,53 @@ def create_news_tables() -> None: ) """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) conn.commit() +def create_users_keys_table() -> None: + query = """ + CREATE TABLE IF NOT EXISTS users_keys( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES app_users(id) ON DELETE CASCADE, + service VARCHAR(50) NOT NULL, + encrypted_key TEXT NOT NULL, + iv TEXT NOT NULL, + auth_tag TEXT NOT NULL, + key_last4 VARCHAR(4) NOT NULL, + uploaded_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (user_id, service) + ); + """ + trigger_function = """ + CREATE OR REPLACE FUNCTION set_users_keys_updated_at() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """ + trigger = """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trg_users_keys_updated_at' + ) THEN + CREATE TRIGGER trg_users_keys_updated_at + BEFORE UPDATE ON users_keys + FOR EACH ROW + EXECUTE FUNCTION set_users_keys_updated_at(); + END IF; + END $$; + """ + with get_cursor(settings.news_db) as (conn, cur): + cur.execute(query) + cur.execute(trigger_function) + cur.execute(trigger) + conn.commit() + def claim_next_search_request() -> dict | None: query = """ @@ -363,7 +406,7 @@ def claim_next_search_request() -> dict | None: RETURNING sr.id, sr.user_id, sr.keyword, sr.language, sr.limit_count, sr.page_size """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query) row = cur.fetchone() conn.commit() @@ -372,20 +415,20 @@ def claim_next_search_request() -> dict | None: def search_request_exists(search_request_id: int) -> bool: query = "SELECT 1 FROM search_requests WHERE id = %s" - with get_cursor(settings.db_news, autocommit=True) as (_, cur): + with get_cursor(settings.news_db, autocommit=True) as (_, cur): cur.execute(query, (search_request_id,)) return cur.fetchone() is not None def app_user_exists(user_id: int) -> bool: query = "SELECT 1 FROM app_users WHERE id = %s" - with get_cursor(settings.db_news, autocommit=True) as (_, cur): + with get_cursor(settings.news_db, autocommit=True) as (_, cur): cur.execute(query, (user_id,)) return cur.fetchone() is not None def search_request_belongs_to_user(search_request_id: int, user_id: int) -> bool: query = "SELECT 1 FROM search_requests WHERE id = %s AND user_id = %s" - with get_cursor(settings.db_news, autocommit=True) as (_, cur): + with get_cursor(settings.news_db, autocommit=True) as (_, cur): cur.execute(query, (search_request_id, user_id)) return cur.fetchone() is not None diff --git a/src/extract.py b/src/extract.py index 0389fe2..30cd3b8 100644 --- a/src/extract.py +++ b/src/extract.py @@ -84,6 +84,7 @@ def _fetch_payload( page: int, page_size: int, language: str, + news_api_key: str | None = None, ) -> tuple[dict[str, Any], int]: if page <= 0: raise ValueError("page must be > 0") @@ -93,7 +94,7 @@ def _fetch_payload( normalized_page_size = min(page_size, settings.request_page_size_max) params = { - "apiKey": _require_newsapi_key(), + "apiKey": news_api_key or _require_newsapi_key(), "language": language, "q": key_word, "pageSize": normalized_page_size, @@ -189,8 +190,9 @@ def make_extract_web( page: int = 1, page_size: int = 20, language: str = "ru", + news_api_key: str | None = None, ) -> tuple[dict[str, Any], int]: - payload, articles_count = _fetch_payload(key_word, page, page_size, language) + payload, articles_count = _fetch_payload(key_word, page, page_size, language, news_api_key) if articles_count == 0: logger.info("There are no more articles for keyword=%s", key_word) return payload, articles_count diff --git a/src/load.py b/src/load.py index a2f52b8..bd0a0e8 100644 --- a/src/load.py +++ b/src/load.py @@ -42,7 +42,7 @@ def load_news(clean_news: str, max_rows: int | None = None) -> int: rows = data if max_rows is None else data[:max_rows] loaded_count = 0 - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): for news in rows: cur.execute( query, @@ -150,7 +150,7 @@ def load_request_stats(search_request_id: int, stats: dict[str, Any]) -> None: prime_reasons = EXCLUDED.prime_reasons """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute( query, ( @@ -182,7 +182,7 @@ def load_web_pipeline( rows = clean_data loaded_count = 0 - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): for article in rows: keyword = article["key_word"] fetched_at = article["fetched_at"] diff --git a/src/pipeline.py b/src/pipeline.py index 846ef7a..58434d8 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -44,6 +44,7 @@ def run_pipeline_for_web_user( limit: int, page_size: int, language: str | None = None, + news_api_key: str | None = None, ) -> int: if limit <= 0: raise ValueError("limit must be > 0") @@ -60,7 +61,7 @@ def run_pipeline_for_web_user( try: while loaded_total < limit and page <= max_pages: - payload, raw_articles_count = make_extract_web(key_word, page, effective_page_size, language_to_use) + payload, raw_articles_count = make_extract_web(key_word, page, effective_page_size, language_to_use, news_api_key=news_api_key) if page == 1: max_pages = _resolve_max_pages(payload, effective_page_size) diff --git a/src/user_news_api_key.py b/src/user_news_api_key.py new file mode 100644 index 0000000..a6b206d --- /dev/null +++ b/src/user_news_api_key.py @@ -0,0 +1,47 @@ +from __future__ import annotations +import base64 +import hashlib +from dataclasses import dataclass +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from config.config import settings +from .db import get_cursor + +@dataclass(frozen=True, slots=True) +class EncryptedNewsApiKey: + encrypted_key: str + iv: str + auth_tag: str + +def _get_encryption_key() -> bytes: + secret = settings.news_api_key_encryption_secret.strip() + if not secret: + raise RuntimeError("NEWS_API_KEY_ENCRYPTION_SECRET is not in config") + return hashlib.sha256(secret.encode("utf-8")).digest() + +def decrypt_news_api_key(row: EncryptedNewsApiKey) -> str: + aesgcm = AESGCM(_get_encryption_key()) + ciphertext = base64.b64decode(row.encrypted_key) + iv = base64.b64decode(row.iv) + auth_tag = base64.b64decode(row.auth_tag) + plaintext = aesgcm.decrypt(iv, ciphertext + auth_tag, None) + return plaintext.decode("utf-8") + +def get_decrypted_news_api_key_for_user(user_id: int) -> str | None: + query = """ + SELECT encrypted_key, iv, auth_tag + FROM users_keys + WHERE user_id = %s AND service = %s + LIMIT 1 + """ + with get_cursor(settings.news_db, autocommit=True) as (_, cur): + cur.execute(query, (user_id, "news_api")) + row = cur.fetchone() + if not row: + return None + + return decrypt_news_api_key( + EncryptedNewsApiKey( + encrypted_key = row["encrypted_key"], + iv = row["iv"], + auth_tag = row["auth_tag"],) + ) \ No newline at end of file diff --git a/src/worker.py b/src/worker.py index a750171..0921b67 100644 --- a/src/worker.py +++ b/src/worker.py @@ -7,6 +7,7 @@ from .db import claim_next_search_request, get_cursor from .pipeline import run_pipeline_for_web_user +from .user_news_api_key import get_decrypted_news_api_key_for_user logging.basicConfig( level=logging.INFO, @@ -25,7 +26,7 @@ def mark_as_success(search_request_id: int) -> None: WHERE id = %s """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query, (search_request_id,)) if cur.rowcount != 1: logger.warning("Search request %s was not marked as success", search_request_id) @@ -42,7 +43,7 @@ def mark_as_error(search_request_id: int, error_text: str) -> None: WHERE id = %s """ - with get_cursor(settings.db_news) as (conn, cur): + with get_cursor(settings.news_db) as (conn, cur): cur.execute(query, (error_text[:2000], search_request_id)) if cur.rowcount != 1: logger.warning("Search request %s was not marked as failed", search_request_id) @@ -69,6 +70,10 @@ def one_request() -> bool: ) try: + news_api_key = get_decrypted_news_api_key_for_user(user_id) + if not news_api_key: + raise RuntimeError("User has no NEWSAPI key yet.") + amount_of_articles = run_pipeline_for_web_user( user_id=user_id, search_request_id=search_request_id, @@ -76,6 +81,7 @@ def one_request() -> bool: limit=limit_count, page_size=page_size, language=language, + news_api_key=news_api_key, ) mark_as_success(search_request_id)