diff --git a/posttroll/publisher.py b/posttroll/publisher.py index 501e0fd..36276e8 100644 --- a/posttroll/publisher.py +++ b/posttroll/publisher.py @@ -28,6 +28,8 @@ import socket from datetime import datetime, timedelta from threading import Lock +import time + from six.moves.urllib.parse import urlsplit, urlunsplit import six import zmq @@ -37,6 +39,8 @@ from posttroll.message_broadcaster import sendaddressservice LOGGER = logging.getLogger(__name__) +BIND_RETRIES = max(0, int(os.environ.get("PYTROLL_BIND_RETRIES", 5))) +BIND_RETRY_TIMEOUT = float(os.environ.get("PYTROLL_BIND_RETRY_TIMEOUT", 0.1)) def get_own_ip(): @@ -121,7 +125,7 @@ def __init__(self, address, name="", min_port=None, max_port=None): self.destination = urlunsplit((u__.scheme, netloc, u__.path, u__.query, u__.fragment)) else: - self.publish.bind(self.destination) + self._bind_destination() self.port_number = port LOGGER.info("publisher started on port %s", str(self.port_number)) @@ -130,6 +134,18 @@ def __init__(self, address, name="", min_port=None, max_port=None): self._heartbeat = None self._pub_lock = Lock() + def _bind_destination(self): + """Bind publish destination.""" + last_error = "" + for _ in range(BIND_RETRIES + 1): + try: + self.publish.bind(self.destination) + return + except zmq.error.ZMQError as err: + last_error = err.strerror + time.sleep(BIND_RETRY_TIMEOUT) + raise OSError("Could not bind %s - %s" % (self.destination, last_error)) + def send(self, msg): """Send the given message. """ diff --git a/posttroll/tests/test_pubsub.py b/posttroll/tests/test_pubsub.py index 0b842df..bbc9f35 100644 --- a/posttroll/tests/test_pubsub.py +++ b/posttroll/tests/test_pubsub.py @@ -1,27 +1,27 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +# # Copyright (c) 2014 Martin Raspaud - +# # Author(s): - +# # Martin Raspaud - +# Panu Lahtinen +# # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. - +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. - +# # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Test the publishing and subscribing facilities. -""" +"""Test the publishing and subscribing facilities.""" import unittest from unittest import mock from datetime import timedelta @@ -29,6 +29,7 @@ import time import six +import pytest test_lock = Lock() @@ -314,6 +315,56 @@ def test_pub_minmax_port(self): self.assertEqual(res, port) break + @mock.patch("posttroll.publisher.get_context") + def test_bind_retries(self, get_context): + """Test that the destination bind is retried on failure.""" + from zmq.error import ZMQError + from posttroll.publisher import Publish, BIND_RETRIES + context = mock.MagicMock() + context.bind.side_effect = ZMQError("mocked failure") + get_context.return_value.socket.return_value = context + + with pytest.raises(OSError) as err: + with Publish("test_bind_retries", port=50000): + pass + assert context.bind.call_count == BIND_RETRIES + 1 + assert "Could not bind" in err.value.args[0] + assert "50000" in err.value.args[0] + + @mock.patch("posttroll.publisher.BIND_RETRIES", 0) + @mock.patch("posttroll.publisher.get_context") + def test_bind_no_retries(self, get_context): + """Test that the destination bind retries can be turned off.""" + from zmq.error import ZMQError + from posttroll.publisher import Publish, BIND_RETRIES + # Just ensure the mock sets the variable correctly + assert BIND_RETRIES == 0 + context = mock.MagicMock() + context.bind.side_effect = ZMQError("mocked failure") + get_context.return_value.socket.return_value = context + + with pytest.raises(OSError) as err: + with Publish("test_bind_retries", port=50000): + pass + assert context.bind.call_count == 1 + + def test_bind_retries_env_variable(self): + """Test that the retry count env variable is handled correctly.""" + import os + os.environ["PYTROLL_BIND_RETRIES"] = "-1" + from posttroll.publisher import BIND_RETRIES + + assert BIND_RETRIES == 0 + + def test_bind_retry_timeout_env_variable(self): + """Test that the retry timeout env variable is handled correctly.""" + import os + val = 0.3 + os.environ["PYTROLL_BIND_RETRY_TIMEOUT"] = str(val) + from posttroll.publisher import BIND_RETRY_TIMEOUT + + assert BIND_RETRY_TIMEOUT == val + def _get_port(min_port=None, max_port=None): from zmq.error import ZMQError @@ -400,4 +451,6 @@ def suite(): mysuite.addTest(loader.loadTestsFromTestCase(TestListenerContainer)) mysuite.addTest(loader.loadTestsFromTestCase(TestPub)) mysuite.addTest(loader.loadTestsFromTestCase(TestAddressReceiver)) + mysuite.addTest(loader.loadTestsFromTestCase(TestBindRetryEnvVariable)) + mysuite.addTest(loader.loadTestsFromTestCase(TestBindRetryTimeoutEnvVariable)) return mysuite