Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion posttroll/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import socket
from datetime import datetime, timedelta
from threading import Lock
import time

from six.moves.urllib.parse import urlsplit, urlunsplit
import six
import zmq
Expand All @@ -37,6 +39,8 @@
from posttroll.message_broadcaster import sendaddressservice

LOGGER = logging.getLogger(__name__)
BIND_RETRIES = max(0, int(os.environ.get("PYTROLL_BIND_RETRIES", 5)))
BIND_RETRY_TIMEOUT = float(os.environ.get("PYTROLL_BIND_RETRY_TIMEOUT", 0.1))


def get_own_ip():
Expand Down Expand Up @@ -121,7 +125,7 @@ def __init__(self, address, name="", min_port=None, max_port=None):
self.destination = urlunsplit((u__.scheme, netloc, u__.path,
u__.query, u__.fragment))
else:
self.publish.bind(self.destination)
self._bind_destination()
self.port_number = port

LOGGER.info("publisher started on port %s", str(self.port_number))
Expand All @@ -130,6 +134,18 @@ def __init__(self, address, name="", min_port=None, max_port=None):
self._heartbeat = None
self._pub_lock = Lock()

def _bind_destination(self):
"""Bind publish destination."""
last_error = ""
for _ in range(BIND_RETRIES + 1):
try:
self.publish.bind(self.destination)
return
except zmq.error.ZMQError as err:
last_error = err.strerror
time.sleep(BIND_RETRY_TIMEOUT)
raise OSError("Could not bind %s - %s" % (self.destination, last_error))

def send(self, msg):
"""Send the given message.
"""
Expand Down
69 changes: 61 additions & 8 deletions posttroll/tests/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
# Copyright (c) 2014 Martin Raspaud

#
# Author(s):

#
# Martin Raspaud <martin.raspaud@smhi.se>

# Panu Lahtinen <panu.lahtinen@fmi.fi>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Test the publishing and subscribing facilities.
"""
"""Test the publishing and subscribing facilities."""
import unittest
from unittest import mock
from datetime import timedelta
from threading import Thread, Lock
import time

import six
import pytest

test_lock = Lock()

Expand Down Expand Up @@ -314,6 +315,56 @@ def test_pub_minmax_port(self):
self.assertEqual(res, port)
break

@mock.patch("posttroll.publisher.get_context")
def test_bind_retries(self, get_context):
"""Test that the destination bind is retried on failure."""
from zmq.error import ZMQError
from posttroll.publisher import Publish, BIND_RETRIES
context = mock.MagicMock()
context.bind.side_effect = ZMQError("mocked failure")
get_context.return_value.socket.return_value = context

with pytest.raises(OSError) as err:
with Publish("test_bind_retries", port=50000):
pass
assert context.bind.call_count == BIND_RETRIES + 1
assert "Could not bind" in err.value.args[0]
assert "50000" in err.value.args[0]

@mock.patch("posttroll.publisher.BIND_RETRIES", 0)
@mock.patch("posttroll.publisher.get_context")
def test_bind_no_retries(self, get_context):
"""Test that the destination bind retries can be turned off."""
from zmq.error import ZMQError
from posttroll.publisher import Publish, BIND_RETRIES
# Just ensure the mock sets the variable correctly
assert BIND_RETRIES == 0
context = mock.MagicMock()
context.bind.side_effect = ZMQError("mocked failure")
get_context.return_value.socket.return_value = context

with pytest.raises(OSError) as err:
with Publish("test_bind_retries", port=50000):
pass
assert context.bind.call_count == 1

def test_bind_retries_env_variable(self):
"""Test that the retry count env variable is handled correctly."""
import os
os.environ["PYTROLL_BIND_RETRIES"] = "-1"
from posttroll.publisher import BIND_RETRIES

assert BIND_RETRIES == 0

def test_bind_retry_timeout_env_variable(self):
"""Test that the retry timeout env variable is handled correctly."""
import os
val = 0.3
os.environ["PYTROLL_BIND_RETRY_TIMEOUT"] = str(val)
from posttroll.publisher import BIND_RETRY_TIMEOUT

assert BIND_RETRY_TIMEOUT == val


def _get_port(min_port=None, max_port=None):
from zmq.error import ZMQError
Expand Down Expand Up @@ -400,4 +451,6 @@ def suite():
mysuite.addTest(loader.loadTestsFromTestCase(TestListenerContainer))
mysuite.addTest(loader.loadTestsFromTestCase(TestPub))
mysuite.addTest(loader.loadTestsFromTestCase(TestAddressReceiver))
mysuite.addTest(loader.loadTestsFromTestCase(TestBindRetryEnvVariable))
mysuite.addTest(loader.loadTestsFromTestCase(TestBindRetryTimeoutEnvVariable))
return mysuite