Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion aw_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import socket
import threading
import warnings
from collections import namedtuple
from datetime import datetime
from time import sleep
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
self.request_queue = RequestQueue(self)
# Dict of each last heartbeat in each bucket
self.last_heartbeat = {} # type: Dict[str, Event]
self._warned_queue_before_connect = False

#
# Get/Post base requests
Expand Down Expand Up @@ -243,6 +245,7 @@ def heartbeat(
_commit_interval = commit_interval or self.commit_interval

if queued:
self._warn_queue_before_connect()
# Pre-merge heartbeats
if bucket_id not in self.last_heartbeat:
self.last_heartbeat[bucket_id] = event
Expand Down Expand Up @@ -278,6 +281,7 @@ def get_buckets(self) -> dict:

def create_bucket(self, bucket_id: str, event_type: str, queued=False):
if queued:
self._warn_queue_before_connect()
self.request_queue.register_bucket(bucket_id, event_type)
else:
endpoint = f"buckets/{bucket_id}"
Expand Down Expand Up @@ -380,6 +384,8 @@ def disconnect(self):

# Throw away old thread object, create new one since same thread cannot be started twice
self.request_queue = RequestQueue(self)
# Reset so warn-before-connect fires again if user calls queued ops before reconnecting
self._warned_queue_before_connect = False

def wait_for_start(self, timeout: int = 10) -> None:
"""Wait for the server to start by trying to get the server info."""
Expand All @@ -395,6 +401,18 @@ def wait_for_start(self, timeout: int = 10) -> None:
else:
raise Exception(f"Server at {self.server_address} did not start in time")

def _warn_queue_before_connect(self) -> None:
if self._warned_queue_before_connect or self.request_queue.is_alive():
return

warnings.warn(
"Queued requests require calling connect() or using `with client:` "
"before buckets can be created and queued events can flush.",
UserWarning,
stacklevel=3,
)
self._warned_queue_before_connect = True


QueuedRequest = namedtuple("QueuedRequest", ["endpoint", "data"])
Bucket = namedtuple("Bucket", ["id", "type"])
Expand Down Expand Up @@ -554,4 +572,13 @@ def add_request(self, endpoint: str, data: dict) -> None:
self._persistqueue.put(QueuedRequest(endpoint, data))

def register_bucket(self, bucket_id: str, event_type: str) -> None:
self._registered_buckets.append(Bucket(bucket_id, event_type))
bucket = Bucket(bucket_id, event_type)
self._registered_buckets.append(bucket)

if not self.connected:
return

try:
self.client.create_bucket(bucket_id, event_type)
except req.RequestException:
self.connected = False
22 changes: 22 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import time
import warnings
from random import random
from datetime import datetime, timedelta, timezone
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -106,3 +107,24 @@ def test_full():

# Delete bucket
client.delete_bucket(bucket_name)


def test_queued_usage_warns_once_before_connect():
client = ActivityWatchClient(f"aw-test-client-{random()}", testing=True)

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.create_bucket("test-bucket", "test", queued=True)
client.heartbeat(
"test-bucket",
create_unique_event(),
pulsetime=1,
queued=True,
)

queue_warnings = [
warning
for warning in caught
if "connect()" in str(warning.message) or "with client:" in str(warning.message)
]
assert len(queue_warnings) == 1
28 changes: 28 additions & 0 deletions tests/test_requestqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ class MockClient:

def __init__(self):
self.testing = True
self.create_bucket_calls = []

def get_buckets(self, *args, **kwargs):
print("Called get_buckets")
return [{"id": "test", "name": "Test"}]

def create_bucket(self, *args, **kwargs):
self.create_bucket_calls.append((args, kwargs))
print("Called create_bucket")

def _post(self, *args, **kwargs):
Expand Down Expand Up @@ -60,3 +62,29 @@ def test_complex():
sleep(1)
rq.stop()
rq.join()


def test_register_bucket_creates_immediately_when_connected():
client = MockClient()
rq = RequestQueue(client) # type: ignore
rq.connected = True

rq.register_bucket("test-bucket", "test-type")

assert client.create_bucket_calls == [(("test-bucket", "test-type"), {})]


def test_register_bucket_marks_queue_disconnected_on_create_failure():
class FailingClient(MockClient):
def create_bucket(self, *args, **kwargs):
super().create_bucket(*args, **kwargs)
raise requests.exceptions.ConnectionError()

client = FailingClient()
rq = RequestQueue(client) # type: ignore
rq.connected = True

rq.register_bucket("test-bucket", "test-type")

assert rq.connected is False
assert client.create_bucket_calls == [(("test-bucket", "test-type"), {})]
Loading