Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion trollmoves/movers.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def move(self):
def copy(self):
"""Upload the file."""
from scp import SCPClient
from paramiko import SSHException

ssh_connection = self.get_connection(self.destination.hostname,
self.destination.port or 22,
Expand All @@ -371,7 +372,23 @@ def copy(self):
raise

try:
scp.put(self.origin, self.destination.path)
destination = self.destination.path
remote_tmp = self.attrs.get("remote_tmp", None)
if remote_tmp:
destination = os.path.join(destination, '.' + os.path.basename(self.origin))
scp.put(self.origin, destination)

if remote_tmp:
timeout = self.attrs.get("ssh_connection_timeout", None)
_remote_orig = os.path.join(self.destination.path, os.path.basename(self.origin))
_cmd = f"mv {destination} {_remote_orig}"
(_, out_ret, err_ret) = ssh_connection.exec_command(_cmd, timeout=timeout)
out_lines = out_ret.readlines()
for line in out_lines:
LOGGER.debug("Remote rename stdout: %s ", str(line))
err_lines = err_ret.readlines()
for line in err_lines:
LOGGER.error("Remote rename stderr: %s ", str(line))
except OSError as osex:
if osex.errno == 2:
LOGGER.error("No such file or directory. File not transfered: "
Expand All @@ -380,6 +397,8 @@ def copy(self):
else:
LOGGER.error("OSError in scp.put: %s", str(osex))
raise
except SSHException as sshe:
LOGGER.exception("Failed to rename from tmp name: %s", str(sshe))
except Exception as err:
LOGGER.error("Something went wrong with scp: %s", str(err))
LOGGER.error("Exception name %s", type(err).__name__)
Expand Down
89 changes: 87 additions & 2 deletions trollmoves/tests/test_ssh_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Test the ssh server."""

import os
import sys
import logging
import shutil
from unittest.mock import Mock, MagicMock, patch
import unittest
Expand All @@ -30,14 +33,24 @@

from paramiko import SSHException
import pytest
import logging
import socket
import sys

import trollmoves

logger = logging.getLogger()


class MockChannel:
def __init__(self, content=None):
self.content = [] if not content else [content]

def __str__(self) -> str:
return str(self.content)

def readlines(self):
return self.content


class TestSSHMovers(unittest.TestCase):
"""Tests for SSH Mover."""

Expand Down Expand Up @@ -280,6 +293,78 @@ def test_scp_move(self, mock_scp_client, mock_sshclient):

mocked_scp_client.put.assert_called_once_with(self.origin, urlparse(self.destination_no_port).path)

@patch('trollmoves.movers.ScpMover.get_connection')
@patch('paramiko.SSHClient.connect')
@patch('scp.SCPClient', autospec=True)
def test_scp_copy_via_remote_tmp2(self, mock_scp_client, mock_sshconnect, mock_sshexec):
"""Check scp copy using remote temporary file."""
from trollmoves.movers import ScpMover

mocked_scp_client = MagicMock()
mock_scp_client.return_value = mocked_scp_client
mock_sshexec.return_value.exec_command.return_value = [(None), (MockChannel()), (MockChannel())]
scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True})
scp_mover.copy()

tmp_bn = os.path.join(urlparse(self.destination_no_port).path,
"." + os.path.basename(self.origin))
mocked_scp_client.put.assert_called_once_with(self.origin, tmp_bn)
final_remote = os.path.join(urlparse(self.destination_no_port).path,
os.path.basename(self.origin))
_cmd = f"mv {tmp_bn} {final_remote}"
mock_sshexec.return_value.exec_command.assert_called_once_with(_cmd, timeout=None)

@patch('trollmoves.movers.ScpMover.get_connection')
@patch('paramiko.SSHClient.connect')
@patch('scp.SCPClient', autospec=True)
def test_scp_copy_via_remote_tmp_return_values(self, mock_scp_client, mock_sshconnect, mock_sshexec):
"""Check scp copy using remote temporary file."""
from trollmoves.movers import ScpMover
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)

mocked_scp_client = MagicMock()
mock_scp_client.return_value = mocked_scp_client
mock_sshexec.return_value.exec_command.return_value = [(None), (MockChannel("stdout")), (MockChannel("stderr"))]
try:
with self.assertLogs(logger, level=logging.DEBUG) as lc:
scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True})
scp_mover.copy()
self.assertIn("Remote rename stdout: stdout", "".join(lc.output))
self.assertIn("Remote rename stderr: stderr", "".join(lc.output))
finally:
logger.removeHandler(stream_handler)

tmp_bn = os.path.join(urlparse(self.destination_no_port).path,
"." + os.path.basename(self.origin))
mocked_scp_client.put.assert_called_once_with(self.origin, tmp_bn)
final_remote = os.path.join(urlparse(self.destination_no_port).path,
os.path.basename(self.origin))
_cmd = f"mv {tmp_bn} {final_remote}"
mock_sshexec.return_value.exec_command.assert_called_once_with(_cmd, timeout=None)

@patch('trollmoves.movers.ScpMover.get_connection')
@patch('paramiko.SSHClient.connect')
@patch('scp.SCPClient', autospec=True)
def test_scp_copy_via_remote_tmp_exception(self, mock_scp_client, mock_sshconnect, mock_sshexec):
"""Check scp copy using remote temporary file."""
from trollmoves.movers import ScpMover
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)

mocked_scp_client = MagicMock()
mock_scp_client.return_value = mocked_scp_client
mock_sshexec.return_value.exec_command.side_effect = MagicMock(side_effect=SSHException)
try:
with self.assertLogs(logger, level=logging.DEBUG) as lc:
scp_mover = ScpMover(self.origin, self.destination_no_port, attrs={'remote_tmp': True})
scp_mover.copy()
self.assertIn("Failed to rename from tmp name:", "".join(lc.output))
finally:
logger.removeHandler(stream_handler)


if __name__ == '__main__':
unittest.main()