diff --git a/ADVANCED.md b/ADVANCED.md index 194a6f1..774846c 100644 --- a/ADVANCED.md +++ b/ADVANCED.md @@ -150,7 +150,7 @@ if response.raw_data: - Raw data is stored in the `raw_data` attribute of response objects - This increases memory usage as both parsed and raw data are kept - Only enable when needed for debugging - disable for better performance -- All response models support `raw_data` when this flag is enabled +- Response models support `raw_data` when this flag is enabled (except StatusCodeSearchResponse) **Example: Debugging a missing field** @@ -180,6 +180,7 @@ All warnings inherit from `USPTODataWarning`: - `USPTOBooleanParseWarning`: Y/N boolean string parsing failures - `USPTOTimezoneWarning`: Timezone-related issues - `USPTOEnumParseWarning`: Enum value parsing failures +- `USPTODataMismatchWarning`: API returns data with different identifier than requested **Controlling Warnings** @@ -190,7 +191,8 @@ from pyUSPTO.warnings import ( USPTODateParseWarning, USPTOBooleanParseWarning, USPTOTimezoneWarning, - USPTOEnumParseWarning + USPTOEnumParseWarning, + USPTODataMismatchWarning ) # Suppress all pyUSPTO data warnings diff --git a/README.md b/README.md index e749b37..2199223 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,11 @@ pip install pyUSPTO > You must have an API key for the [USPTO Open Data Portal API](https://data.uspto.gov/myodp/landing). ```python -from pyUSPTO import PatentDataClient +from pyUSPTO import PatentDataClient, USPTOConfig -# Initialize with your API key -client = PatentDataClient(api_key="your_api_key_here") +# Initialize with config +config = USPTOConfig(api_key="your_api_key_here") +client = PatentDataClient(config=config) # Search for patent applications results = client.search_applications(inventor_name_q="Smith", limit=10) @@ -36,32 +37,9 @@ print(f"Found {results.count} applications") ## Configuration -All clients can be configured using one of three methods: - -### Method 1: Direct API Key Initialization - -> [!NOTE] -> This method is convenient for quick scripts but not recommended. Consider using environment variables instead. - -```python -from pyUSPTO import ( - BulkDataClient, - PatentDataClient, - FinalPetitionDecisionsClient, - PTABTrialsClient, - PTABAppealsClient, - PTABInterferencesClient -) - -patent_client = PatentDataClient(api_key="your_api_key_here") -bulk_client = BulkDataClient(api_key="your_api_key_here") -petition_client = FinalPetitionDecisionsClient(api_key="your_api_key_here") -trials_client = PTABTrialsClient(api_key="your_api_key_here") -appeals_client = PTABAppealsClient(api_key="your_api_key_here") -interferences_client = PTABInterferencesClient(api_key="your_api_key_here") -``` +All clients require a `USPTOConfig` object for configuration. There are two methods: -### Method 2: Using USPTOConfig +### Method 1: Using USPTOConfig ```python from pyUSPTO import ( @@ -85,7 +63,7 @@ appeals_client = PTABAppealsClient(config=config) interferences_client = PTABInterferencesClient(config=config) ``` -### Method 3: Environment Variables (Recommended) +### Method 2: Environment Variables (Recommended) Set the environment variable in your shell: @@ -127,7 +105,8 @@ interferences_client = PTABInterferencesClient(config=config) ```python from pyUSPTO import PatentDataClient -client = PatentDataClient(api_key="your_api_key_here") +config = USPTOConfig(api_key="your_api_key_here") +client = PatentDataClient(config=config) # Search for applications by inventor name response = client.search_applications(inventor_name_q="Smith", limit=2) @@ -146,7 +125,8 @@ See [`examples/patent_data_example.py`](examples/patent_data_example.py) for det ```python from pyUSPTO import FinalPetitionDecisionsClient -client = FinalPetitionDecisionsClient(api_key="your_api_key_here") +config = USPTOConfig(api_key="your_api_key_here") +client = FinalPetitionDecisionsClient(config=config) # Search for petition decisions response = client.search_decisions( @@ -169,7 +149,8 @@ See [`examples/petition_decisions_example.py`](examples/petition_decisions_examp ```python from pyUSPTO import PTABTrialsClient -client = PTABTrialsClient(api_key="your_api_key_here") +config = USPTOConfig(api_key="your_api_key_here") +client = PTABTrialsClient(config=config) # Search for IPR proceedings response = client.search_proceedings( @@ -196,7 +177,8 @@ See [`examples/ptab_trials_example.py`](examples/ptab_trials_example.py) for det ```python from pyUSPTO import PTABAppealsClient -client = PTABAppealsClient(api_key="your_api_key_here") +config = USPTOConfig(api_key="your_api_key_here") +client = PTABAppealsClient(config=config) # Search for appeal decisions response = client.search_decisions( @@ -215,7 +197,8 @@ See [`examples/ptab_appeals_example.py`](examples/ptab_appeals_example.py) for d ```python from pyUSPTO import PTABInterferencesClient -client = PTABInterferencesClient(api_key="your_api_key_here") +config = USPTOConfig(api_key="your_api_key_here") +client = PTABInterferencesClient(config=config) # Search for interference decisions response = client.search_decisions( @@ -247,11 +230,11 @@ The library uses Python dataclasses to represent API responses. All data models - `PatentDataResponse`: Top-level response from the API - `PatentFileWrapper`: Information about a patent application - `ApplicationMetaData`: Metadata about a patent application -- `Address`: Represents an address in the patent data - `Person`, `Applicant`, `Inventor`, `Attorney`: Person-related data classes - `Assignment`, `Assignor`, `Assignee`: Assignment-related data classes - `Continuity`, `ParentContinuity`, `ChildContinuity`: Continuity-related data classes - `PatentTermAdjustmentData`: Patent term adjustment information +- `DocumentBag`, `EntityStatus`, `RecordAttorney`: Additional data classes for patent data - And many more specialized classes for different aspects of patent data #### Final Petition Decisions API @@ -259,7 +242,6 @@ The library uses Python dataclasses to represent API responses. All data models - `PetitionDecisionResponse`: Top-level response from the API - `PetitionDecision`: Complete information about a petition decision - `PetitionDecisionDocument`: Document associated with a petition decision -- `DocumentDownloadOption`: Download options for petition documents - `DecisionTypeCode`: Enum for petition decision types - `DocumentDirectionCategory`: Enum for document direction categories @@ -267,26 +249,31 @@ The library uses Python dataclasses to represent API responses. All data models - `PTABTrialProceedingResponse`: Top-level response from the API - `PTABTrialProceeding`: Information about a PTAB trial proceeding (IPR, PGR, CBM, DER) +- `PTABTrialDocumentResponse`: Response containing trial documents - `PTABTrialDocument`: Document associated with a trial proceeding -- `PTABTrialDecision`: Decision information for a trial proceeding +- `TrialDecisionData`: Decision information for a trial proceeding +- `TrialDocumentData`: Document metadata for trial documents +- `TrialMetaData`: Trial metadata and status information - `RegularPetitionerData`, `RespondentData`, `DerivationPetitionerData`: Party data for different trial types -- `PTABTrialMetaData`: Trial metadata and status information #### PTAB Appeals API - `PTABAppealResponse`: Top-level response from the API - `PTABAppealDecision`: Ex parte appeal decision information - `AppellantData`: Appellant information and application details -- `PTABAppealMetaData`: Appeal metadata and filing information -- `PTABAppealDocumentData`: Document and decision details +- `AppealMetaData`: Appeal metadata and filing information +- `AppealDocumentData`: Document and decision details #### PTAB Interferences API - `PTABInterferenceResponse`: Top-level response from the API - `PTABInterferenceDecision`: Interference proceeding decision information - `SeniorPartyData`, `JuniorPartyData`, `AdditionalPartyData`: Party data classes -- `PTABInterferenceMetaData`: Interference metadata and status information -- `PTABInterferenceDocumentData`: Document and outcome details +- `InterferenceMetaData`: Interference metadata and status information +- `InterferenceDocumentData`: Document and outcome details +- `DecisionData`: Decision information for interference proceedings + +For a complete list of all data models, see the [API Reference docuentation](https://pyuspto.readthedocs.io/en/latest/api/models/index.html). ## Advanced Topics diff --git a/examples/bulk_data_example.py b/examples/bulk_data_example.py index f798f5a..ca05569 100644 --- a/examples/bulk_data_example.py +++ b/examples/bulk_data_example.py @@ -1,4 +1,4 @@ -"""Example usage of the BulkDataClient. +"""Example usage of pyUSPTO for the BulkDataClient. This example demonstrates how to use the BulkDataClient to interact with the USPTO Bulk Data API. It shows how to search for products, retrieve product details, and download files. @@ -37,18 +37,13 @@ def format_size(size_bytes: int | float) -> str: # Client Initialization Methods # ============================================================================ -# Method 1: Initialize with API key directly -print("Method 1: Initialize with direct API key") -api_key = "YOUR_API_KEY_HERE" # Replace with your actual API key -client = BulkDataClient(api_key=api_key) - -# Method 2: Initialize with USPTOConfig object -print("\nMethod 2: Initialize with USPTOConfig") +# Method 1: Initialize with USPTOConfig object +print("\nMethod 1: Initialize with USPTOConfig") config = USPTOConfig(api_key="YOUR_API_KEY_HERE") client = BulkDataClient(config=config) -# Method 3: Initialize from environment variables (recommended) -print("\nMethod 3: Initialize from environment variables") +# Method 2: Initialize from environment variables (recommended) +print("\nMethod 2: Initialize from environment variables") os.environ["USPTO_API_KEY"] = "YOUR_API_KEY_HERE" # Set this outside your script config_from_env = USPTOConfig.from_env() client = BulkDataClient(config=config_from_env) diff --git a/examples/error_handling_example.py b/examples/error_handling_example.py index 3188ef1..2b75aff 100644 --- a/examples/error_handling_example.py +++ b/examples/error_handling_example.py @@ -19,13 +19,14 @@ # Initialize client api_key = os.environ.get("USPTO_API_KEY", "YOUR_API_KEY_HERE") -client = PatentDataClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = PatentDataClient(config=config) # Example 1: Handle authentication errors print("Example 1: Authentication errors") try: # This will fail with invalid API key - bad_client = PatentDataClient(api_key="invalid_key") + bad_client = PatentDataClient(config=config) bad_client.search_applications(limit=1) except USPTOApiAuthError as e: print(f"Authentication failed: {e}") diff --git a/examples/ifw_example.py b/examples/ifw_example.py index 0f64a8b..9e185e8 100644 --- a/examples/ifw_example.py +++ b/examples/ifw_example.py @@ -8,14 +8,15 @@ import os from pyUSPTO.clients.patent_data import PatentDataClient +from pyUSPTO.config import USPTOConfig api_key = os.environ.get("USPTO_API_KEY", "YOUR_API_KEY_HERE") if api_key == "YOUR_API_KEY_HERE": raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) - -client = PatentDataClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = PatentDataClient(config=config) print("\nBeginning API requests with configured client:") diff --git a/examples/patent_data_example.py b/examples/patent_data_example.py index 1ca30be..48e0998 100644 --- a/examples/patent_data_example.py +++ b/examples/patent_data_example.py @@ -1,4 +1,4 @@ -"""Example usage of the uspto_api module for patent data. +"""Example usage of the pyUSPTO module for patent data. This example demonstrates how to use the PatentDataClient to interact with the USPTO Patent Data API. It shows how to retrieve patent applications, search for patents by various criteria, and access @@ -9,17 +9,19 @@ import os from pyUSPTO.clients.patent_data import PatentDataClient +from pyUSPTO.config import USPTOConfig from pyUSPTO.models.patent_data import ApplicationContinuityData # --- Initialization --- # Initialize the client with API key from ENV Var. -print("Initialize with direct API key") +print("Initialize with config") api_key = os.environ.get("USPTO_API_KEY", "YOUR_API_KEY_HERE") if api_key == "YOUR_API_KEY_HERE": raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) -client = PatentDataClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = PatentDataClient(config=config) DEST_PATH = "./download-example" diff --git a/examples/petition_decisions_example.py b/examples/petition_decisions_example.py index 62368b2..f22ffdb 100644 --- a/examples/petition_decisions_example.py +++ b/examples/petition_decisions_example.py @@ -10,6 +10,7 @@ import os from pyUSPTO.clients import FinalPetitionDecisionsClient +from pyUSPTO.config import USPTOConfig from pyUSPTO.models.petition_decisions import PetitionDecisionDownloadResponse # --- Initialization --- @@ -20,7 +21,8 @@ raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) -client = FinalPetitionDecisionsClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = FinalPetitionDecisionsClient(config=config) DEST_PATH = "./download-example" diff --git a/examples/ptab_appeals_example.py b/examples/ptab_appeals_example.py index 24c098b..9c2a31f 100644 --- a/examples/ptab_appeals_example.py +++ b/examples/ptab_appeals_example.py @@ -10,6 +10,7 @@ import os from pyUSPTO import PTABAppealsClient +from pyUSPTO.config import USPTOConfig # --- Initialization --- # Initialize the client with direct API key @@ -19,7 +20,8 @@ raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) -client = PTABAppealsClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = PTABAppealsClient(config=config) print("\nBeginning PTAB Appeals API requests with configured client:") diff --git a/examples/ptab_interferences_example.py b/examples/ptab_interferences_example.py index d09dbbd..9f90892 100644 --- a/examples/ptab_interferences_example.py +++ b/examples/ptab_interferences_example.py @@ -11,6 +11,7 @@ import os from pyUSPTO import PTABInterferencesClient +from pyUSPTO.config import USPTOConfig # --- Initialization --- # Initialize the client with direct API key @@ -20,7 +21,8 @@ raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) -client = PTABInterferencesClient(api_key=api_key) +config = USPTOConfig(api_key=api_key) +client = PTABInterferencesClient(config=config) print("\nBeginning PTAB Interferences API requests with configured client:") diff --git a/examples/ptab_trials_example.py b/examples/ptab_trials_example.py index 779eeeb..dd21e84 100644 --- a/examples/ptab_trials_example.py +++ b/examples/ptab_trials_example.py @@ -14,6 +14,7 @@ import os from pyUSPTO import PTABTrialsClient +from pyUSPTO.config import USPTOConfig # --- Initialization --- # Initialize the client with direct API key @@ -23,7 +24,9 @@ raise ValueError( "WARNING: API key is not set. Please replace 'YOUR_API_KEY_HERE' or set USPTO_API_KEY environment variable." ) -client = PTABTrialsClient(api_key=api_key) + +config = USPTOConfig(api_key=api_key) +client = PTABTrialsClient(config=config) print("\nBeginning PTAB Trials API requests with configured client:") diff --git a/pyproject.toml b/pyproject.toml index 19ad63a..6b86858 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pyUSPTO" -description = "A Modern Python client for accessing United Stated Patent and Trademark Office (USPTO) Open Data Portal (ODP) APIs." +description = "A Modern Python client for accessing the United States Patent and Trademark Office (USPTO) Open Data Portal (ODP) APIs." authors = [ { name = "Andrew Piechocki", email = "apiechocki@dunlapcodding.com" }, ] @@ -46,7 +46,7 @@ keywords = ["uspto", "patent", "odp", "client", "bulk data", "patent data"] dependencies = [ "requests>=2.32.5", "typing-extensions>=4.15.0; python_version < '3.11'", - "tzdata>=2025.2", + "tzdata>=2025.3", ] dynamic = ["version"] @@ -67,7 +67,7 @@ dev = [ "mypy>=1.19.0", "types-requests>=2.32.4", # Code quality and formatting - "ruff>=0.8.0", + "ruff>=0.15.0", ] [project.urls] diff --git a/requirements-dev.txt b/requirements-dev.txt index 109a2a9..ba124ff 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -37,7 +37,7 @@ jinja2==3.1.6 # via # myst-parser # sphinx -librt==0.8.0 +librt==0.8.1 # via mypy markdown-it-py==3.0.0 # via @@ -95,7 +95,7 @@ requests==2.32.5 # pyUSPTO (pyproject.toml) # sphinx # sphinx-immaterial -ruff==0.15.1 +ruff==0.15.2 # via pyUSPTO (pyproject.toml) snowballstemmer==3.0.1 # via sphinx diff --git a/src/pyUSPTO/clients/base.py b/src/pyUSPTO/clients/base.py index 670e132..e85ff46 100644 --- a/src/pyUSPTO/clients/base.py +++ b/src/pyUSPTO/clients/base.py @@ -15,8 +15,6 @@ ) import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry from pyUSPTO.config import USPTOConfig from pyUSPTO.exceptions import ( @@ -26,7 +24,6 @@ USPTOTimeout, get_api_exception, ) -from pyUSPTO.http_config import ALLOWED_METHODS @runtime_checkable @@ -48,124 +45,61 @@ class BaseUSPTOClient(Generic[T]): def __init__( self, - api_key: str | None = None, base_url: str = "", config: USPTOConfig | None = None, ): """Initialize the BaseUSPTOClient. Args: - api_key: API key for authentication base_url: The base URL of the API - config: Optional USPTOConfig instance. When multiple clients share the same - config object, they automatically share an HTTP session for better - performance and connection pooling. + config: USPTOConfig instance containing API key and HTTP settings. + When multiple clients share the same config object, they automatically + share an HTTP session for better performance and connection pooling. + If not provided, creates a default config (requires USPTO_API_KEY + environment variable). """ - # Handle config if provided - if config: - self.config = config - self._api_key = api_key or config.api_key + # Use provided config or create default from environment + if config is None: + self.config = USPTOConfig.from_env() else: - # Backward compatibility: create minimal config - self.config = USPTOConfig(api_key=api_key) - self._api_key = api_key - - self.base_url = base_url.rstrip("/") + self.config = config - # Extract HTTP config for session creation + # Store API key and HTTP config from config + self._api_key = self.config.api_key self.http_config = self.config.http_config - # Use shared session from config if available, otherwise create new one - if self.config._shared_session is not None: - # Reuse existing shared session - self.session = self.config._shared_session - self._owns_session = False - # Still apply API key headers in case this client has a different key - self._apply_session_headers() - else: - # Create new session and store in config for sharing - self.session = self._create_session() - self.config._shared_session = self.session - self._owns_session = True - - def _apply_session_headers(self) -> None: - """Apply API key and custom headers to the session. - - This is separated from _create_session so it can be used when - a session is injected from outside. - """ - # Set API key and default headers - if self._api_key: - self.session.headers.update( - {"X-API-KEY": self._api_key, "content-type": "application/json"} - ) + self.base_url = base_url.rstrip("/") - # Apply custom headers from HTTP config - if self.http_config.custom_headers: - self.session.headers.update(self.http_config.custom_headers) + # No session creation here - clients use config's session + # Session is accessed via property: self.session -> self.config.session - def _create_session(self) -> requests.Session: - """Create configured HTTP session from HTTPConfig settings. + @property + def session(self) -> requests.Session: + """Get the HTTP session from config. Returns: - Configured requests.Session instance + Session: The requests Session configured by this client's config. """ - session = requests.Session() - self.session = session - - # Apply headers using shared helper - self._apply_session_headers() - - # Configure retry strategy from HTTP config - retry_strategy = Retry( - total=self.http_config.max_retries, - backoff_factor=self.http_config.backoff_factor, - status_forcelist=( - self.http_config.retry_status_codes - if self.http_config.max_retries > 0 - else [] - ), - allowed_methods=ALLOWED_METHODS, - ) - - # Create adapter with retry and connection pool settings - adapter = HTTPAdapter( - max_retries=retry_strategy, - pool_connections=self.http_config.pool_connections, - pool_maxsize=self.http_config.pool_maxsize, - ) - - session.mount("http://", adapter) - session.mount("https://", adapter) - - return session + return self.config.session def close(self) -> None: - """Close the HTTP session and release connection pool resources. - - This method should be called when you're done using the client to ensure - proper cleanup of connection pools and resources. Alternatively, use the - client as a context manager for automatic cleanup. + """Close the client and release resources. - Note: If a session was provided via the `session` parameter during - initialization, this method will NOT close it, as the client does not - own the session lifecycle. Only sessions created by the client are closed. + Note: This method does NOT close the HTTP session, as clients do not + own sessions. To close the session, call close() on the USPTOConfig + object that was used to create this client. Example: - client = PatentDataClient(api_key="...") + config = USPTOConfig(api_key="...") + client = PatentDataClient(config=config) try: # Use client pass finally: - client.close() + config.close() # Close config, not client """ - if hasattr(self, "_owns_session") and self._owns_session: - if hasattr(self, "session") and self.session: - self.session.close() - elif not hasattr(self, "_owns_session"): - # Backward compatibility: if _owns_session not set, close anyway - if hasattr(self, "session") and self.session: - self.session.close() + # Nothing to do - client doesn't own any resources + pass def __enter__(self) -> "BaseUSPTOClient[T]": """Enter context manager, returning the client instance. @@ -174,10 +108,10 @@ def __enter__(self) -> "BaseUSPTOClient[T]": Self for use in with statements Example: - with PatentDataClient(api_key="...") as client: + config = USPTOConfig(api_key="...") + with PatentDataClient(config=config) as client: response = client.search_applications(...) """ - USPTOConfig._active_clients += 1 return self def __exit__( @@ -186,17 +120,13 @@ def __exit__( exc_val: BaseException | None, exc_tb: Any | None, ) -> None: - """Exit context manager, ensuring session cleanup. + """Exit context manager. - Args: - exc_type: Exception type if an exception occurred - exc_val: Exception value if an exception occurred - exc_tb: Exception traceback if an exception occurred + Note: Does not close the session. Use context manager on + USPTOConfig instead if session cleanup is needed. """ - USPTOConfig._active_clients -= 1 - if USPTOConfig._active_clients == 0: - USPTOConfig._shared_session = None - self.close() + # Nothing to close - client doesn't own session + pass def _parse_json_response( self, response: requests.Response, url: str @@ -651,24 +581,51 @@ def _save_response_to_file( return str(final_path) + def _is_safe_path(self, base_dir: Path, target_path: Path) -> bool: + """Check if target_path is within base_dir (prevents path traversal). + + Args: + base_dir: The intended base directory + target_path: The path to validate + + Returns: + True if target_path is safely within base_dir, False otherwise + """ + # Resolve both paths to absolute paths + base_resolved = base_dir.resolve() + target_resolved = target_path.resolve() + + # Check if base_dir is in the target's parent chain + return ( + base_resolved in target_resolved.parents or base_resolved == target_resolved + ) + def _extract_archive( self, archive_path: Path, extract_to: Path | None = None, remove_archive: bool = False, + max_size: int | None = None, ) -> str: - """Extract TAR or ZIP archive. + """Extract TAR or ZIP archive with security protections. + + Protects against path traversal attacks. Optional protection against + zip bombs via max_size parameter. Args: archive_path: Path to archive file extract_to: Directory to extract to (default: archive_path.stem) remove_archive: Delete archive after extraction + max_size: Optional maximum total extracted size in bytes. If None (default), + no size limit is enforced. Set this to protect against zip bombs. Returns: Path to extracted content (single file: path to file, multiple files: directory path) Raises: - ValueError: If file is not a valid TAR or ZIP archive + ValueError: If file is not a valid TAR or ZIP archive, or if archive + contains paths that would extract outside the target directory, + or if max_size is set and extraction would exceed it """ import tarfile import zipfile @@ -679,14 +636,64 @@ def _extract_archive( extract_to.mkdir(parents=True, exist_ok=True) extracted_items = [] + total_size = 0 + if tarfile.is_tarfile(archive_path): with tarfile.open(archive_path, "r:*") as tar: - tar.extractall(path=extract_to) - extracted_items = [m.name for m in tar.getmembers() if m.isfile()] + # Extract members one by one with validation + for member in tar.getmembers(): + # Skip directories + if member.isdir(): + continue + + # Path traversal check + member_path = extract_to / member.name + if not self._is_safe_path(extract_to, member_path): + raise ValueError( + f"Archive contains unsafe path that would extract outside target directory: {member.name}" + ) + + # Optional zip bomb check + if max_size is not None: + total_size += member.size + if total_size > max_size: + raise ValueError( + f"Archive extraction aborted: total size ({total_size} bytes) " + f"exceeds maximum allowed ({max_size} bytes)" + ) + + # Extract individual member + tar.extract(member, path=extract_to) + extracted_items.append(member.name) + elif zipfile.is_zipfile(archive_path): with zipfile.ZipFile(archive_path, "r") as zip_ref: - zip_ref.extractall(extract_to) - extracted_items = [n for n in zip_ref.namelist() if not n.endswith("/")] + # Extract members one by one with validation + for zip_info in zip_ref.infolist(): + # Skip directories + if zip_info.is_dir(): + continue + + # Path traversal check + member_path = extract_to / zip_info.filename + if not self._is_safe_path(extract_to, member_path): + raise ValueError( + f"Archive contains unsafe path that would extract outside target directory: {zip_info.filename}" + ) + + # Optional zip bomb check + if max_size is not None: + total_size += zip_info.file_size + if total_size > max_size: + raise ValueError( + f"Archive extraction aborted: total size ({total_size} bytes) " + f"exceeds maximum allowed ({max_size} bytes)" + ) + + # Extract individual member + zip_ref.extract(zip_info, path=extract_to) + extracted_items.append(zip_info.filename) + else: raise ValueError(f"Not a valid TAR/ZIP archive: {archive_path}") diff --git a/src/pyUSPTO/clients/bulk_data.py b/src/pyUSPTO/clients/bulk_data.py index 8823700..bb6bc52 100644 --- a/src/pyUSPTO/clients/bulk_data.py +++ b/src/pyUSPTO/clients/bulk_data.py @@ -28,27 +28,31 @@ class BulkDataClient(BaseUSPTOClient[BulkDataResponse]): def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the BulkDataClient. Args: - api_key: Optional API key for authentication - base_url: The base URL of the API, defaults to config.bulk_data_base_url or "https://api.uspto.gov/api/v1/datasets" - config: Optional USPTOConfig instance + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO Bulk Data API. + If not provided, uses config.bulk_data_base_url or default. """ - # Use config if provided, otherwise create default config - self.config = config or USPTOConfig(api_key=api_key) - - # Use provided API key or get from config - api_key = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config - # Use provided base_url or get from config - base_url = base_url or self.config.bulk_data_base_url + # Determine effective base URL + effective_base_url = base_url or self.config.bulk_data_base_url - super().__init__(api_key=api_key, base_url=base_url, config=self.config) + # Initialize base client + super().__init__( + base_url=effective_base_url, + config=self.config, + ) def get_product_by_id( self, @@ -134,11 +138,11 @@ def download_file( destination: str | None = None, file_name: str | None = None, overwrite: bool = False, - extract: bool = True, + extract: bool = False, ) -> str: """Download a file from the bulk data API. - Automatically extracts archives (tar.gz, zip) by default. The download + Does not extract archives (tar.gz, zip) by default. The download uses base class helpers for consistent behavior across all clients. Args: @@ -146,7 +150,7 @@ def download_file( destination: Directory to save/extract to. Defaults to current directory. file_name: Override filename. Defaults to file_data.file_name. overwrite: Whether to overwrite existing files. Defaults to False. - extract: Whether to auto-extract archives. Defaults to True. + extract: Whether to auto-extract archives. Defaults to False. Returns: str: Path to downloaded file or extracted directory. @@ -156,22 +160,22 @@ def download_file( Examples: Download and extract a file: - >>> product = client.get_product_by_id("product-123", include_files=True) + >>> product = client.get_product_by_id("product-123", include_files=True, extract=True) >>> file_data = product.product_file_bag.file_data_bag[0] >>> path = client.download_file(file_data, destination="./downloads") Download without extraction: >>> path = client.download_file(file_data, extract=False) """ - # Resolve filename default_file_name = file_name or file_data.file_name - - # Construct URL from endpoint - endpoint = self.ENDPOINTS["download_file"].format( - productIdentifier=file_data.product_identifier, - fileName=default_file_name, - ) - download_url = f"{self.base_url}/{endpoint}" + if file_data.file_download_uri: + download_url = file_data.file_download_uri + else: + endpoint = self.ENDPOINTS["download_file"].format( + productIdentifier=file_data.product_identifier, + fileName=default_file_name, + ) + download_url = f"{self.base_url}/{endpoint}" # Delegate to base class helpers if extract: @@ -189,16 +193,11 @@ def download_file( overwrite=overwrite, ) - def paginate_products( - self, post_body: dict[str, Any] | None = None, **kwargs: Any - ) -> Iterator[BulkDataProduct]: + def paginate_products(self, **kwargs: Any) -> Iterator[BulkDataProduct]: """Paginate through all products matching the search criteria. - Supports both GET and POST requests. - Args: - post_body: Optional POST body for complex search queries - **kwargs: Keyword arguments for GET-based pagination + **kwargs: Keyword arguments passed to search_products Yields: BulkDataProduct objects @@ -206,7 +205,6 @@ def paginate_products( return self.paginate_results( method_name="search_products", response_container_attr="bulk_data_product_bag", - post_body=post_body, **kwargs, ) diff --git a/src/pyUSPTO/clients/patent_data.py b/src/pyUSPTO/clients/patent_data.py index b88879f..93125e9 100644 --- a/src/pyUSPTO/clients/patent_data.py +++ b/src/pyUSPTO/clients/patent_data.py @@ -53,24 +53,30 @@ class PatentDataClient(BaseUSPTOClient[PatentDataResponse]): def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the PatentDataClient. Args: - api_key: USPTO API key. If not provided, uses key from config or environment. - base_url: Base URL for the USPTO Patent Data API. Defaults to https://api.uspto.gov. - config: USPTOConfig instance. If not provided, creates one with the given api_key. + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO Patent Data API. + If not provided, uses config.patent_data_base_url or default. """ - self.config = config or USPTOConfig(api_key=api_key) - api_key_to_use = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config + + # Determine effective base URL effective_base_url = ( base_url or self.config.patent_data_base_url or "https://api.uspto.gov" ) + + # Initialize base client super().__init__( - api_key=api_key_to_use, base_url=effective_base_url, config=self.config, ) diff --git a/src/pyUSPTO/clients/petition_decisions.py b/src/pyUSPTO/clients/petition_decisions.py index c1e492d..612db65 100644 --- a/src/pyUSPTO/clients/petition_decisions.py +++ b/src/pyUSPTO/clients/petition_decisions.py @@ -40,26 +40,32 @@ class FinalPetitionDecisionsClient(BaseUSPTOClient[PetitionDecisionResponse]): def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the FinalPetitionDecisionsClient. Args: - api_key: Optional API key for authentication. - base_url: Optional base URL override for the API. - config: Optional USPTOConfig instance for configuration. + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO Final Petition Decisions API. + If not provided, uses config.petition_decisions_base_url or default. """ - self.config = config or USPTOConfig(api_key=api_key) - api_key_to_use = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config + + # Determine effective base URL effective_base_url = ( base_url or self.config.petition_decisions_base_url or "https://api.uspto.gov" ) + + # Initialize base client super().__init__( - api_key=api_key_to_use, base_url=effective_base_url, config=self.config, ) diff --git a/src/pyUSPTO/clients/ptab_appeals.py b/src/pyUSPTO/clients/ptab_appeals.py index 7b393b8..0ca6117 100644 --- a/src/pyUSPTO/clients/ptab_appeals.py +++ b/src/pyUSPTO/clients/ptab_appeals.py @@ -33,24 +33,30 @@ class PTABAppealsClient(BaseUSPTOClient[PTABAppealResponse]): def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the PTABAppealsClient. Args: - api_key: Optional API key for authentication. - base_url: Optional base URL override for the API. - config: Optional USPTOConfig instance for configuration. + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO PTAB API. + If not provided, uses config.ptab_base_url or default. """ - self.config = config or USPTOConfig(api_key=api_key) - api_key_to_use = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config + + # Determine effective base URL effective_base_url = ( base_url or self.config.ptab_base_url or "https://api.uspto.gov" ) + + # Initialize base client super().__init__( - api_key=api_key_to_use, base_url=effective_base_url, config=self.config, ) diff --git a/src/pyUSPTO/clients/ptab_interferences.py b/src/pyUSPTO/clients/ptab_interferences.py index f27c20b..795ac9b 100644 --- a/src/pyUSPTO/clients/ptab_interferences.py +++ b/src/pyUSPTO/clients/ptab_interferences.py @@ -33,24 +33,30 @@ class PTABInterferencesClient(BaseUSPTOClient[PTABInterferenceResponse]): def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the PTABInterferencesClient. Args: - api_key: Optional API key for authentication. - base_url: Optional base URL override for the API. - config: Optional USPTOConfig instance for configuration. + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO PTAB API. + If not provided, uses config.ptab_base_url or default. """ - self.config = config or USPTOConfig(api_key=api_key) - api_key_to_use = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config + + # Determine effective base URL effective_base_url = ( base_url or self.config.ptab_base_url or "https://api.uspto.gov" ) + + # Initialize base client super().__init__( - api_key=api_key_to_use, base_url=effective_base_url, config=self.config, ) diff --git a/src/pyUSPTO/clients/ptab_trials.py b/src/pyUSPTO/clients/ptab_trials.py index 4ea8fd4..c8ed371 100644 --- a/src/pyUSPTO/clients/ptab_trials.py +++ b/src/pyUSPTO/clients/ptab_trials.py @@ -39,24 +39,30 @@ class PTABTrialsClient( def __init__( self, - api_key: str | None = None, - base_url: str | None = None, config: USPTOConfig | None = None, + base_url: str | None = None, ): """Initialize the PTABTrialsClient. Args: - api_key: Optional API key for authentication. - base_url: Optional base URL override for the API. - config: Optional USPTOConfig instance for configuration. + config: USPTOConfig instance containing API key and settings. If not provided, + creates config from environment variables (requires USPTO_API_KEY). + base_url: Optional base URL override for the USPTO PTAB API. + If not provided, uses config.ptab_base_url or default. """ - self.config = config or USPTOConfig(api_key=api_key) - api_key_to_use = api_key or self.config.api_key + # Use provided config or create from environment + if config is None: + self.config = USPTOConfig.from_env() + else: + self.config = config + + # Determine effective base URL effective_base_url = ( base_url or self.config.ptab_base_url or "https://api.uspto.gov" ) + + # Initialize base client super().__init__( - api_key=api_key_to_use, base_url=effective_base_url, config=self.config, ) diff --git a/src/pyUSPTO/config.py b/src/pyUSPTO/config.py index a61211c..944d69c 100644 --- a/src/pyUSPTO/config.py +++ b/src/pyUSPTO/config.py @@ -20,9 +20,6 @@ class USPTOConfig: accepts HTTP transport configuration via HTTPConfig. """ - _shared_session: "requests.Session | None" = None - _active_clients: int = 0 - def __init__( self, api_key: str | None = None, @@ -59,8 +56,8 @@ def __init__( # Control whether to include raw JSON data in response objects self.include_raw_data = include_raw_data - # Shared session for all clients using this config (created lazily) - self._shared_session: requests.Session | None = None + # Session for all clients using this config (created lazily) + self._session: requests.Session | None = None @classmethod def from_env(cls) -> "USPTOConfig": @@ -86,3 +83,90 @@ def from_env(cls) -> "USPTOConfig": # Also read HTTP config from environment http_config=HTTPConfig.from_env(), ) + + @property + def session(self) -> "requests.Session": + """Get the HTTP session for this config, creating it if needed. + + The session is created lazily on first access and reused for all + subsequent requests. All clients sharing this config will use the + same session for connection pooling. + + Returns: + Session: The requests Session object with configured adapters. + """ + if self._session is None: + self._session = self._create_session() + return self._session + + def _create_session(self) -> "requests.Session": + """Create and configure a new requests Session. + + Returns: + Session: Configured session with retry logic and connection pooling. + """ + from requests import Session + from requests.adapters import HTTPAdapter + from urllib3.util.retry import Retry + + session = Session() + + # Set API key header + if self.api_key: + session.headers["X-API-KEY"] = self.api_key + + # Apply custom headers from HTTP config + if self.http_config.custom_headers: + session.headers.update(self.http_config.custom_headers) + + # Configure retry strategy + retry_strategy = Retry( + total=self.http_config.max_retries, + backoff_factor=self.http_config.backoff_factor, + status_forcelist=self.http_config.retry_status_codes, + ) + + # Configure connection pooling + adapter = HTTPAdapter( + max_retries=retry_strategy, + pool_connections=self.http_config.pool_connections, + pool_maxsize=self.http_config.pool_maxsize, + ) + + session.mount("http://", adapter) + session.mount("https://", adapter) + + return session + + def close(self) -> None: + """Close the HTTP session and release resources. + + This should be called when you're done using this config and all + clients created from it. After calling close(), the session will + be recreated if accessed again. + + Example: + config = USPTOConfig(api_key="...") + client = PatentDataClient(config=config) + try: + # Use client + pass + finally: + config.close() + """ + if self._session is not None: + self._session.close() + self._session = None + + def __enter__(self) -> "USPTOConfig": + """Enter context manager.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: object | None, + ) -> None: + """Exit context manager, closing the session.""" + self.close() diff --git a/src/pyUSPTO/models/bulk_data.py b/src/pyUSPTO/models/bulk_data.py index 7164fc4..0b391af 100644 --- a/src/pyUSPTO/models/bulk_data.py +++ b/src/pyUSPTO/models/bulk_data.py @@ -83,7 +83,7 @@ class FileData: file_data_from_date: Start date of data covered in the file. file_data_to_date: End date of data covered in the file. file_type_text: Description of the file type. - file_release_date: Date when the file was released. + file_release_date: Datetime when the file was released. file_download_uri: URL for downloading the file. file_date: Additional file date information. file_last_modified_date_time: Last modification timestamp. @@ -96,9 +96,9 @@ class FileData: file_data_from_date: date | None file_data_to_date: date | None file_type_text: str - file_release_date: date | None + file_release_date: datetime | None file_download_uri: str | None = None - file_date: date | None = None + file_date: datetime | None = None file_last_modified_date_time: datetime | None = None raw_data: str | None = field(default=None, compare=False, repr=False) @@ -126,9 +126,9 @@ def from_dict( file_data_from_date=parse_to_date(data.get("fileDataFromDate")), file_data_to_date=parse_to_date(data.get("fileDataToDate")), file_type_text=data.get("fileTypeText", ""), - file_release_date=parse_to_date(data.get("fileReleaseDate")), + file_release_date=parse_to_datetime_utc(data.get("fileReleaseDate")), file_download_uri=data.get("fileDownloadURI"), - file_date=parse_to_date(data.get("fileDate")), + file_date=parse_to_datetime_utc(data.get("fileDate")), file_last_modified_date_time=parse_to_datetime_utc( data.get("fileLastModifiedDateTime") ), @@ -147,9 +147,15 @@ def to_dict(self) -> dict[str, Any]: "fileDataFromDate": serialize_date(self.file_data_from_date), "fileDataToDate": serialize_date(self.file_data_to_date), "fileTypeText": self.file_type_text, - "fileReleaseDate": serialize_date(self.file_release_date), + "fileReleaseDate": ( + serialize_datetime_as_naive(self.file_release_date) + if self.file_release_date + else None + ), "fileDownloadURI": self.file_download_uri, - "fileDate": serialize_date(self.file_date), + "fileDate": serialize_datetime_as_naive(self.file_date) + if self.file_date + else None, "fileLastModifiedDateTime": ( serialize_datetime_as_naive(self.file_last_modified_date_time) if self.file_last_modified_date_time diff --git a/tests/clients/test_base.py b/tests/clients/test_base.py index 6f59b9b..d257415 100644 --- a/tests/clients/test_base.py +++ b/tests/clients/test_base.py @@ -13,6 +13,7 @@ import pyUSPTO.models.base as BaseModels from pyUSPTO.clients.base import BaseUSPTOClient +from pyUSPTO.config import USPTOConfig from pyUSPTO.exceptions import ( USPTOApiAuthError, USPTOApiBadRequestError, @@ -27,6 +28,15 @@ ) +@pytest.fixture +def uspto_config() -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + mock_session = MagicMock() + config = USPTOConfig(api_key="test_key") + config._session = mock_session + return config + + class TestModelsBase: """Test classes from models.base.""" @@ -98,7 +108,7 @@ def test_init(self) -> None: """Test initialization of the BaseUSPTOClient.""" # Test with API key client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test_key", base_url="https://api.test.com" + config=USPTOConfig(api_key="test_key"), base_url="https://api.test.com" ) assert client._api_key == "test_key" assert client.api_key == "********" # API key is masked @@ -143,7 +153,7 @@ def test_make_request_get(self, mock_session: MagicMock) -> None: """Test _make_request method with GET.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.json.return_value = {"key": "value"} @@ -167,7 +177,7 @@ def test_make_request_post(self, mock_session: MagicMock) -> None: """Test _make_request method with POST.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.json.return_value = {"key": "value"} @@ -195,7 +205,7 @@ def test_make_request_with_response_class(self, mock_session: MagicMock) -> None """Test _make_request method with response_class.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.json.return_value = {"key": "value"} @@ -216,7 +226,7 @@ def test_make_request_with_custom_base_url(self, mock_session: MagicMock) -> Non """Test _make_request method with custom_base_url.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.json.return_value = {"key": "value"} @@ -242,7 +252,7 @@ def test_make_request_with_stream(self, mock_session: MagicMock) -> None: """Test _make_request method with stream=True.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_session.get.return_value = mock_response @@ -291,7 +301,7 @@ def test_make_request_http_errors(self, mock_session: MagicMock) -> None: """Test _make_request method with HTTP errors.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Test 400 error (Bad Request) mock_response = MagicMock() @@ -409,7 +419,7 @@ def test_make_request_post_error_includes_body( """Test that POST request errors include the request body in the error message.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Mock a 400 Bad Request error mock_response = MagicMock() @@ -441,7 +451,7 @@ def test_make_request_connection_error(self, mock_session: MagicMock) -> None: """Test _make_request method with connection error.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Test connection error mock_session.get.side_effect = requests.exceptions.ConnectionError( @@ -460,7 +470,7 @@ def test_make_request_timeout_error(self, mock_session: MagicMock) -> None: """Test _make_request method with timeout error.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Test timeout error mock_session.get.side_effect = requests.exceptions.Timeout("Request timed out") @@ -479,7 +489,7 @@ def test_make_request_generic_request_exception( """Test _make_request method with generic request exception.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Test generic RequestException (not Timeout or ConnectionError) mock_session.get.side_effect = requests.exceptions.RequestException( @@ -495,7 +505,7 @@ def test_make_request_json_parse_error(self, mock_session: MagicMock) -> None: """Test _make_request method with JSON parsing error.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Mock successful HTTP response but with non-JSON content mock_response = MagicMock() @@ -524,7 +534,7 @@ def test_make_request_json_parse_error_with_response_class( """Test _make_request with JSON parsing error when using response_class.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Mock successful HTTP response but with invalid JSON mock_response = MagicMock() @@ -553,7 +563,7 @@ def test_paginate_results(self, mock_session: MagicMock) -> None: """Test paginate_results method.""" # Setup client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Create mock responses first_response = MagicMock() @@ -582,7 +592,7 @@ def test_method(self, **kwargs: Any) -> Any: # Use our test client test_client = TestClient(base_url="https://api.test.com") - test_client.session = mock_session + test_client.config._session = mock_session # Spy on the test_method to verify calls with patch.object( @@ -616,7 +626,7 @@ def test_method(self, **kwargs: Any) -> Any: # Use our test client for partial results test_partial_client = TestPartialClient(base_url="https://api.test.com") - test_partial_client.session = mock_session + test_partial_client.config._session = mock_session # Test paginate_results with early return results = list( @@ -641,7 +651,7 @@ def test_method(self, **kwargs: Any) -> Any: # Use our test client for empty results test_empty_client = TestEmptyClient(base_url="https://api.test.com") - test_empty_client.session = mock_session + test_empty_client.config._session = mock_session # Test paginate_results with empty response results = list( @@ -661,7 +671,7 @@ def test_paginate_results_with_nested_pagination( """Test paginate_results handles nested pagination structure correctly.""" # Setup client client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Create mock responses # response.count is the TOTAL count across all pages @@ -750,7 +760,7 @@ def test_paginate_results_with_flat_pagination( """Test paginate_results still works with flat (top-level) pagination structure.""" # Setup client client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Create mock responses # First page: 2 items, total count shows there's only 1 item total (less than limit) @@ -801,7 +811,7 @@ def test_paginate_results_rejects_offset_in_nested_pagination( ) -> None: """Test that offset is rejected when provided in nested pagination.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session post_body = { "q": "test", @@ -825,7 +835,7 @@ def test_paginate_results_missing_count_attribute( ) -> None: """Test pagination raises AttributeError when response missing count.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Create response without count attribute mock_response = MagicMock() @@ -836,7 +846,7 @@ def test_method(self, **kwargs: Any) -> Any: return mock_response test_client = TestClient(base_url="https://api.test.com") - test_client.session = mock_session + test_client.config._session = mock_session with pytest.raises( AttributeError, match="missing required 'count' attribute for pagination" @@ -852,7 +862,7 @@ def test_paginate_results_missing_container_attribute( ) -> None: """Test pagination raises AttributeError when response missing container.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Create response with count but without items mock_response = MagicMock() @@ -864,7 +874,7 @@ def test_method(self, **kwargs: Any) -> Any: return mock_response test_client = TestClient(base_url="https://api.test.com") - test_client.session = mock_session + test_client.config._session = mock_session with pytest.raises( AttributeError, match="missing required 'items' attribute for pagination" @@ -878,7 +888,7 @@ def test_method(self, **kwargs: Any) -> Any: def test_paginate_results_count_none(self, mock_session: MagicMock) -> None: """Test pagination stops gracefully when count is None.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.count = None @@ -888,7 +898,7 @@ def test_method(self, **kwargs: Any) -> Any: return mock_response test_client = TestClient(base_url="https://api.test.com") - test_client.session = mock_session + test_client.config._session = mock_session # Should return empty list without error results = list( @@ -901,7 +911,7 @@ def test_method(self, **kwargs: Any) -> Any: def test_paginate_results_container_none(self, mock_session: MagicMock) -> None: """Test pagination raises ValueError when container is None.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.count = 10 @@ -912,7 +922,7 @@ def test_method(self, **kwargs: Any) -> Any: return mock_response test_client = TestClient(base_url="https://api.test.com") - test_client.session = mock_session + test_client.config._session = mock_session with pytest.raises(ValueError, match="Container 'items' is None"): list( @@ -924,7 +934,7 @@ def test_method(self, **kwargs: Any) -> Any: def test_save_response_to_file(self, mock_session: MagicMock) -> None: """Test _save_response_to_file raises FileExistsError.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session mock_response = MagicMock() mock_response.headers = {} mock_response.url = "https://api.test.com/file" @@ -970,7 +980,7 @@ def test_base_client_with_http_config(self) -> None: def test_base_client_backward_compatibility(self) -> None: """Test client works without HTTPConfig (backward compatibility)""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Should create default HTTPConfig automatically @@ -989,7 +999,7 @@ def test_base_client_timeout_applied(self, mock_session: MagicMock) -> None: client: BaseUSPTOClient[Any] = BaseUSPTOClient( config=config, base_url="https://api.test.com" ) - client.session = mock_session + client.config._session = mock_session mock_session.get.return_value.status_code = 200 mock_session.get.return_value.json.return_value = {"test": "data"} @@ -1016,80 +1026,46 @@ def test_base_client_with_config_object(self) -> None: assert client._api_key == "config_key" assert client.config is config - def test_base_client_api_key_priority(self) -> None: - """Test API key priority: explicit > config""" - from pyUSPTO.config import USPTOConfig - - config = USPTOConfig(api_key="config_key") - client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="explicit_key", config=config, base_url="https://test.com" - ) - - # Explicit api_key should take precedence - assert client._api_key == "explicit_key" - def test_context_manager_enters_and_exits(self, mock_session: MagicMock) -> None: """Test that context manager __enter__ and __exit__ work correctly.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Test __enter__ returns self with client as ctx_client: assert ctx_client is client - # Test __exit__ was called (which calls close) - # Since we're using mock_session, we need to verify close was called on it - mock_session.close.assert_called_once() + # Verify session not closed by client context manager + mock_session.close.assert_not_called() - def test_close_when_session_is_owned(self, mock_session: MagicMock) -> None: - """Test close() closes session when client owns it.""" + def test_close_does_not_close_session(self, mock_session: MagicMock) -> None: + """Test close() does NOT close session.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - # Client creates its own session, so it owns it - assert client._owns_session is True - - # Replace with mock for testing - client.session = mock_session - - # Close should close the session - client.close() - mock_session.close.assert_called_once() - - def test_close_when_session_is_shared(self, mock_session: MagicMock) -> None: - """Test close() does NOT close session when it's shared via config.""" - from pyUSPTO.config import USPTOConfig - - # Create config with existing shared session - config = USPTOConfig(api_key="test") - config._shared_session = mock_session - - # Create client - it should reuse the shared session and not own it - client: BaseUSPTOClient[Any] = BaseUSPTOClient( - base_url="https://api.test.com", config=config - ) - assert client._owns_session is False + client.config._session = mock_session - # Close should NOT close the shared session client.close() mock_session.close.assert_not_called() - def test_close_backward_compatibility(self, mock_session: MagicMock) -> None: - """Test close() works when _owns_session attribute doesn't exist (backward compat).""" - client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + def test_config_context_manager(self) -> None: + """Test USPTOConfig context manager.""" + config = USPTOConfig(api_key="test") + with config as ctx_config: + assert ctx_config is config + assert config.session is not None - # Simulate old client without _owns_session attribute - delattr(client, "_owns_session") - - # Close should still close the session for backward compatibility - client.close() - mock_session.close.assert_called_once() + def test_config_close_with_session(self, uspto_config: USPTOConfig) -> None: + """Test USPTOConfig.close() when session exists.""" + config = uspto_config + assert config._session is not None + config.close() + assert config._session is None def test_paginate_results_rejects_offset_in_flat_post_body( self, mock_session: MagicMock ) -> None: """Test that offset is rejected when provided in flat POST body.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient(base_url="https://api.test.com") - client.session = mock_session + client.config._session = mock_session # Flat structure with user-provided offset - should raise post_body = {"q": "test", "offset": 10, "limit": 50} @@ -1227,7 +1203,7 @@ def test_save_to_directory_with_content_disposition( # Create a test client client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with Content-Disposition header @@ -1251,7 +1227,7 @@ def test_save_without_extension_uses_content_type_pdf( ) -> None: """Test saving file without extension adds extension from Content-Type (PDF).""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with Content-Type but no Content-Disposition @@ -1274,7 +1250,7 @@ def test_save_url_without_extension_uses_content_type_tiff( ) -> None: """Test filename from URL without extension gets extension from Content-Type (TIFF).""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with TIFF Content-Type, filename extracted from URL @@ -1297,7 +1273,7 @@ def test_save_url_with_existing_extension_ignores_content_type( ) -> None: """Test filename from URL with extension ignores Content-Type.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with Content-Type @@ -1322,7 +1298,7 @@ def test_save_url_without_extension_unmapped_mime_type( ) -> None: """Test filename from URL with unmapped MIME type saves without extension.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with unmapped Content-Type @@ -1345,7 +1321,7 @@ def test_save_url_without_extension_no_content_type( ) -> None: """Test filename from URL without Content-Type header saves without extension.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response without Content-Type header @@ -1368,7 +1344,7 @@ def test_save_content_disposition_takes_precedence_over_content_type( ) -> None: """Test Content-Disposition filename takes precedence over Content-Type extension.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with both Content-Disposition and Content-Type @@ -1394,7 +1370,7 @@ def test_save_fallback_to_download_filename( ) -> None: """Test fallback to 'download' filename when no filename can be determined.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response with no Content-Disposition, no URL path, no extension @@ -1419,7 +1395,7 @@ def test_save_to_current_directory_when_no_destination( from pathlib import Path client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Mock response @@ -1448,7 +1424,7 @@ def test_extract_tar_file(self, tmp_path: Any) -> None: import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a test TAR file @@ -1466,14 +1442,16 @@ def test_extract_tar_file(self, tmp_path: Any) -> None: # Verify extraction assert (extract_to / "test_file.txt").exists() - assert result == str(extract_to / "test_file.txt") # Single file returns file path + assert result == str( + extract_to / "test_file.txt" + ) # Single file returns file path def test_extract_tar_gz_file(self, tmp_path: Any) -> None: """Test extracting a TAR.GZ file.""" import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a test TAR.GZ file @@ -1497,7 +1475,7 @@ def test_extract_zip_file(self, tmp_path: Any) -> None: import zipfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a test ZIP file @@ -1519,7 +1497,7 @@ def test_extract_multiple_files_returns_directory(self, tmp_path: Any) -> None: import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create archive with multiple files @@ -1547,7 +1525,7 @@ def test_extract_with_remove_archive(self, tmp_path: Any) -> None: import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a test archive @@ -1569,7 +1547,7 @@ def test_extract_with_remove_archive(self, tmp_path: Any) -> None: def test_extract_invalid_archive_raises_error(self, tmp_path: Any) -> None: """Test extracting invalid archive raises ValueError.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a non-archive file @@ -1585,7 +1563,7 @@ def test_extract_default_extract_to_path(self, tmp_path: Any) -> None: import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create archive @@ -1604,6 +1582,133 @@ def test_extract_default_extract_to_path(self, tmp_path: Any) -> None: assert (expected_dir / "test.txt").exists() assert result == str(expected_dir / "test.txt") + def test_tar_path_traversal_blocked(self, tmp_path: Any) -> None: + """Test that tar archives with ../ paths are rejected.""" + import io + import tarfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + archive_path = tmp_path / "malicious.tar" + with tarfile.open(archive_path, "w") as tar: + info = tarfile.TarInfo(name="../../etc/passwd") + info.size = 10 + tar.addfile(info, io.BytesIO(b"malicious!")) + + extract_to = tmp_path / "safe_dir" + extract_to.mkdir() + + with pytest.raises(ValueError, match="unsafe path"): + client._extract_archive(archive_path, extract_to=extract_to) + + def test_zip_path_traversal_blocked(self, tmp_path: Any) -> None: + """Test that zip archives with ../ paths are rejected.""" + import zipfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + archive_path = tmp_path / "malicious.zip" + with zipfile.ZipFile(archive_path, "w") as zip_ref: + zip_ref.writestr("../../etc/passwd", "malicious!") + + extract_to = tmp_path / "safe_dir" + extract_to.mkdir() + + with pytest.raises(ValueError, match="unsafe path"): + client._extract_archive(archive_path, extract_to=extract_to) + + def test_max_size_enforcement(self, tmp_path: Any) -> None: + """Test that max_size parameter prevents extracting large archives.""" + import tarfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + # Create archive with known size + archive_path = tmp_path / "large.tar" + with tarfile.open(archive_path, "w") as tar: + temp_file = tmp_path / "large_file.txt" + temp_file.write_text("x" * 1000) # 1000 bytes + tar.add(temp_file, arcname="large_file.txt") + + extract_to = tmp_path / "extracted" + + # Should fail with max_size=500 + with pytest.raises(ValueError, match="exceeds maximum allowed"): + client._extract_archive(archive_path, extract_to=extract_to, max_size=500) + + # Should succeed with max_size=2000 + result = client._extract_archive( + archive_path, extract_to=extract_to, max_size=2000 + ) + assert (extract_to / "large_file.txt").exists() + + def test_max_size_enforcement_zip(self, tmp_path: Any) -> None: + """Test that max_size parameter prevents extracting large zip archives.""" + import zipfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + archive_path = tmp_path / "large.zip" + with zipfile.ZipFile(archive_path, "w") as zip_ref: + zip_ref.writestr("large_file.txt", "x" * 1000) + + extract_to = tmp_path / "extracted" + + # Should fail with max_size=500 + with pytest.raises(ValueError, match="exceeds maximum allowed"): + client._extract_archive(archive_path, extract_to=extract_to, max_size=500) + + def test_tar_with_directories(self, tmp_path: Any) -> None: + """Test that tar archives with directories extract correctly.""" + import tarfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + archive_path = tmp_path / "with_dirs.tar" + with tarfile.open(archive_path, "w") as tar: + # Add a directory and a file + temp_dir = tmp_path / "testdir" + temp_dir.mkdir() + temp_file = temp_dir / "file.txt" + temp_file.write_text("content") + tar.add(temp_dir, arcname="testdir") + + extract_to = tmp_path / "extracted" + result = client._extract_archive(archive_path, extract_to=extract_to) + + # Directory entries are skipped, but files are extracted + assert (extract_to / "testdir" / "file.txt").exists() + + def test_zip_with_directories(self, tmp_path: Any) -> None: + """Test that zip archives with directories extract correctly.""" + import zipfile + + client: BaseUSPTOClient[Any] = BaseUSPTOClient( + config=USPTOConfig(api_key="test"), base_url="https://test.com" + ) + + archive_path = tmp_path / "with_dirs.zip" + with zipfile.ZipFile(archive_path, "w") as zip_ref: + # Add directory entry and file + zip_ref.writestr("testdir/", "") + zip_ref.writestr("testdir/file.txt", "content") + + extract_to = tmp_path / "extracted" + result = client._extract_archive(archive_path, extract_to=extract_to) + + # Directory entries are skipped, but files are extracted + assert (extract_to / "testdir" / "file.txt").exists() + class TestDownloadAndExtract: """Tests for _download_and_extract method.""" @@ -1613,7 +1718,7 @@ def test_download_and_extract_tar_file(self, tmp_path: Any) -> None: import tarfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a real TAR file to simulate download @@ -1639,7 +1744,7 @@ def test_download_and_extract_zip_file(self, tmp_path: Any) -> None: import zipfile client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a real ZIP file @@ -1661,7 +1766,7 @@ def test_download_and_extract_zip_file(self, tmp_path: Any) -> None: def test_download_non_archive_returns_file_path(self, tmp_path: Any) -> None: """Test downloading non-archive file returns downloaded file path.""" client: BaseUSPTOClient[Any] = BaseUSPTOClient( - api_key="test", base_url="https://test.com" + config=USPTOConfig(api_key="test"), base_url="https://test.com" ) # Create a non-archive file diff --git a/tests/clients/test_bulk_data_clients.py b/tests/clients/test_bulk_data_clients.py index 142474e..ff62cee 100644 --- a/tests/clients/test_bulk_data_clients.py +++ b/tests/clients/test_bulk_data_clients.py @@ -5,13 +5,11 @@ model handling, edge cases, and response handling. """ -import os -from datetime import date +from datetime import date, datetime, timezone from typing import Any -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import MagicMock, patch import pytest -import requests from pyUSPTO.clients import BulkDataClient from pyUSPTO.config import USPTOConfig @@ -22,6 +20,26 @@ ProductFileBag, ) +# --- Fixtures --- + + +@pytest.fixture +def api_key_fixture() -> str: + """Provides a test API key.""" + return "test_key" + + +@pytest.fixture +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + +@pytest.fixture +def bulk_data_client(uspto_config: USPTOConfig) -> BulkDataClient: + """Provides a BulkDataClient instance initialized with a test config.""" + return BulkDataClient(config=uspto_config) + class TestBulkDataModels: """Tests for the bulk data model classes.""" @@ -34,9 +52,9 @@ def test_file_data_from_dict(self) -> None: "fileDataFromDate": "2023-01-01", "fileDataToDate": "2023-12-31", "fileTypeText": "ZIP", - "fileReleaseDate": "2024-01-01", + "fileReleaseDate": "2024-01-01 00:00:00", "fileDownloadURI": "https://example.com/test.zip", - "fileDate": "2023-12-31", + "fileDate": "2023-12-31 00:00:00", "fileLastModifiedDateTime": "2023-12-31T23:59:59", } @@ -48,10 +66,10 @@ def test_file_data_from_dict(self) -> None: assert file_data.file_data_from_date == date(2023, 1, 1) assert file_data.file_data_to_date == date(2023, 12, 31) assert file_data.file_type_text == "ZIP" - assert file_data.file_release_date == date(2024, 1, 1) + assert file_data.file_release_date is not None assert file_data.file_download_uri == "https://example.com/test.zip" - assert file_data.file_date == date(2023, 12, 31) - assert file_data.file_last_modified_date_time is not None # Datetime object, not string + assert file_data.file_date is not None + assert file_data.file_last_modified_date_time is not None def test_product_file_bag_from_dict(self) -> None: """Test ProductFileBag.from_dict method.""" @@ -64,7 +82,7 @@ def test_product_file_bag_from_dict(self) -> None: "fileDataFromDate": "2023-01-01", "fileDataToDate": "2023-06-30", "fileTypeText": "ZIP", - "fileReleaseDate": "2023-07-01", + "fileReleaseDate": "2023-07-01 00:00:00", }, { "fileName": "test2.zip", @@ -72,7 +90,7 @@ def test_product_file_bag_from_dict(self) -> None: "fileDataFromDate": "2023-07-01", "fileDataToDate": "2023-12-31", "fileTypeText": "ZIP", - "fileReleaseDate": "2024-01-01", + "fileReleaseDate": "2024-01-01 00:00:00", }, ], } @@ -140,7 +158,9 @@ def test_bulk_data_product_from_dict(self) -> None: assert product.product_to_date == date(2023, 12, 31) assert product.product_total_file_size == 1024 assert product.product_file_total_quantity == 2 - assert product.last_modified_date_time is not None # Datetime object, not string + assert ( + product.last_modified_date_time is not None + ) # Datetime object, not string assert product.mime_type_identifier_array_text == ["application/zip"] assert product.product_file_bag is not None assert product.product_file_bag.count == 2 @@ -205,26 +225,25 @@ def test_bulk_data_response_from_dict(self) -> None: assert response.bulk_data_product_bag[1].product_identifier == "PRODUCT2" assert response.bulk_data_product_bag[0].product_file_bag is not None assert response.bulk_data_product_bag[0].product_file_bag.count == 1 - assert len(response.bulk_data_product_bag[0].product_file_bag.file_data_bag) == 1 + assert ( + len(response.bulk_data_product_bag[0].product_file_bag.file_data_bag) == 1 + ) assert response.bulk_data_product_bag[1].product_file_bag is None class TestBulkDataClientInit: """Tests for the initialization of the BulkDataClient class.""" - def test_init_with_api_key(self) -> None: + def test_init_with_api_key(self, uspto_config: USPTOConfig) -> None: """Test initialization with direct API key.""" - client = BulkDataClient(api_key="test_key") - assert client._api_key == "test_key" + client = BulkDataClient(config=uspto_config) assert client.base_url == "https://api.uspto.gov" assert client.config is not None assert client.config.api_key == "test_key" - def test_init_with_custom_base_url(self) -> None: + def test_init_with_custom_base_url(self, uspto_config: USPTOConfig) -> None: """Test initialization with custom base URL.""" - client = BulkDataClient( - api_key="test_key", base_url="https://custom.api.test.com" - ) + client = BulkDataClient(uspto_config, base_url="https://custom.api.test.com") assert client.base_url == "https://custom.api.test.com" def test_init_with_config(self) -> None: @@ -234,19 +253,14 @@ def test_init_with_config(self) -> None: bulk_data_base_url="https://config.api.test.com", ) client = BulkDataClient(config=config) - assert client._api_key == "config_key" + assert client.config.api_key == "config_key" assert client.base_url == "https://config.api.test.com" - assert client.config is config - def test_init_with_api_key_and_config(self) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", - bulk_data_base_url="https://config.api.test.com", - ) - client = BulkDataClient(api_key="direct_key", config=config) - assert client._api_key == "direct_key" - assert client.base_url == "https://config.api.test.com" + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = BulkDataClient() + assert client.config.api_key == "env_key" class TestBulkDataClientCore: @@ -265,7 +279,7 @@ def test_search_products_basic( mock_session.get.return_value = mock_response # Replace the client's session with our mock - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test search_products with basic query response = mock_bulk_data_client.search_products(query="Patent") @@ -296,7 +310,7 @@ def test_get_product_by_id( mock_session.get.return_value = mock_response # Replace the client's session with our mock - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test get_product_by_id product = mock_bulk_data_client.get_product_by_id( @@ -336,20 +350,19 @@ def test_download_file(self, mock_bulk_data_client: BulkDataClient) -> None: file_data_from_date=date(2023, 1, 1), file_data_to_date=date(2023, 12, 31), file_type_text="TAR", - file_release_date=date(2024, 1, 1), + file_release_date=datetime(2024, 1, 1, tzinfo=timezone.utc), ) destination = "./downloads" - # Mock the _download_and_extract method with patch.object( - mock_bulk_data_client, "_download_and_extract", return_value="./downloads/extracted" + mock_bulk_data_client, + "_download_file", + return_value="./downloads/test.tar.gz", ) as mock_download: - # Test download_file with extraction (default) file_path = mock_bulk_data_client.download_file( file_data=file_data, destination=destination ) - # Verify expected_url = f"{mock_bulk_data_client.base_url}/api/v1/datasets/products/files/PRODUCT1/test.tar.gz" mock_download.assert_called_once_with( url=expected_url, @@ -357,7 +370,7 @@ def test_download_file(self, mock_bulk_data_client: BulkDataClient) -> None: file_name="test.tar.gz", overwrite=False, ) - assert file_path == "./downloads/extracted" + assert file_path == "./downloads/test.tar.gz" def test_download_file_without_extraction( self, mock_bulk_data_client: BulkDataClient @@ -371,7 +384,7 @@ def test_download_file_without_extraction( file_data_from_date=date(2023, 1, 1), file_data_to_date=date(2023, 12, 31), file_type_text="ZIP", - file_release_date=date(2024, 1, 1), + file_release_date=datetime(2024, 1, 1, tzinfo=timezone.utc), ) destination = "./downloads" @@ -406,7 +419,7 @@ def test_search_products_all_params( mock_session.get.return_value = mock_response # Replace the client's session with our mock - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test search_products with all available parameters response = mock_bulk_data_client.search_products( @@ -452,7 +465,6 @@ def test_paginate_products(self, mock_bulk_data_client: BulkDataClient) -> None: mock_paginate_results.assert_called_once_with( method_name="search_products", response_container_attr="bulk_data_product_bag", - post_body=None, param="value", ) @@ -463,7 +475,7 @@ class TestBulkDataClientEdgeCases: def test_get_product_by_id_not_found(self) -> None: """Test get_product_by_id when product is not in response.""" # Setup - client = BulkDataClient(api_key="test_key") + client = BulkDataClient(config=USPTOConfig(api_key="test_key")) # Mock _make_request to return an empty BulkDataResponse empty_response = BulkDataResponse(count=0, bulk_data_product_bag=[]) @@ -475,7 +487,7 @@ def test_get_product_by_id_not_found(self) -> None: def test_get_product_by_id_wrong_product_returned(self) -> None: """Test get_product_by_id when API returns wrong product.""" # Setup - client = BulkDataClient(api_key="test_key") + client = BulkDataClient(config=USPTOConfig(api_key="test_key")) # Create response with different product ID wrong_product = BulkDataProduct( @@ -488,7 +500,9 @@ def test_get_product_by_id_wrong_product_returned(self) -> None: with patch.object(client, "_make_request", return_value=response): # Should still return the product but issue a warning - with pytest.warns(match="API returned product 'WRONG_ID' but requested 'TEST'"): + with pytest.warns( + match="API returned product 'WRONG_ID' but requested 'TEST'" + ): product = client.get_product_by_id(product_id="TEST") assert product.product_identifier == "WRONG_ID" @@ -504,13 +518,15 @@ def test_download_file_with_custom_filename( file_data_from_date=date(2023, 1, 1), file_data_to_date=date(2023, 12, 31), file_type_text="ZIP", - file_release_date=date(2024, 1, 1), + file_release_date=datetime(2024, 1, 1, tzinfo=timezone.utc), ) destination = "./downloads" # Mock the _download_file method with patch.object( - mock_bulk_data_client, "_download_file", return_value="./downloads/custom.zip" + mock_bulk_data_client, + "_download_file", + return_value="./downloads/custom.zip", ) as mock_download: file_path = mock_bulk_data_client.download_file( file_data=file_data, @@ -541,14 +557,17 @@ def test_download_file_with_overwrite( file_data_from_date=date(2023, 1, 1), file_data_to_date=date(2023, 12, 31), file_type_text="ZIP", - file_release_date=date(2024, 1, 1), + file_release_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + file_download_uri="https://example.com/direct/test.zip", ) # Mock the _download_and_extract method with patch.object( mock_bulk_data_client, "_download_and_extract", return_value="./test.zip" ) as mock_download: - mock_bulk_data_client.download_file(file_data=file_data, overwrite=True) + mock_bulk_data_client.download_file( + file_data=file_data, overwrite=True, extract=True + ) # Verify overwrite is passed through assert mock_download.call_args[1]["overwrite"] is True @@ -570,7 +589,7 @@ def test_search_products_with_query_and_limit( mock_session.get.return_value = mock_response # Replace the client's session with our mock - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test search_products with query and limit response = mock_bulk_data_client.search_products( @@ -599,7 +618,7 @@ def test_search_products_with_offset_and_facets( mock_session.get.return_value = mock_response # Replace the client's session with our mock - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test search_products with offset and facets response = mock_bulk_data_client.search_products( @@ -625,7 +644,7 @@ def test_search_products_with_fields( mock_response.json.return_value = bulk_data_sample mock_session = MagicMock() mock_session.get.return_value = mock_response - mock_bulk_data_client.session = mock_session + mock_bulk_data_client.config._session = mock_session # Test search_products with fields response = mock_bulk_data_client.search_products( diff --git a/tests/clients/test_patent_data_clients.py b/tests/clients/test_patent_data_clients.py index 273b465..93bfca7 100644 --- a/tests/clients/test_patent_data_clients.py +++ b/tests/clients/test_patent_data_clients.py @@ -58,9 +58,15 @@ def api_key_fixture() -> str: @pytest.fixture -def patent_data_client(api_key_fixture: str) -> PatentDataClient: - """Provides a PatentDataClient instance initialized with a test API key.""" - return PatentDataClient(api_key=api_key_fixture) +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + +@pytest.fixture +def patent_data_client(uspto_config: USPTOConfig) -> PatentDataClient: + """Provides a PatentDataClient instance initialized with a test config.""" + return PatentDataClient(config=uspto_config) @pytest.fixture @@ -248,44 +254,37 @@ def mock_requests_response() -> MagicMock: class TestPatentDataClientInit: """Tests for the initialization of the PatentDataClient.""" - def test_init_with_api_key(self, api_key_fixture: str) -> None: - """Test initialization with API key.""" - client = PatentDataClient(api_key=api_key_fixture) - assert client._api_key == api_key_fixture - assert client.base_url == "https://api.uspto.gov" + def test_init_with_config(self, uspto_config: USPTOConfig) -> None: + """Test initialization with config.""" + client = PatentDataClient(config=uspto_config) + assert client._api_key == uspto_config.api_key + assert client.base_url == uspto_config.patent_data_base_url - def test_init_with_custom_base_url(self, api_key_fixture: str) -> None: + def test_init_with_custom_base_url(self, uspto_config: USPTOConfig) -> None: """Test initialization with custom base URL.""" custom_url = "https://custom.api.test.com" - client = PatentDataClient(api_key=api_key_fixture, base_url=custom_url) - assert client._api_key == api_key_fixture + client = PatentDataClient(config=uspto_config, base_url=custom_url) assert client.base_url == custom_url - def test_init_with_config(self) -> None: - """Test initialization with config object.""" + def test_init_with_config_base_url(self, uspto_config: USPTOConfig) -> None: + """Test initialization with config containing base URL.""" config_key = "config_key" config_url = "https://config.api.test.com" config = USPTOConfig(api_key=config_key, patent_data_base_url=config_url) client = PatentDataClient(config=config) - assert client._api_key == config_key assert client.base_url == config_url assert client.config is config - def test_init_with_api_key_and_config(self, api_key_fixture: str) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", patent_data_base_url="https://config.api.test.com" - ) - client = PatentDataClient(api_key=api_key_fixture, config=config) - assert client._api_key == api_key_fixture - assert client.base_url == "https://config.api.test.com" - custom_url = "https://custom.url.com" - client_custom_url = PatentDataClient( - api_key=api_key_fixture, base_url=custom_url, config=config - ) + client_custom_url = PatentDataClient(config=uspto_config, base_url=custom_url) assert client_custom_url.base_url == custom_url + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = PatentDataClient() + assert client.config.api_key == "env_key" + class TestPatentApplicationSearch: """Tests for patent application search functionalities using the new search_applications method.""" @@ -2546,11 +2545,10 @@ def test_raw_data_disabled_by_default( assert result.raw_data is None def test_raw_data_enabled_via_config( - self, mock_patent_file_wrapper: PatentFileWrapper + self, uspto_config: USPTOConfig, mock_patent_file_wrapper: PatentFileWrapper ) -> None: """Test that raw_data is populated when config.include_raw_data=True.""" - config = USPTOConfig(api_key="test_key", include_raw_data=True) - PatentDataClient(config=config) + PatentDataClient(config=uspto_config) # Create a response with raw_data enabled test_data = { diff --git a/tests/clients/test_petition_decision_clients.py b/tests/clients/test_petition_decision_clients.py index 597613e..0ee303b 100644 --- a/tests/clients/test_petition_decision_clients.py +++ b/tests/clients/test_petition_decision_clients.py @@ -21,9 +21,8 @@ ) from pyUSPTO.warnings import USPTODataMismatchWarning -# --- Fixtures --- - +# --- Fixtures --- @pytest.fixture def api_key_fixture() -> str: """Provides a test API key.""" @@ -31,9 +30,17 @@ def api_key_fixture() -> str: @pytest.fixture -def petition_client(api_key_fixture: str) -> Iterator[FinalPetitionDecisionsClient]: +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + +@pytest.fixture +def petition_client( + uspto_config: USPTOConfig, +) -> Iterator[FinalPetitionDecisionsClient]: """Provides a FinalPetitionDecisionsClient instance.""" - client = FinalPetitionDecisionsClient(api_key=api_key_fixture) + client = FinalPetitionDecisionsClient(config=uspto_config) with patch.object(client, "_download_and_extract") as mock_download: mock_download.return_value = "/tmp/document.pdf" yield client @@ -98,19 +105,18 @@ def mock_download_option() -> DocumentDownloadOption: class TestFinalPetitionDecisionsClientInit: """Tests for initialization of FinalPetitionDecisionsClient.""" - def test_init_with_api_key(self, api_key_fixture: str) -> None: + def test_init_with_api_key( + self, petition_client: FinalPetitionDecisionsClient, uspto_config: USPTOConfig + ) -> None: """Test initialization with API key.""" - client = FinalPetitionDecisionsClient(api_key=api_key_fixture) - assert client._api_key == api_key_fixture + client = petition_client + assert client._api_key == uspto_config.api_key assert client.base_url == "https://api.uspto.gov" - def test_init_with_custom_base_url(self, api_key_fixture: str) -> None: + def test_init_with_custom_base_url(self, uspto_config: USPTOConfig) -> None: """Test initialization with custom base URL.""" custom_url = "https://custom.api.test.com" - client = FinalPetitionDecisionsClient( - api_key=api_key_fixture, base_url=custom_url - ) - assert client._api_key == api_key_fixture + client = FinalPetitionDecisionsClient(config=uspto_config, base_url=custom_url) assert client.base_url == custom_url def test_init_with_config(self) -> None: @@ -123,29 +129,11 @@ def test_init_with_config(self) -> None: assert client.base_url == config_url assert client.config is config - def test_init_with_api_key_and_config(self, api_key_fixture: str) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", - petition_decisions_base_url="https://config.api.test.com", - ) - client = FinalPetitionDecisionsClient(api_key=api_key_fixture, config=config) - # API key parameter takes precedence - assert client._api_key == api_key_fixture - # But base_url comes from config - assert client.base_url == "https://config.api.test.com" - - def test_init_base_url_precedence(self, api_key_fixture: str) -> None: - """Test that explicit base_url takes precedence over config.""" - config = USPTOConfig( - api_key="config_key", - petition_decisions_base_url="https://config.api.test.com", - ) - custom_url = "https://custom.url.com" - client = FinalPetitionDecisionsClient( - api_key=api_key_fixture, base_url=custom_url, config=config - ) - assert client.base_url == custom_url + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = FinalPetitionDecisionsClient() + assert client.config.api_key == "env_key" class TestFinalPetitionDecisionsClientSearch: @@ -812,7 +800,9 @@ def test_download_document_file_exists( # Mock _download_and_extract to raise FileExistsError with patch.object(petition_client, "_download_and_extract") as mock_dl: - mock_dl.side_effect = FileExistsError(f"File exists: {existing_file}. Use overwrite=True") + mock_dl.side_effect = FileExistsError( + f"File exists: {existing_file}. Use overwrite=True" + ) with pytest.raises(FileExistsError, match="File exists"): petition_client.download_petition_document( diff --git a/tests/clients/test_ptab_appeals_client.py b/tests/clients/test_ptab_appeals_client.py index 7d81714..240806e 100644 --- a/tests/clients/test_ptab_appeals_client.py +++ b/tests/clients/test_ptab_appeals_client.py @@ -9,15 +9,22 @@ import pytest from pyUSPTO import PTABAppealsClient, USPTOConfig -from pyUSPTO.models.ptab import PTABAppealResponse +from pyUSPTO.models.ptab import AppealDocumentData, PTABAppealResponse +# --- Fixtures --- @pytest.fixture def api_key_fixture() -> str: - """Fixture for test API key.""" + """Provides a test API key.""" return "test_key" +@pytest.fixture +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + @pytest.fixture def appeal_decision_sample() -> dict[str, Any]: """Sample appeal decision data for testing.""" @@ -56,9 +63,9 @@ def appeal_decision_sample() -> dict[str, Any]: @pytest.fixture -def mock_ptab_appeals_client(api_key_fixture: str) -> PTABAppealsClient: +def mock_ptab_appeals_client(uspto_config: USPTOConfig) -> PTABAppealsClient: """Fixture for mock PTABAppealsClient.""" - return PTABAppealsClient(api_key=api_key_fixture) + return PTABAppealsClient(config=uspto_config) class TestPTABAppealsClientInit: @@ -66,38 +73,23 @@ class TestPTABAppealsClientInit: def test_init_with_api_key(self, api_key_fixture: str) -> None: """Test initialization with API key.""" - client = PTABAppealsClient(api_key=api_key_fixture) - assert client._api_key == api_key_fixture + client = PTABAppealsClient(config=USPTOConfig(api_key=api_key_fixture)) assert client.base_url == "https://api.uspto.gov" - def test_init_with_custom_base_url(self, api_key_fixture: str) -> None: + def test_init_with_custom_base_url( + self, api_key_fixture: str, uspto_config: USPTOConfig + ) -> None: """Test initialization with custom base URL.""" custom_url = "https://custom.api.test.com" - client = PTABAppealsClient(api_key=api_key_fixture, base_url=custom_url) + client = PTABAppealsClient(config=uspto_config, base_url=custom_url) assert client._api_key == api_key_fixture assert client.base_url == custom_url - def test_init_with_config(self) -> None: - """Test initialization with config object.""" - config_key = "config_key" - config_url = "https://config.api.test.com" - config = USPTOConfig(api_key=config_key, ptab_base_url=config_url) - client = PTABAppealsClient(config=config) - assert client._api_key == config_key - assert client.base_url == config_url - assert client.config is config - - def test_init_with_api_key_and_config(self, api_key_fixture: str) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", - ptab_base_url="https://config.api.test.com", - ) - client = PTABAppealsClient(api_key=api_key_fixture, config=config) - # API key parameter takes precedence - assert client._api_key == api_key_fixture - # But base_url comes from config - assert client.base_url == "https://config.api.test.com" + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = PTABAppealsClient() + assert client.config.api_key == "env_key" class TestPTABAppealsClientSearchDecisions: @@ -114,7 +106,7 @@ def test_search_decisions_get_with_query( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -141,7 +133,7 @@ def test_search_decisions_get_with_convenience_params( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -178,7 +170,7 @@ def test_search_decisions_get_with_date_from_only( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -202,7 +194,7 @@ def test_search_decisions_get_with_date_to_only( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -226,7 +218,7 @@ def test_search_decisions_get_with_all_convenience_params( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -262,7 +254,7 @@ def test_search_decisions_post_with_body( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.post.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session post_body = {"q": "technologyCenterNumber:3600", "limit": 100} @@ -286,7 +278,7 @@ def test_search_decisions_with_optional_params( mock_response = MagicMock() mock_response.json.return_value = appeal_decision_sample mock_session.get.return_value = mock_response - mock_ptab_appeals_client.session = mock_session + mock_ptab_appeals_client.config._session = mock_session # Test result = mock_ptab_appeals_client.search_decisions( @@ -427,43 +419,49 @@ def test_paginate_decisions_with_multiple_params( class TestPTABAppealsDownloadMethods: """Tests for PTAB Appeals download methods.""" - def test_download_appeal_archive_missing_uri_raises_error(self) -> None: + def test_download_appeal_archive_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_appeal_archive raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import AppealMetaData - client = PTABAppealsClient(api_key="test") - + client = PTABAppealsClient(config=uspto_config) # Create AppealMetaData without file_download_uri meta_data = AppealMetaData(file_download_uri=None) with pytest.raises(ValueError, match="AppealMetaData has no file_download_uri"): client.download_appeal_archive(meta_data) - def test_download_appeal_archive_with_uri(self) -> None: + def test_download_appeal_archive_with_uri(self, uspto_config: USPTOConfig) -> None: """Test download_appeal_archive calls _download_file with URI.""" from unittest.mock import patch from pyUSPTO.models.ptab import AppealMetaData - client = PTABAppealsClient(api_key="test") + client = PTABAppealsClient(config=uspto_config) meta_data = AppealMetaData(file_download_uri="https://test.com/appeal.tar") - with patch.object(client, "_download_file", return_value="/path/to/file") as mock_download: - result = client.download_appeal_archive(meta_data, destination="/dest", file_name="custom.tar", overwrite=True) + with patch.object( + client, "_download_file", return_value="/path/to/file" + ) as mock_download: + result = client.download_appeal_archive( + meta_data, destination="/dest", file_name="custom.tar", overwrite=True + ) mock_download.assert_called_once_with( url="https://test.com/appeal.tar", destination="/dest", file_name="custom.tar", - overwrite=True + overwrite=True, ) assert result == "/path/to/file" - def test_download_appeal_documents_missing_uri_raises_error(self) -> None: + def test_download_appeal_documents_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_appeal_documents raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import AppealMetaData - client = PTABAppealsClient(api_key="test") - + client = PTABAppealsClient(config=uspto_config) # Create AppealMetaData without file_download_uri meta_data = AppealMetaData(file_download_uri=None) @@ -476,45 +474,51 @@ def test_download_appeal_documents_with_uri(self) -> None: from pyUSPTO.models.ptab import AppealMetaData - client = PTABAppealsClient(api_key="test") + client = PTABAppealsClient(config=USPTOConfig(api_key="test")) meta_data = AppealMetaData(file_download_uri="https://test.com/appeal.tar") - with patch.object(client, "_download_and_extract", return_value="/path/to/extracted") as mock_extract: - result = client.download_appeal_documents(meta_data, destination="/dest", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/extracted" + ) as mock_extract: + result = client.download_appeal_documents( + meta_data, destination="/dest", overwrite=True + ) mock_extract.assert_called_once_with( - url="https://test.com/appeal.tar", - destination="/dest", - overwrite=True + url="https://test.com/appeal.tar", destination="/dest", overwrite=True ) assert result == "/path/to/extracted" - def test_download_appeal_document_missing_uri_raises_error(self) -> None: + def test_download_appeal_document_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_appeal_document raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import AppealDocumentData - client = PTABAppealsClient(api_key="test") - + client = PTABAppealsClient(config=uspto_config) # Create AppealDocumentData without file_download_uri document_data = AppealDocumentData(file_download_uri=None) - with pytest.raises(ValueError, match="AppealDocumentData has no file_download_uri"): + with pytest.raises( + ValueError, match="AppealDocumentData has no file_download_uri" + ): client.download_appeal_document(document_data) - def test_download_appeal_document_with_uri(self) -> None: + def test_download_appeal_document_with_uri(self, uspto_config: USPTOConfig) -> None: """Test download_appeal_document calls _download_and_extract with URI.""" - from unittest.mock import patch - from pyUSPTO.models.ptab import AppealDocumentData - - client = PTABAppealsClient(api_key="test") + client = PTABAppealsClient(config=uspto_config) document_data = AppealDocumentData(file_download_uri="https://test.com/doc.pdf") - with patch.object(client, "_download_and_extract", return_value="/path/to/doc.pdf") as mock_extract: - result = client.download_appeal_document(document_data, destination="/dest", file_name="doc.pdf", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/doc.pdf" + ) as mock_extract: + result = client.download_appeal_document( + document_data, destination="/dest", file_name="doc.pdf", overwrite=True + ) mock_extract.assert_called_once_with( url="https://test.com/doc.pdf", destination="/dest", file_name="doc.pdf", - overwrite=True + overwrite=True, ) assert result == "/path/to/doc.pdf" diff --git a/tests/clients/test_ptab_interferences_client.py b/tests/clients/test_ptab_interferences_client.py index 9f12a16..122e3a7 100644 --- a/tests/clients/test_ptab_interferences_client.py +++ b/tests/clients/test_ptab_interferences_client.py @@ -10,7 +10,7 @@ import pytest from pyUSPTO import PTABInterferencesClient, USPTOConfig -from pyUSPTO.models.ptab import PTABInterferenceResponse +from pyUSPTO.models.ptab import InterferenceDocumentData, PTABInterferenceResponse @pytest.fixture @@ -19,6 +19,12 @@ def api_key_fixture() -> str: return "test_key" +@pytest.fixture +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + @pytest.fixture def interference_decision_sample() -> dict[str, Any]: """Sample interference decision data for testing.""" @@ -109,48 +115,38 @@ def interference_decision_sample() -> dict[str, Any]: @pytest.fixture -def mock_ptab_interferences_client(api_key_fixture: str) -> PTABInterferencesClient: +def mock_ptab_interferences_client( + uspto_config: USPTOConfig, +) -> PTABInterferencesClient: """Fixture for mock PTABInterferencesClient.""" - return PTABInterferencesClient(api_key=api_key_fixture) + return PTABInterferencesClient(config=uspto_config) class TestPTABInterferencesClientInit: """Tests for initialization of PTABInterferencesClient.""" - def test_init_with_api_key(self, api_key_fixture: str) -> None: + def test_init_with_api_key( + self, api_key_fixture: str, uspto_config: USPTOConfig + ) -> None: """Test initialization with API key.""" - client = PTABInterferencesClient(api_key=api_key_fixture) + client = PTABInterferencesClient(config=uspto_config) assert client._api_key == api_key_fixture assert client.base_url == "https://api.uspto.gov" - def test_init_with_custom_base_url(self, api_key_fixture: str) -> None: + def test_init_with_custom_base_url( + self, api_key_fixture: str, uspto_config: USPTOConfig + ) -> None: """Test initialization with custom base URL.""" custom_url = "https://custom.api.test.com" - client = PTABInterferencesClient(api_key=api_key_fixture, base_url=custom_url) + client = PTABInterferencesClient(config=uspto_config, base_url=custom_url) assert client._api_key == api_key_fixture assert client.base_url == custom_url - def test_init_with_config(self) -> None: - """Test initialization with config object.""" - config_key = "config_key" - config_url = "https://config.api.test.com" - config = USPTOConfig(api_key=config_key, ptab_base_url=config_url) - client = PTABInterferencesClient(config=config) - assert client._api_key == config_key - assert client.base_url == config_url - assert client.config is config - - def test_init_with_api_key_and_config(self, api_key_fixture: str) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", - ptab_base_url="https://config.api.test.com", - ) - client = PTABInterferencesClient(api_key=api_key_fixture, config=config) - # API key parameter takes precedence - assert client._api_key == api_key_fixture - # But base_url comes from config - assert client.base_url == "https://config.api.test.com" + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = PTABInterferencesClient() + assert client.config.api_key == "env_key" class TestPTABInterferencesClientSearchDecisions: @@ -167,7 +163,7 @@ def test_search_decisions_get_with_query( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -194,7 +190,7 @@ def test_search_decisions_get_with_convenience_params( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -238,7 +234,7 @@ def test_search_decisions_get_with_date_from_only( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -262,7 +258,7 @@ def test_search_decisions_get_with_date_to_only( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -286,7 +282,7 @@ def test_search_decisions_get_with_all_convenience_params( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -327,7 +323,7 @@ def test_search_decisions_with_real_party_in_interest_q( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -354,7 +350,7 @@ def test_search_decisions_post_with_body( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.post.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session post_body = { "q": "interferenceOutcomeCategory:Priority to Senior Party", @@ -381,7 +377,7 @@ def test_search_decisions_with_optional_params( mock_response = MagicMock() mock_response.json.return_value = interference_decision_sample mock_session.get.return_value = mock_response - mock_ptab_interferences_client.session = mock_session + mock_ptab_interferences_client.config._session = mock_session # Test result = mock_ptab_interferences_client.search_decisions( @@ -530,94 +526,124 @@ def test_paginate_decisions_with_multiple_params( class TestPTABInterferencesDownloadMethods: """Tests for PTAB Interferences download methods.""" - def test_download_interference_archive_missing_uri_raises_error(self) -> None: + def test_download_interference_archive_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_archive raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import InterferenceMetaData - client = PTABInterferencesClient(api_key="test") - + client = PTABInterferencesClient(config=uspto_config) # Create InterferenceMetaData without file_download_uri meta_data = InterferenceMetaData(file_download_uri=None) - with pytest.raises(ValueError, match="InterferenceMetaData has no file_download_uri"): + with pytest.raises( + ValueError, match="InterferenceMetaData has no file_download_uri" + ): client.download_interference_archive(meta_data) - def test_download_interference_archive_with_uri(self) -> None: + def test_download_interference_archive_with_uri( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_archive calls _download_file with URI.""" from unittest.mock import patch from pyUSPTO.models.ptab import InterferenceMetaData - client = PTABInterferencesClient(api_key="test") - meta_data = InterferenceMetaData(file_download_uri="https://test.com/interference.tar") + client = PTABInterferencesClient(config=uspto_config) + meta_data = InterferenceMetaData( + file_download_uri="https://test.com/interference.tar" + ) - with patch.object(client, "_download_file", return_value="/path/to/file") as mock_download: - result = client.download_interference_archive(meta_data, destination="/dest", file_name="custom.tar", overwrite=True) + with patch.object( + client, "_download_file", return_value="/path/to/file" + ) as mock_download: + result = client.download_interference_archive( + meta_data, destination="/dest", file_name="custom.tar", overwrite=True + ) mock_download.assert_called_once_with( url="https://test.com/interference.tar", destination="/dest", file_name="custom.tar", - overwrite=True + overwrite=True, ) assert result == "/path/to/file" - def test_download_interference_documents_missing_uri_raises_error(self) -> None: + def test_download_interference_documents_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_documents raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import InterferenceMetaData - client = PTABInterferencesClient(api_key="test") - + client = PTABInterferencesClient(config=uspto_config) # Create InterferenceMetaData without file_download_uri meta_data = InterferenceMetaData(file_download_uri=None) - with pytest.raises(ValueError, match="InterferenceMetaData has no file_download_uri"): + with pytest.raises( + ValueError, match="InterferenceMetaData has no file_download_uri" + ): client.download_interference_documents(meta_data) - def test_download_interference_documents_with_uri(self) -> None: + def test_download_interference_documents_with_uri( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_documents calls _download_and_extract with URI.""" from unittest.mock import patch from pyUSPTO.models.ptab import InterferenceMetaData - client = PTABInterferencesClient(api_key="test") - meta_data = InterferenceMetaData(file_download_uri="https://test.com/interference.tar") + client = PTABInterferencesClient(config=uspto_config) + meta_data = InterferenceMetaData( + file_download_uri="https://test.com/interference.tar" + ) - with patch.object(client, "_download_and_extract", return_value="/path/to/extracted") as mock_extract: - result = client.download_interference_documents(meta_data, destination="/dest", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/extracted" + ) as mock_extract: + result = client.download_interference_documents( + meta_data, destination="/dest", overwrite=True + ) mock_extract.assert_called_once_with( url="https://test.com/interference.tar", destination="/dest", - overwrite=True + overwrite=True, ) assert result == "/path/to/extracted" - def test_download_interference_document_missing_uri_raises_error(self) -> None: + def test_download_interference_document_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_document raises ValueError when file_download_uri is None.""" from pyUSPTO.models.ptab import InterferenceDocumentData - client = PTABInterferencesClient(api_key="test") - + client = PTABInterferencesClient(config=uspto_config) # Create InterferenceDocumentData without file_download_uri document_data = InterferenceDocumentData(file_download_uri=None) - with pytest.raises(ValueError, match="InterferenceDocumentData has no file_download_uri"): + with pytest.raises( + ValueError, match="InterferenceDocumentData has no file_download_uri" + ): client.download_interference_document(document_data) - def test_download_interference_document_with_uri(self) -> None: + def test_download_interference_document_with_uri( + self, uspto_config: USPTOConfig + ) -> None: """Test download_interference_document calls _download_and_extract with URI.""" - from unittest.mock import patch - - from pyUSPTO.models.ptab import InterferenceDocumentData - client = PTABInterferencesClient(api_key="test") - document_data = InterferenceDocumentData(file_download_uri="https://test.com/doc.pdf") + client = PTABInterferencesClient(config=uspto_config) + document_data = InterferenceDocumentData( + file_download_uri="https://test.com/doc.pdf" + ) - with patch.object(client, "_download_and_extract", return_value="/path/to/doc.pdf") as mock_extract: - result = client.download_interference_document(document_data, destination="/dest", file_name="doc.pdf", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/doc.pdf" + ) as mock_extract: + result = client.download_interference_document( + document_data, destination="/dest", file_name="doc.pdf", overwrite=True + ) mock_extract.assert_called_once_with( url="https://test.com/doc.pdf", destination="/dest", file_name="doc.pdf", - overwrite=True + overwrite=True, ) assert result == "/path/to/doc.pdf" diff --git a/tests/clients/test_ptab_trials_client.py b/tests/clients/test_ptab_trials_client.py index 01fbc8e..7f72b55 100644 --- a/tests/clients/test_ptab_trials_client.py +++ b/tests/clients/test_ptab_trials_client.py @@ -13,15 +13,30 @@ from pyUSPTO.models.ptab import ( PTABTrialDocumentResponse, PTABTrialProceedingResponse, + TrialDocumentData, + TrialMetaData, ) +# --- Fixtures --- @pytest.fixture def api_key_fixture() -> str: - """Fixture for test API key.""" + """Provides a test API key.""" return "test_key" +@pytest.fixture +def uspto_config(api_key_fixture: str) -> USPTOConfig: + """Provides a USPTOConfig instance with test API key.""" + return USPTOConfig(api_key=api_key_fixture) + + +@pytest.fixture +def mock_ptab_trials_client(uspto_config: USPTOConfig) -> PTABTrialsClient: + """Fixture for mock PTABTrialsClient.""" + return PTABTrialsClient(config=uspto_config) + + @pytest.fixture def trial_proceeding_sample() -> dict[str, Any]: """Sample trial proceeding data for testing.""" @@ -88,49 +103,31 @@ def trial_document_sample() -> dict[str, Any]: } -@pytest.fixture -def mock_ptab_trials_client(api_key_fixture: str) -> PTABTrialsClient: - """Fixture for mock PTABTrialsClient.""" - return PTABTrialsClient(api_key=api_key_fixture) - - class TestPTABTrialsClientInit: """Tests for initialization of PTABTrialsClient.""" - def test_init_with_api_key(self, api_key_fixture: str) -> None: + def test_init_with_api_key( + self, api_key_fixture: str, uspto_config: USPTOConfig + ) -> None: """Test initialization with API key.""" - client = PTABTrialsClient(api_key=api_key_fixture) + client = PTABTrialsClient(config=uspto_config) assert client._api_key == api_key_fixture assert client.base_url == "https://api.uspto.gov" - def test_init_with_custom_base_url(self, api_key_fixture: str) -> None: + def test_init_with_custom_base_url( + self, api_key_fixture: str, uspto_config: USPTOConfig + ) -> None: """Test initialization with custom base URL.""" custom_url = "https://custom.api.test.com" - client = PTABTrialsClient(api_key=api_key_fixture, base_url=custom_url) + client = PTABTrialsClient(config=uspto_config, base_url=custom_url) assert client._api_key == api_key_fixture assert client.base_url == custom_url - def test_init_with_config(self) -> None: - """Test initialization with config object.""" - config_key = "config_key" - config_url = "https://config.api.test.com" - config = USPTOConfig(api_key=config_key, ptab_base_url=config_url) - client = PTABTrialsClient(config=config) - assert client._api_key == config_key - assert client.base_url == config_url - assert client.config is config - - def test_init_with_api_key_and_config(self, api_key_fixture: str) -> None: - """Test initialization with both API key and config.""" - config = USPTOConfig( - api_key="config_key", - ptab_base_url="https://config.api.test.com", - ) - client = PTABTrialsClient(api_key=api_key_fixture, config=config) - # API key parameter takes precedence - assert client._api_key == api_key_fixture - # But base_url comes from config - assert client.base_url == "https://config.api.test.com" + def test_init_without_config(self, monkeypatch: Any) -> None: + """Test initialization without config uses environment.""" + monkeypatch.setenv("USPTO_API_KEY", "env_key") + client = PTABTrialsClient() + assert client.config.api_key == "env_key" class TestPTABTrialsClientSearchProceedings: @@ -147,7 +144,7 @@ def test_search_proceedings_get_with_query( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -174,7 +171,7 @@ def test_search_proceedings_get_with_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -207,7 +204,7 @@ def test_search_proceedings_with_all_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -244,7 +241,7 @@ def test_search_proceedings_with_date_from_only( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -268,7 +265,7 @@ def test_search_proceedings_with_date_to_only( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -292,7 +289,7 @@ def test_search_proceedings_with_optional_params( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_proceedings( @@ -331,7 +328,7 @@ def test_search_proceedings_post_with_body( mock_response = MagicMock() mock_response.json.return_value = trial_proceeding_sample mock_session.post.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session post_body = {"q": "trialTypeCode:IPR", "limit": 100} @@ -359,7 +356,7 @@ def test_search_documents_get_with_query( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents( @@ -381,7 +378,7 @@ def test_search_documents_with_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents( @@ -411,7 +408,7 @@ def test_search_documents_with_all_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents( @@ -462,7 +459,7 @@ def test_search_documents_with_date_from_only( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents( @@ -486,7 +483,7 @@ def test_search_documents_with_date_to_only( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents(filing_date_to_q="2023-12-31") @@ -508,7 +505,7 @@ def test_search_documents_post_with_body( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.post.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session post_body = {"q": "documentCategory:Paper", "limit": 100} @@ -532,7 +529,7 @@ def test_search_documents_with_optional_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_documents( @@ -575,7 +572,7 @@ def test_search_decisions_get_with_query( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -597,7 +594,7 @@ def test_search_decisions_with_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -624,7 +621,7 @@ def test_search_decisions_with_all_convenience_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -676,7 +673,7 @@ def test_search_decisions_with_date_from_only( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -700,7 +697,7 @@ def test_search_decisions_with_date_to_only( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -724,7 +721,7 @@ def test_search_decisions_with_document_type_description_q( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -751,7 +748,7 @@ def test_search_decisions_post_with_body( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.post.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session post_body = {"q": "decisionTypeCategory:Final Written Decision", "limit": 100} @@ -775,7 +772,7 @@ def test_search_decisions_with_optional_params( mock_response = MagicMock() mock_response.json.return_value = trial_document_sample mock_session.get.return_value = mock_response - mock_ptab_trials_client.session = mock_session + mock_ptab_trials_client.config._session = mock_session # Test result = mock_ptab_trials_client.search_decisions( @@ -870,94 +867,92 @@ def test_paginate_proceedings_rejects_offset_in_kwargs( class TestPTABTrialsDownloadMethods: """Tests for PTAB Trials download methods.""" - def test_download_trial_archive_missing_uri_raises_error(self) -> None: + def test_download_trial_archive_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_trial_archive raises ValueError when file_download_uri is None.""" - from pyUSPTO.models.ptab import TrialMetaData - - client = PTABTrialsClient(api_key="test") - + client = PTABTrialsClient(config=uspto_config) # Create TrialMetaData without file_download_uri meta_data = TrialMetaData(file_download_uri=None) with pytest.raises(ValueError, match="TrialMetaData has no file_download_uri"): client.download_trial_archive(meta_data) - def test_download_trial_archive_with_uri(self) -> None: + def test_download_trial_archive_with_uri(self, uspto_config: USPTOConfig) -> None: """Test download_trial_archive calls _download_file with URI.""" - from unittest.mock import patch - - from pyUSPTO.models.ptab import TrialMetaData - - client = PTABTrialsClient(api_key="test") + client = PTABTrialsClient(config=uspto_config) meta_data = TrialMetaData(file_download_uri="https://test.com/trial.tar") - with patch.object(client, "_download_file", return_value="/path/to/file") as mock_download: - result = client.download_trial_archive(meta_data, destination="/dest", file_name="custom.tar", overwrite=True) + with patch.object( + client, "_download_file", return_value="/path/to/file" + ) as mock_download: + result = client.download_trial_archive( + meta_data, destination="/dest", file_name="custom.tar", overwrite=True + ) mock_download.assert_called_once_with( url="https://test.com/trial.tar", destination="/dest", file_name="custom.tar", - overwrite=True + overwrite=True, ) assert result == "/path/to/file" - def test_download_trial_documents_missing_uri_raises_error(self) -> None: + def test_download_trial_documents_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_trial_documents raises ValueError when file_download_uri is None.""" - from pyUSPTO.models.ptab import TrialMetaData - - client = PTABTrialsClient(api_key="test") + client = PTABTrialsClient(config=uspto_config) # Create TrialMetaData without file_download_uri meta_data = TrialMetaData(file_download_uri=None) with pytest.raises(ValueError, match="TrialMetaData has no file_download_uri"): client.download_trial_documents(meta_data) - def test_download_trial_documents_with_uri(self) -> None: + def test_download_trial_documents_with_uri(self, uspto_config: USPTOConfig) -> None: """Test download_trial_documents calls _download_and_extract with URI.""" - from unittest.mock import patch - - from pyUSPTO.models.ptab import TrialMetaData - - client = PTABTrialsClient(api_key="test") + client = PTABTrialsClient(config=uspto_config) meta_data = TrialMetaData(file_download_uri="https://test.com/trial.tar") - with patch.object(client, "_download_and_extract", return_value="/path/to/extracted") as mock_extract: - result = client.download_trial_documents(meta_data, destination="/dest", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/extracted" + ) as mock_extract: + result = client.download_trial_documents( + meta_data, destination="/dest", overwrite=True + ) mock_extract.assert_called_once_with( - url="https://test.com/trial.tar", - destination="/dest", - overwrite=True + url="https://test.com/trial.tar", destination="/dest", overwrite=True ) assert result == "/path/to/extracted" - def test_download_trial_document_missing_uri_raises_error(self) -> None: + def test_download_trial_document_missing_uri_raises_error( + self, uspto_config: USPTOConfig + ) -> None: """Test download_trial_document raises ValueError when file_download_uri is None.""" - from pyUSPTO.models.ptab import TrialDocumentData - - client = PTABTrialsClient(api_key="test") - + client = PTABTrialsClient(config=uspto_config) # Create TrialDocumentData without file_download_uri document_data = TrialDocumentData(file_download_uri=None) - with pytest.raises(ValueError, match="TrialDocumentData has no file_download_uri"): + with pytest.raises( + ValueError, match="TrialDocumentData has no file_download_uri" + ): client.download_trial_document(document_data) - def test_download_trial_document_with_uri(self) -> None: + def test_download_trial_document_with_uri(self, uspto_config: USPTOConfig) -> None: """Test download_trial_document calls _download_and_extract with URI.""" - from unittest.mock import patch - - from pyUSPTO.models.ptab import TrialDocumentData - - client = PTABTrialsClient(api_key="test") + client = PTABTrialsClient(config=uspto_config) document_data = TrialDocumentData(file_download_uri="https://test.com/doc.pdf") - with patch.object(client, "_download_and_extract", return_value="/path/to/doc.pdf") as mock_extract: - result = client.download_trial_document(document_data, destination="/dest", file_name="doc.pdf", overwrite=True) + with patch.object( + client, "_download_and_extract", return_value="/path/to/doc.pdf" + ) as mock_extract: + result = client.download_trial_document( + document_data, destination="/dest", file_name="doc.pdf", overwrite=True + ) mock_extract.assert_called_once_with( url="https://test.com/doc.pdf", destination="/dest", file_name="doc.pdf", - overwrite=True + overwrite=True, ) assert result == "/path/to/doc.pdf" diff --git a/tests/conftest.py b/tests/conftest.py index fec2b62..e59f1f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,7 +90,7 @@ def bulk_data_sample() -> dict[str, Any]: "fileDataFromDate": "2023-01-01", "fileDataToDate": "2023-06-30", "fileTypeText": "ZIP", - "fileReleaseDate": "2023-07-01", + "fileReleaseDate": "2023-07-01 00:00:00", "fileDownloadURI": "https://example.com/test1.zip", }, { @@ -99,7 +99,7 @@ def bulk_data_sample() -> dict[str, Any]: "fileDataFromDate": "2023-07-01", "fileDataToDate": "2023-12-31", "fileTypeText": "ZIP", - "fileReleaseDate": "2024-01-01", + "fileReleaseDate": "2024-01-01 00:00:00", "fileDownloadURI": "https://example.com/test2.zip", }, ], @@ -128,7 +128,7 @@ def bulk_data_sample() -> dict[str, Any]: "fileDataFromDate": "2023-01-01", "fileDataToDate": "2023-12-31", "fileTypeText": "ZIP", - "fileReleaseDate": "2024-01-01", + "fileReleaseDate": "2024-01-01 00:00:00", "fileDownloadURI": "https://example.com/test3.zip", } ], @@ -244,7 +244,7 @@ def mock_bulk_data_client( BulkDataClient: A client with a mocked session """ client = BulkDataClient(config=uspto_config) - client.session = mock_session + client.config._session = mock_session return client @@ -263,5 +263,5 @@ def mock_patent_data_client( PatentDataClient: A client with a mocked session """ client = PatentDataClient(config=uspto_config) - client.session = mock_session + client.config._session = mock_session return client diff --git a/tests/test_config.py b/tests/test_config.py index 149d368..b5b38e6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,7 @@ """Tests for USPTOConfig""" +from pyUSPTO.clients.bulk_data import BulkDataClient +from pyUSPTO.clients.patent_data import PatentDataClient from pyUSPTO.config import USPTOConfig from pyUSPTO.http_config import HTTPConfig @@ -92,3 +94,30 @@ def test_http_config_sharing(self): assert config1.http_config is config2.http_config assert config1.http_config.timeout == 90.0 assert config2.http_config.timeout == 90.0 + + def test_session_lifecycle(self): + """Test session sharing, lazy creation, reuse, and cleanup behavior""" + + # Test lazy session creation + config = USPTOConfig(api_key="test") + assert config._session is None # Session not created until first access + + # Test session sharing across multiple clients + client1 = PatentDataClient(config=config) + client2 = BulkDataClient(config=config) + assert client1.session is client2.session + + # Test session reuse + session1 = config.session + session2 = config.session + assert session1 is session2 + + # Test clients don't close shared config sessions + shared_session = client1.session + client1.close() + assert config._session is not None + assert config.session is shared_session # Session still alive + + # Test config.close() clears session reference + config.close() + assert config._session is None