Skip to content
Open
53 changes: 22 additions & 31 deletions common/gratia/common/sandbox_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glob
import time
import shutil
import tempfile
import tarfile

from gratia.common.config import ConfigProxy
Expand Down Expand Up @@ -409,18 +410,8 @@ def SearchOutstandingRecord():

def GenerateFilename(prefix, current_dir):
'''Generate a filename of the for current_dir/prefix.$pid.ConfigFragment.gratia.xml__Unique'''
filename = prefix + str(global_state.RecordPid) + '.' + Config.get_GratiaExtension() \
+ '__XXXXXXXXXX'
filename = os.path.join(current_dir, filename)
mktemp_pipe = os.popen('mktemp -q "' + filename + '"')
if mktemp_pipe != None:
filename = mktemp_pipe.readline()
mktemp_pipe.close()
filename = filename.strip()
if filename != r'':
return filename

raise IOError
fn_prefix = f'{prefix}.{global_state.RecordPid}.{Config.get_GratiaExtension()}__'
return tempfile.NamedTemporaryFile(prefix=fn_prefix, dir=current_dir, delete=False, mode='w')

def UncompressOutbox(staging_name, target_dir):

Expand Down Expand Up @@ -487,20 +478,21 @@ def CompressOutbox(probe_dir, outbox, outfiles):
DebugPrint(0, msg + ':' + exc)
raise InternalError(msg) from exc

staging_name = GenerateFilename('tz.', staged_store)
DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name)
with GenerateFilename('tz', staged_store) as temp_tarfile:
staging_name = temp_tarfile.name
DebugPrint(1, 'Compressing outbox in tar.bz2 file: ' + staging_name)

try:
tar = tarfile.open(staging_name, 'w:bz2')
except KeyboardInterrupt:
raise
except SystemExit:
raise
except Exception as e:
DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':')
DebugPrint(0, 'Caught exception: ', e)
DebugPrintTraceback()
return False
try:
tar = tarfile.open(staging_name, 'w:bz2')
except KeyboardInterrupt:
raise
except SystemExit:
raise
except Exception as e:
DebugPrint(0, 'Warning: Exception caught while opening tar.bz2 file: ' + staging_name + ':')
DebugPrint(0, 'Caught exception: ', e)
DebugPrintTraceback()
return False

try:
for f in outfiles:
Expand Down Expand Up @@ -599,12 +591,11 @@ def OpenNewRecordFile(dirIndex):
raise InternalError(msg) from exc

try:
filename = GenerateFilename('r.', working_dir)
DebugPrint(3, 'Creating file:', filename)
outstandingRecordCount += 1
f = open(filename, 'w')
dirIndex = index
return (f, dirIndex)
with GenerateFilename('r', working_dir) as recordfile:
DebugPrint(3, 'Creating file:', recordfile.name)
outstandingRecordCount += 1
dirIndex = index
return (recordfile, dirIndex)
except Exception as exc:
msg = 'ERROR: Caught exception while creating file'
DebugPrint(0, msg + ': ', exc)
Expand Down
6 changes: 2 additions & 4 deletions common/gratia/common/xml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,8 @@ def UsageCheckXmldoc(xmlDoc, external, resourceType=None):
subdir = os.path.join(Config.get_DataFolder(), "quarantine", 'subdir.' + Config.getFilenameFragment())
if not os.path.exists(subdir):
os.mkdir(subdir)
fn = sandbox_mgmt.GenerateFilename("r.", subdir)
writer = open(fn, 'w')
usageRecord.writexml(writer)
writer.close()
with sandbox_mgmt.GenerateFilename("r", subdir) as writer:
usageRecord.writexml(writer)
usageRecord.unlink()
continue

Expand Down
153 changes: 153 additions & 0 deletions test/test_sandbox_mgmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#!/bin/env python

import glob
import os
import shutil
import tarfile
import tempfile
import unittest
from unittest.mock import patch, PropertyMock
from unittest import TextTestRunner

from common.gratia.common import sandbox_mgmt

class SandboxMgmtTests(unittest.TestCase):

@patch('gratia.common.config.ConfigProxy.get_GratiaExtension', create=True, return_value='test-extension')
def test_GenerateFilename(self, mock_config):
"""GenerateFilename creates a temporary file and returns the path to the file
"""
prefix = 'test-prefix'
temp_dir = '/tmp'

try:
with sandbox_mgmt.GenerateFilename(prefix, temp_dir) as filename:
self.assertTrue(os.path.exists(filename.name),
f'Failed to create temporary file ({filename.name})')
self.assertEqual(temp_dir.rstrip('/'),
os.path.dirname(filename.name),
f'Temporary file {filename.name} placed in the wrong directory')
self.assertRegex(filename.name,
rf'{temp_dir}/*{prefix}\.\d+\.{mock_config.return_value}__\w+',
'Unexpected file name format')
finally:
try:
filename.close()
os.remove(filename.name)
except (FileNotFoundError, NameError):
# don't need to clean up what's not there
pass

class CompressOutboxTests(unittest.TestCase):
def setUp(self):
# provision test environment
gratia_ex = patch('gratia.common.config.ConfigProxy.get_GratiaExtension',
create=True, return_value='test-extension')
file_frag = patch('gratia.common.config.ConfigProxy.getFilenameFragment',
create=True, return_value='test-filename')

self.mock_gratia_ex = gratia_ex.start()
self.mock_file_frag = file_frag.start()

self.probe_dir = tempfile.mkdtemp()
self.outbox = os.path.join(self.probe_dir, 'outbox')
os.makedirs(self.outbox, exist_ok=True)
self.outfiles = ['testfile1', 'testfile2']

# add content to the files
for testfile in self.outfiles:
content = testfile + ' contains this content'
with open(os.path.join(self.outbox, testfile), 'w', encoding="utf-8") as test:
test.write(content)

self.addCleanup(gratia_ex.stop)
self.addCleanup(file_frag.stop)

def tearDown(self):
# Remove probe_dir after test
shutil.rmtree(self.probe_dir)

def get_tarball_location(self, path_to_tarball, tarball):
"""
Attempts to return exact location of tarball
"""
try:
tarball_location = os.path.join(f'{path_to_tarball}', tarball[0])
except IndexError as notarball:
print("Tarball does not exist!")
self.fail(notarball)

return tarball_location

def test_compress_outbox(self):
"""CompressOutbox compresses the files in the outbox directory
and stores the resulting tarball in probe_dir/staged.
"""
# Assert that CompressOutbox returns True
result = sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles)
self.assertTrue(result)
Comment thread
brianhlin marked this conversation as resolved.

def test_tarball_creation(self):
"""
Assert that tarball is created in the correct location
"""
sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles)

path_to_tarball = f'{self.probe_dir}/staged/store'
# Finds exactly one tarball that matches GenerateFilename function output
tarball = glob.glob("tz.*.test-extension__*", root_dir=path_to_tarball)

# Where tarball exists
tarball_location = self.get_tarball_location(path_to_tarball, tarball)

# Counts files in the directory that the tarball should be in
tarball_count = len((tarball))

self.assertTrue(os.path.exists(tarball_location),
'Tarball not created in correct location')
self.assertEqual(tarball_count, 1,
f'Expected 1 tarball, found {tarball_count}')

def test_tarball_contents(self):
"""
Assert that unpacked tarball contains files from outfiles
"""
sandbox_mgmt.CompressOutbox(self.probe_dir, self.outbox, self.outfiles)
path_to_tarball = f'{self.probe_dir}/staged/store'

# Finds tarball that matches GenerateFilename() output
tarball = glob.glob("tz.*.test-extension__*", root_dir=path_to_tarball)

# Where tarball exists
tarball_location = self.get_tarball_location(path_to_tarball, tarball)

# Gets names of files within tarball
with tarfile.open(tarball_location, "r") as names:

# Names of files in tarball
namelist = names.getnames()

# Sort both lists to ensure order-independent comparison
namelist.sort()
self.outfiles.sort()

# Open files in outfiles
expected_files1 = open(os.path.join(f'{self.outbox}/{self.outfiles[0]}'), 'rb')
expected_files2 = open(os.path.join(f'{self.outbox}/{self.outfiles[1]}'), 'rb')

# Extract the contents from testfile1
file1 = names.extractfile(namelist[0])
file1_contents = file1.readlines()
expected_results_f1 = expected_files1.readlines()

# Extract the contents from testfile2
file2 = names.extractfile(namelist[1])
file2_contents = file2.readlines()
expected_results_f2 = expected_files2.readlines()

self.assertListEqual(namelist, self.outfiles,
'Unexpected file names in tarball')
self.assertEqual(file1_contents, expected_results_f1,
'Unexpected content in file1')
self.assertEqual(file2_contents, expected_results_f2,
'Unexpected content in file2')