diff --git a/common/gratia/common/sandbox_mgmt.py b/common/gratia/common/sandbox_mgmt.py index 90906a9f..753c8218 100644 --- a/common/gratia/common/sandbox_mgmt.py +++ b/common/gratia/common/sandbox_mgmt.py @@ -5,6 +5,7 @@ import glob import time import shutil +import tempfile import tarfile from gratia.common.config import ConfigProxy @@ -409,18 +410,8 @@ def SearchOutstandingRecord(): def GenerateFilename(prefix, current_dir): '''Generate a filename of the for current_dir/prefix.$pid.ConfigFragment.gratia.xml__Unique''' - filename = prefix + str(global_state.RecordPid) + '.' + Config.get_GratiaExtension() \ - + '__XXXXXXXXXX' - filename = os.path.join(current_dir, filename) - mktemp_pipe = os.popen('mktemp -q "' + filename + '"') - if mktemp_pipe != None: - filename = mktemp_pipe.readline() - mktemp_pipe.close() - filename = filename.strip() - if filename != r'': - return filename - - raise IOError + fn_prefix = f'{prefix}.{global_state.RecordPid}.{Config.get_GratiaExtension()}__' + return tempfile.NamedTemporaryFile(prefix=fn_prefix, dir=current_dir, delete=False, mode='w') def UncompressOutbox(staging_name, target_dir): @@ -487,20 +478,21 @@ def CompressOutbox(probe_dir, outbox, outfiles): DebugPrint(0, msg + ':' + exc) raise InternalError(msg) from exc - staging_name = GenerateFilename('tz.', staged_store) - DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name) + with GenerateFilename('tz', staged_store) as temp_tarfile: + staging_name = temp_tarfile.name + DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name) - try: - tar = tarfile.open(staging_name, 'w:bz2') - except KeyboardInterrupt: - raise - except SystemExit: - raise - except Exception as e: - DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':') - DebugPrint(0, 'Caught exception: ', e) - DebugPrintTraceback() - return False + try: + tar = tarfile.open(staging_name, 'w:bz2') + except KeyboardInterrupt: + raise + except SystemExit: + raise + except Exception as e: + DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':') + DebugPrint(0, 'Caught exception: ', e) + DebugPrintTraceback() + return False try: for f in outfiles: @@ -599,12 +591,11 @@ def OpenNewRecordFile(dirIndex): raise InternalError(msg) from exc try: - filename = GenerateFilename('r.', working_dir) - DebugPrint(3, 'Creating file:', filename) - outstandingRecordCount += 1 - f = open(filename, 'w') - dirIndex = index - return (f, dirIndex) + with GenerateFilename('r', working_dir) as recordfile: + DebugPrint(3, 'Creating file:', recordfile.name) + outstandingRecordCount += 1 + dirIndex = index + return (recordfile, dirIndex) except Exception as exc: msg = 'ERROR: Caught exception while creating file' DebugPrint(0, msg + ': ', exc) diff --git a/common/gratia/common/xml_utils.py b/common/gratia/common/xml_utils.py index c7c74868..d4e74599 100644 --- a/common/gratia/common/xml_utils.py +++ b/common/gratia/common/xml_utils.py @@ -280,10 +280,8 @@ def UsageCheckXmldoc(xmlDoc, external, resourceType=None): subdir = os.path.join(Config.get_DataFolder(), "quarantine", 'subdir.' + Config.getFilenameFragment()) if not os.path.exists(subdir): os.mkdir(subdir) - fn = sandbox_mgmt.GenerateFilename("r.", subdir) - writer = open(fn, 'w') - usageRecord.writexml(writer) - writer.close() + with sandbox_mgmt.GenerateFilename("r", subdir) as writer: + usageRecord.writexml(writer) usageRecord.unlink() continue diff --git a/test/test_sandbox_mgmt.py b/test/test_sandbox_mgmt.py new file mode 100644 index 00000000..4647810a --- /dev/null +++ b/test/test_sandbox_mgmt.py @@ -0,0 +1,153 @@ +#!/bin/env python + +import glob +import os +import shutil +import tarfile +import tempfile +import unittest +from unittest.mock import patch, PropertyMock +from unittest import TextTestRunner + +from common.gratia.common import sandbox_mgmt + +class SandboxMgmtTests(unittest.TestCase): + + @patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value='test-extension') + def test_GenerateFilename(self, mock_config): + """GenerateFilename creates a temporary file and returns the path to the file + """ + prefix = 'test-prefix' + temp_dir = '/tmp' + + try: + with sandbox_mgmt.GenerateFilename(prefix, temp_dir) as filename: + self.assertTrue(os.path.exists(filename.name), + f'Failed to create temporary file ({filename.name})') + self.assertEqual(temp_dir.rstrip('/'), + os.path.dirname(filename.name), + f'Temporary file {filename.name} placed in the wrong directory') + self.assertRegex(filename.name, + rf'{temp_dir}/*{prefix}\.\d+\.{mock_config.return_value}__\w+', + 'Unexpected file name format') + finally: + try: + filename.close() + os.remove(filename.name) + except (FileNotFoundError, NameError): + # don't need to clean up what's not there + pass + +class CompressOutboxTests(unittest.TestCase): + def setUp(self): + # provision test environment + gratia_ex = patch('gratia.common.config.ConfigProxy.get_GratiaExtension', + create=True, return_value='test-extension') + file_frag = patch('gratia.common.config.ConfigProxy.getFilenameFragment', + create=True, return_value='test-filename') + + self.mock_gratia_ex = gratia_ex.start() + self.mock_file_frag = file_frag.start() + + self.probe_dir = tempfile.mkdtemp() + self.outbox = os.path.join(self.probe_dir, 'outbox') + os.makedirs(self.outbox, exist_ok=True) + self.outfiles = ['testfile1', 'testfile2'] + + # add content to the files + for testfile in self.outfiles: + content = testfile + ' contains this content' + with open(os.path.join(self.outbox, testfile), 'w', encoding="utf-8") as test: + test.write(content) + + self.addCleanup(gratia_ex.stop) + self.addCleanup(file_frag.stop) + + def tearDown(self): + # Remove probe_dir after test + shutil.rmtree(self.probe_dir) + + def get_tarball_location(self, path_to_tarball, tarball): + """ + Attempts to return exact location of tarball + """ + try: + tarball_location = os.path.join(f'{path_to_tarball}', tarball[0]) + except IndexError as notarball: + print("Tarball does not exist!") + self.fail(notarball) + + return tarball_location + + def test_compress_outbox(self): + """CompressOutbox compresses the files in the outbox directory + and stores the resulting tarball in probe_dir/staged. + """ + # Assert that CompressOutbox returns True + result = sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles) + self.assertTrue(result) + + def test_tarball_creation(self): + """ + Assert that tarball is created in the correct location + """ + sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles) + + path_to_tarball = f'{self.probe_dir}/staged/store' + # Finds exactly one tarball that matches GenerateFilename function output + tarball = glob.glob("tz.*.test-extension__*", root_dir=path_to_tarball) + + # Where tarball exists + tarball_location = self.get_tarball_location(path_to_tarball, tarball) + + # Counts files in the directory that the tarball should be in + tarball_count = len((tarball)) + + self.assertTrue(os.path.exists(tarball_location), + 'Tarball not created in correct location') + self.assertEqual(tarball_count, 1, + f'Expected 1 tarball, found {tarball_count}') + + def test_tarball_contents(self): + """ + Assert that unpacked tarball contains files from outfiles + """ + sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles) + path_to_tarball = f'{self.probe_dir}/staged/store' + + # Finds tarball that matches GenerateFilename() output + tarball = glob.glob("tz.*.test-extension__*", root_dir=path_to_tarball) + + # Where tarball exists + tarball_location = self.get_tarball_location(path_to_tarball, tarball) + + # Gets names of files within tarball + with tarfile.open(tarball_location, "r") as names: + + # Names of files in tarball + namelist = names.getnames() + + # Sort both lists to ensure order-independent comparison + namelist.sort() + self.outfiles.sort() + + # Open files in outfiles + expected_files1 = open(os.path.join(f'{self.outbox}/{self.outfiles[0]}'), 'rb') + expected_files2 = open(os.path.join(f'{self.outbox}/{self.outfiles[1]}'), 'rb') + + # Extract the contents from testfile1 + file1 = names.extractfile(namelist[0]) + file1_contents = file1.readlines() + expected_results_f1 = expected_files1.readlines() + + # Extract the contents from testfile2 + file2 = names.extractfile(namelist[1]) + file2_contents = file2.readlines() + expected_results_f2 = expected_files2.readlines() + + self.assertListEqual(namelist, self.outfiles, + 'Unexpected file names in tarball') + self.assertEqual(file1_contents, expected_results_f1, + 'Unexpected content in file1') + self.assertEqual(file2_contents, expected_results_f2, + 'Unexpected content in file2')