Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import time
import re
import concurrent.futures
import threading
import ipaddress
import socket
from functools import lru_cache
Expand Down Expand Up @@ -357,23 +356,19 @@

def get_all_existing_rules(client: httpx.Client, profile_id: str) -> Set[str]:
all_rules = set()
all_rules_lock = threading.Lock()

def _fetch_folder_rules(folder_id: str):
def _fetch_folder_rules(folder_id: str) -> List[str]:
try:
data = _api_get(client, f"{API_BASE}/{profile_id}/rules/{folder_id}").json()
folder_rules = data.get("body", {}).get("rules", [])
# Optimization: Extract PKs locally to minimize lock contention time
local_pks = [rule["PK"] for rule in folder_rules if rule.get("PK")]
if local_pks:
with all_rules_lock:
all_rules.update(local_pks)
return [rule["PK"] for rule in folder_rules if rule.get("PK")]
except httpx.HTTPError:
pass
return []
except Exception as e:
# We log error but don't stop the whole process;
# individual folder failure shouldn't crash the sync
log.warning(f"Error fetching rules for folder {folder_id}: {e}")
return []

try:
# Get rules from root
Expand All @@ -392,11 +387,19 @@
# Parallelize fetching rules from folders.
# Using 5 workers to be safe with rate limits, though GETs are usually cheaper.
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(_fetch_folder_rules, folder_id)
future_to_folder = {
executor.submit(_fetch_folder_rules, folder_id): folder_id
for folder_name, folder_id in folders.items()
]
concurrent.futures.wait(futures)
}

for future in concurrent.futures.as_completed(future_to_folder):
try:
result = future.result()
if result:
all_rules.update(result)
except Exception as e:
folder_id = future_to_folder[future]
log.warning(f"Failed to fetch rules for folder ID {folder_id}: {e}")

Check warning

Code scanning / Prospector (reported by Codacy)

Use lazy % formatting in logging functions (logging-fstring-interpolation) Warning

Use lazy % formatting in logging functions (logging-fstring-interpolation)

Check notice

Code scanning / Pylintpython3 (reported by Codacy)

Use lazy % formatting in logging functions Note

Use lazy % formatting in logging functions

log.info(f"Total existing rules across all folders: {len(all_rules)}")
return all_rules
Expand Down Expand Up @@ -433,7 +436,7 @@
if USE_COLORS:
sys.stderr.write(f"\r{Colors.CYAN}⏳ Warming up cache: {completed}/{total}...{Colors.ENDC}")
sys.stderr.flush()

Check notice

Code scanning / Pylintpython3 (reported by Codacy)

Catching too general exception Exception Note

Catching too general exception Exception
try:
future.result()
except Exception as e:
Expand Down Expand Up @@ -489,7 +492,7 @@
# Check if it returned a list containing our group
if isinstance(body, dict) and "groups" in body:
for grp in body["groups"]:
if grp.get("group") == name:

Check notice

Code scanning / Pylint (reported by Codacy)

Catching too general exception Exception Note

Catching too general exception Exception
log.info("Created folder %s (ID %s) [Direct]", sanitize_for_log(name), grp["PK"])
return str(grp["PK"])
except Exception as e:
Expand All @@ -502,7 +505,7 @@
groups = data.get("body", {}).get("groups", [])

for grp in groups:
if grp["group"].strip() == name.strip():

Check warning

Code scanning / Pylintpython3 (reported by Codacy)

Variable name "e" doesn't conform to snake_case naming style Warning

Variable name "e" doesn't conform to snake_case naming style
log.info("Created folder %s (ID %s) [Polled]", sanitize_for_log(name), grp["PK"])
return str(grp["PK"])
except Exception as e:
Expand All @@ -529,7 +532,6 @@
hostnames: List[str],
existing_rules: Set[str],
client: httpx.Client,
existing_rules_lock: Optional[threading.Lock] = None,
) -> bool:
if not hostnames:
log.info("Folder %s - no rules to push", sanitize_for_log(folder_name))
Expand Down Expand Up @@ -564,7 +566,7 @@

total_batches = len(batches)

def process_batch(batch_idx: int, batch_data: List[str]) -> bool:
def process_batch(batch_idx: int, batch_data: List[str]) -> Optional[List[str]]:

Check warning

Code scanning / Pylint (reported by Codacy)

Missing function docstring Warning

Missing function docstring
data = {
"do": str(do),
"status": str(status),
Expand All @@ -581,19 +583,14 @@
"Folder %s – batch %d: added %d rules",
sanitize_for_log(folder_name), batch_idx, len(batch_data)
)
if existing_rules_lock:
with existing_rules_lock:
existing_rules.update(batch_data)
else:
existing_rules.update(batch_data)
return True
return batch_data
except httpx.HTTPError as e:
if USE_COLORS:
sys.stderr.write("\n")
log.error(f"Failed to push batch {batch_idx} for folder {sanitize_for_log(folder_name)}: {sanitize_for_log(e)}")
if hasattr(e, 'response') and e.response is not None:
log.debug(f"Response content: {e.response.text}")
return False
return None

# Optimization 3: Parallelize batch processing
# Using 3 workers to speed up writes without hitting aggressive rate limits.
Expand All @@ -604,8 +601,10 @@
}

for future in concurrent.futures.as_completed(futures):
if future.result():
result = future.result()
if result:
successful_batches += 1
existing_rules.update(result)

if USE_COLORS:
sys.stderr.write(f"\r{Colors.CYAN}🚀 Folder {sanitize_for_log(folder_name)}: Pushing batch {successful_batches}/{total_batches}...{Colors.ENDC}")
Expand All @@ -626,7 +625,6 @@
folder_data: Dict[str, Any],
profile_id: str,
existing_rules: Set[str],
existing_rules_lock: threading.Lock,
client: httpx.Client,
) -> bool:
grp = folder_data["group"]
Expand All @@ -647,11 +645,11 @@
do = action.get("do", 0)
status = action.get("status", 1)
hostnames = [r["PK"] for r in rule_group.get("rules", []) if r.get("PK")]
if not push_rules(profile_id, name, folder_id, do, status, hostnames, existing_rules, client, existing_rules_lock):
if not push_rules(profile_id, name, folder_id, do, status, hostnames, existing_rules, client):

Check warning

Code scanning / Pylintpython3 (reported by Codacy)

Line too long (106/100) Warning

Line too long (106/100)

Check warning

Code scanning / Pylint (reported by Codacy)

Line too long (106/100) Warning

Line too long (106/100)
folder_success = False
else:
hostnames = [r["PK"] for r in folder_data.get("rules", []) if r.get("PK")]
if not push_rules(profile_id, name, folder_id, main_do, main_status, hostnames, existing_rules, client, existing_rules_lock):
if not push_rules(profile_id, name, folder_id, main_do, main_status, hostnames, existing_rules, client):

Check warning

Code scanning / Pylintpython3 (reported by Codacy)

Line too long (112/100) Warning

Line too long (112/100)

Check warning

Code scanning / Pylint (reported by Codacy)

Line too long (112/100) Warning

Line too long (112/100)
folder_success = False

return folder_success
Expand Down Expand Up @@ -734,7 +732,6 @@

# Create new folders and push rules
success_count = 0
existing_rules_lock = threading.Lock()

# CRITICAL FIX: Switch to Serial Processing (1 worker)
# This prevents API rate limits and ensures stability for large folders.
Expand Down Expand Up @@ -767,7 +764,6 @@
folder_data,
profile_id,
existing_rules,
existing_rules_lock,
client # Pass the persistent client
): folder_data
for folder_data in folder_data_list
Expand All @@ -789,7 +785,7 @@
log.error(f"Unexpected error during sync for profile {profile_id}: {sanitize_for_log(e)}")
return False

# --------------------------------------------------------------------------- #

Check warning

Code scanning / Pylint (reported by Codacy)

Variable name "e" doesn't conform to snake_case naming style Warning

Variable name "e" doesn't conform to snake_case naming style
# 5. Entry-point
# --------------------------------------------------------------------------- #
def parse_args() -> argparse.Namespace:
Expand Down
40 changes: 24 additions & 16 deletions test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import sys
import importlib
import threading
from unittest.mock import MagicMock, call, patch
import pytest
import main
Expand Down Expand Up @@ -32,7 +31,7 @@
m = reload_main_with_env(monkeypatch, no_color=None, isatty=False)
assert m.USE_COLORS is False

# Case 2: get_all_existing_rules updates all_rules set correctly with locking optimization
# Case 2: get_all_existing_rules updates all_rules set correctly without locking
def test_get_all_existing_rules_updates_correctly(monkeypatch):
# Setup
m = reload_main_with_env(monkeypatch, no_color="1") # Disable colors for simplicity
Expand All @@ -56,26 +55,12 @@

monkeypatch.setattr(m, "_api_get", side_effect)

# Spy on threading.Lock
mock_lock_instance = MagicMock()
mock_lock_instance.__enter__.return_value = None
mock_lock_instance.__exit__.return_value = None
mock_lock_cls = MagicMock(return_value=mock_lock_instance)
monkeypatch.setattr(threading, "Lock", mock_lock_cls)

# Execution
rules = m.get_all_existing_rules(mock_client, profile_id)

# Verification
expected_rules = {"rule_root", "rule_A1", "rule_A2", "rule_B1"}
assert rules == expected_rules

# Verify lock usage.
# Since get_all_existing_rules creates a lock: `all_rules_lock = threading.Lock()`
# and then uses `with all_rules_lock:` inside the worker `_fetch_folder_rules`.
# We expect the lock to be acquired.
assert mock_lock_cls.called
assert mock_lock_instance.__enter__.called

# Case 3: push_rules updates data dictionary with pre-calculated batch keys correctly
def test_push_rules_updates_data_with_batch_keys(monkeypatch):
Expand All @@ -94,7 +79,7 @@
folder_id="fid1",
do=1,
status=1,
hostnames=hostnames,

Check warning

Code scanning / Pylintpython3 (reported by Codacy)

Variable name "m" doesn't conform to snake_case naming style Warning test

Variable name "m" doesn't conform to snake_case naming style
existing_rules=set(),
client=mock_client
)
Expand All @@ -111,6 +96,29 @@
assert data_sent["do"] == "1"
assert data_sent["group"] == "fid1"

# Case 3b: push_rules updates existing_rules set correctly
def test_push_rules_updates_existing_rules(monkeypatch):

Check warning

Code scanning / Pylintpython3 (reported by Codacy)

Missing function or method docstring Warning test

Missing function or method docstring

Check warning

Code scanning / Pylint (reported by Codacy)

Missing function docstring Warning test

Missing function docstring
m = reload_main_with_env(monkeypatch)

Check warning

Code scanning / Pylint (reported by Codacy)

Variable name "m" doesn't conform to snake_case naming style Warning test

Variable name "m" doesn't conform to snake_case naming style
mock_client = MagicMock()
monkeypatch.setattr(m, "_api_post_form", MagicMock())

hostnames = ["h1", "h2"]
existing_rules = set()

m.push_rules(
profile_id="p1",
folder_name="f1",
folder_id="fid1",
do=1,
status=1,
hostnames=hostnames,
existing_rules=existing_rules,
client=mock_client
)

assert "h1" in existing_rules

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

Check notice

Code scanning / Bandit (reported by Codacy)

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert "h2" in existing_rules

Check notice

Code scanning / Bandit

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

Check notice

Code scanning / Bandit (reported by Codacy)

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.

# Case 4: push_rules logs info conditionally based on USE_COLORS flag
def test_push_rules_logs_conditionally_use_colors(monkeypatch):
# Test when USE_COLORS is False
Expand Down
Loading