diff --git a/main.py b/main.py index e6aabc5..f10c5b7 100644 --- a/main.py +++ b/main.py @@ -359,8 +359,8 @@ def _fetch_folder_rules(folder_id: str): folders = list_existing_folders(client, profile_id) # Parallelize fetching rules from folders. - # Using 5 workers to be safe with rate limits, though GETs are usually cheaper. - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + # Using 10 workers to speed up rule fetching (GETs are cheaper). + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [ executor.submit(_fetch_folder_rules, folder_id) for folder_name, folder_id in folders.items() @@ -489,11 +489,11 @@ def push_rules( log.info(f"Folder {sanitize_for_log(folder_name)} - no new rules to push after filtering duplicates") return True - successful_batches = 0 - total_batches = len(range(0, len(filtered_hostnames), BATCH_SIZE)) + total_batches = (len(filtered_hostnames) + BATCH_SIZE - 1) // BATCH_SIZE - for i, start in enumerate(range(0, len(filtered_hostnames), BATCH_SIZE), 1): - batch = filtered_hostnames[start : start + BATCH_SIZE] + # Helper for processing a single batch + def _push_batch(batch_idx: int, start_idx: int) -> bool: + batch = filtered_hostnames[start_idx : start_idx + BATCH_SIZE] data = { "do": str(do), "status": str(status), @@ -506,18 +506,47 @@ def push_rules( _api_post_form(client, f"{API_BASE}/{profile_id}/rules", data=data) log.info( "Folder %s – batch %d: added %d rules", - sanitize_for_log(folder_name), i, len(batch) + sanitize_for_log(folder_name), batch_idx, len(batch) ) - successful_batches += 1 + + # Thread-safe update of existing rules if existing_rules_lock: with existing_rules_lock: existing_rules.update(batch) else: existing_rules.update(batch) + return True except httpx.HTTPError as e: - log.error(f"Failed to push batch {i} for folder {sanitize_for_log(folder_name)}: {sanitize_for_log(e)}") + log.error(f"Failed to push batch {batch_idx} for folder {sanitize_for_log(folder_name)}: {sanitize_for_log(e)}") if hasattr(e, 'response') and e.response is not None: log.debug(f"Response content: {e.response.text}") + return False + + # Parallelize batch uploads to speed up large folders + # Using 5 workers to be safe with POST rate limits while getting significant speedup + successful_batches = 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = { + executor.submit(_push_batch, i, start): i + for i, start in enumerate(range(0, len(filtered_hostnames), BATCH_SIZE), 1) + } + + for future in concurrent.futures.as_completed(futures): + batch_idx = futures[future] + try: + result = future.result() + except Exception as e: + log.error( + "Unexpected error while processing batch %d for folder %s: %s", + batch_idx, + sanitize_for_log(folder_name), + sanitize_for_log(e), + ) + log.debug("Unexpected exception details", exc_info=True) + continue + + if result: + successful_batches += 1 if successful_batches == total_batches: log.info("Folder %s – finished (%d new rules added)", sanitize_for_log(folder_name), len(filtered_hostnames)) diff --git a/tests/test_performance.py b/tests/test_performance.py new file mode 100644 index 0000000..4217799 --- /dev/null +++ b/tests/test_performance.py @@ -0,0 +1,118 @@ +import unittest +from unittest.mock import MagicMock, patch +import time +import threading +from main import push_rules, BATCH_SIZE +import httpx + +class TestPushRulesPerformance(unittest.TestCase): + def setUp(self): + self.client = MagicMock() + self.profile_id = "test-profile" + self.folder_name = "test-folder" + self.folder_id = "test-folder-id" + self.do = 1 + self.status = 1 + self.existing_rules = set() + + @patch('main._api_post_form') + def test_push_rules_parallel_with_lock(self, mock_post): + # Create enough hostnames for 5 batches + num_batches = 5 + hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] + + # Mock success + mock_post.return_value = MagicMock(status_code=200) + + lock = threading.Lock() + + start_time = time.time() + success = push_rules( + self.profile_id, + self.folder_name, + self.folder_id, + self.do, + self.status, + hostnames, + self.existing_rules, + self.client, + existing_rules_lock=lock + ) + duration = time.time() - start_time + + self.assertTrue(success) + self.assertEqual(mock_post.call_count, num_batches) + self.assertEqual(len(self.existing_rules), len(hostnames)) + + print(f"\n[Parallel with Lock] Duration: {duration:.4f}s") + + @patch('main._api_post_form') + def test_push_rules_concurrency(self, mock_post): + # Create enough hostnames for 10 batches + num_batches = 10 + hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] + + # Mock delay to simulate network latency + def delayed_post(*args, **kwargs): + time.sleep(0.1) + return MagicMock(status_code=200) + + mock_post.side_effect = delayed_post + + start_time = time.time() + success = push_rules( + self.profile_id, + self.folder_name, + self.folder_id, + self.do, + self.status, + hostnames, + self.existing_rules, + self.client + ) + duration = time.time() - start_time + + self.assertTrue(success) + self.assertEqual(mock_post.call_count, num_batches) + + print(f"\n[Performance Test] Duration for {num_batches} batches with 0.1s latency: {duration:.4f}s") + + @patch('main._api_post_form') + def test_push_rules_partial_failure(self, mock_post): + # Create enough hostnames for 5 batches + num_batches = 5 + hostnames = [f"host-{i}.com" for i in range(BATCH_SIZE * num_batches)] + + # Mock failure for some batches + call_count = 0 + def partial_failure(*args, **kwargs): + nonlocal call_count + call_count += 1 + # Fail batches 2 and 4 + if call_count in [2, 4]: + raise httpx.HTTPError("Simulated API failure") + return MagicMock(status_code=200) + + mock_post.side_effect = partial_failure + + success = push_rules( + self.profile_id, + self.folder_name, + self.folder_id, + self.do, + self.status, + hostnames, + self.existing_rules, + self.client + ) + + # Should return False when some batches fail + self.assertFalse(success) + self.assertEqual(mock_post.call_count, num_batches) + # Only 3 batches should have succeeded and updated existing_rules + self.assertEqual(len(self.existing_rules), BATCH_SIZE * 3) + + print(f"\n[Partial Failure Test] {mock_post.call_count} batches attempted, 3 succeeded") + +if __name__ == '__main__': + unittest.main()