diff --git a/pyaqsapi/helperfunctions.py b/pyaqsapi/helperfunctions.py index a6aaac3..c73ea48 100644 --- a/pyaqsapi/helperfunctions.py +++ b/pyaqsapi/helperfunctions.py @@ -1,21 +1,74 @@ """pyaqsapi core functions.""" +import threading +import time from collections.abc import Iterable from datetime import date from itertools import starmap - -# from time import sleep __aqs_ratelimit() is deprecated, use ratelimit package instead from typing import Any, cast, no_type_check from warnings import warn + from certifi import where from pandas import DataFrame, concat -from ratelimit import limits, sleep_and_retry from requests import get from requests.exceptions import ConnectionError, HTTPError, Timeout # noqa # pylint: disable=wrong-spelling-in-comment AQS_user: str | None = None AQS_key: str | None = None -ONE_MINUTE = 60 # set 60 second period for ratelimit package decorator +ONE_MINUTE = 60 # 60 second period for rate limiting + +# Rate limiting configuration: 10 calls per minute = 1 call every 6 seconds +RATE_LIMIT_CALLS = 10 +RATE_LIMIT_PERIOD = ONE_MINUTE +MIN_INTERVAL_BETWEEN_CALLS = RATE_LIMIT_PERIOD / RATE_LIMIT_CALLS # 6 seconds + + +class TokenBucketRateLimiter: + """ + A token bucket rate limiter that evenly spaces out API calls. + + This implementation ensures that calls are spaced evenly over time + rather than allowing bursts followed by long waits. This is similar + to httr2's req_throttle behavior in R. + """ + + def __init__(self, calls: int = RATE_LIMIT_CALLS, period: float = RATE_LIMIT_PERIOD) -> None: + """ + Initialize the rate limiter. + + Parameters + ---------- + calls : int + Maximum number of calls allowed in the period + period : float + Time period in seconds + """ + self.calls = calls + self.period = period + self.min_interval = period / calls + self.last_call_time: float = 0.0 + self.lock = threading.Lock() + + def acquire(self) -> None: + """ + Acquire permission to make an API call. + + This method will block if necessary to ensure calls are spaced + evenly according to the rate limit. + """ + with self.lock: + current_time = time.time() + time_since_last_call = current_time - self.last_call_time + + if time_since_last_call < self.min_interval: + sleep_time = self.min_interval - time_since_last_call + time.sleep(sleep_time) + + self.last_call_time = time.time() + + +# Global rate limiter instance +_rate_limiter = TokenBucketRateLimiter(calls=RATE_LIMIT_CALLS, period=RATE_LIMIT_PERIOD) class AQSAPI_V2: @@ -185,8 +238,6 @@ def get_request_time(self) -> str: return str(self._request_time) @no_type_check - @sleep_and_retry - @limits(calls=10, period=ONE_MINUTE) def __aqs( self, service: str | None = None, @@ -227,6 +278,9 @@ def __aqs( (AQSAPI_v2) An AQSAPI_V2 instance containing the data requested. """ + # Apply rate limiting before making the API call + _rate_limiter.acquire() + user_agent = "pyAQSAPI module for python3" # server = ":AQSDatamartAPI:" # check if either aqs_username or aqs_key are None @@ -300,8 +354,6 @@ def __aqs( ) else: warn(category=UserWarning, message="pyaqsapi experienced an error:" + f"{newline} {exception}") - # finally: - # self.__aqs_ratelimit() # use ratelimit package instead return self def _aqs_services_by_site( diff --git a/tests/test_helperfunctions.py b/tests/test_helperfunctions.py index 4f6abeb..eaad9fd 100644 --- a/tests/test_helperfunctions.py +++ b/tests/test_helperfunctions.py @@ -4,7 +4,7 @@ import pytest from pandas import DataFrame -from pyaqsapi.helperfunctions import aqs_credentials +from pyaqsapi.helperfunctions import aqs_credentials, TokenBucketRateLimiter, RATE_LIMIT_CALLS, RATE_LIMIT_PERIOD from pyaqsapi import listfunctions @@ -35,3 +35,13 @@ def setuppyaqsapi(autouse=True): def test_aqs_removeheader(setuppyaqsapi): returnvalue = listfunctions.aqs_knownissues(return_header=False) assert isinstance(returnvalue, DataFrame) + + +def test_rate_limiter_exists(): + """Test that TokenBucketRateLimiter class exists and has correct defaults.""" + limiter = TokenBucketRateLimiter() + assert limiter.calls == RATE_LIMIT_CALLS + assert limiter.period == RATE_LIMIT_PERIOD + assert limiter.min_interval == RATE_LIMIT_PERIOD / RATE_LIMIT_CALLS + # Test that acquire method exists and can be called + limiter.acquire() # Should not raise an exception diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..532ebc8 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,94 @@ +"""Unit tests for TokenBucketRateLimiter.""" + +import time +import threading + +import pytest + +from pyaqsapi.helperfunctions import TokenBucketRateLimiter, RATE_LIMIT_CALLS, RATE_LIMIT_PERIOD + + +def test_rate_limiter_basic_functionality(): + """Test that rate limiter enforces minimum interval between calls.""" + limiter = TokenBucketRateLimiter(calls=10, period=60) + + start_time = time.time() + limiter.acquire() + first_call_time = time.time() + + limiter.acquire() + second_call_time = time.time() + + # Should have waited at least 6 seconds (60/10) between calls + elapsed = second_call_time - first_call_time + assert elapsed >= 5.9, f"Expected at least 6 seconds, got {elapsed}" # Allow small margin for timing + + +def test_rate_limiter_multiple_calls(): + """Test that rate limiter properly spaces multiple calls.""" + limiter = TokenBucketRateLimiter(calls=5, period=10) # 5 calls per 10 seconds = 2 seconds per call + + times = [] + for _ in range(3): + limiter.acquire() + times.append(time.time()) + + # Check intervals between calls + interval1 = times[1] - times[0] + interval2 = times[2] - times[1] + + # Each interval should be approximately 2 seconds (allow 0.1s margin) + assert interval1 >= 1.9, f"First interval should be ~2s, got {interval1}" + assert interval2 >= 1.9, f"Second interval should be ~2s, got {interval2}" + + +def test_rate_limiter_thread_safety(): + """Test that rate limiter is thread-safe.""" + limiter = TokenBucketRateLimiter(calls=10, period=10) # 10 calls per 10 seconds = 1 second per call + + call_times = [] + lock = threading.Lock() + + def make_call(): + limiter.acquire() + with lock: + call_times.append(time.time()) + + # Create 5 threads that all try to make calls simultaneously + threads = [threading.Thread(target=make_call) for _ in range(5)] + + start_time = time.time() + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + end_time = time.time() + + # All calls should complete, and they should be spaced out + assert len(call_times) == 5 + # Total time should be at least 4 seconds (5 calls with 1s intervals) + assert (end_time - start_time) >= 3.9, f"Expected at least 4 seconds for 5 calls, got {end_time - start_time}" + + +def test_rate_limiter_default_values(): + """Test that rate limiter uses correct default values.""" + limiter = TokenBucketRateLimiter() + + assert limiter.calls == RATE_LIMIT_CALLS + assert limiter.period == RATE_LIMIT_PERIOD + assert limiter.min_interval == RATE_LIMIT_PERIOD / RATE_LIMIT_CALLS + + +def test_rate_limiter_first_call_immediate(): + """Test that first call doesn't wait.""" + limiter = TokenBucketRateLimiter(calls=10, period=60) + + start_time = time.time() + limiter.acquire() + elapsed = time.time() - start_time + + # First call should be immediate (less than 0.1 seconds) + assert elapsed < 0.1, f"First call should be immediate, took {elapsed} seconds" +