Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions app/api/v2/handlers/payload_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import itertools
import logging
import os
import pathlib
import re
from io import IOBase
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from typing import Optional

import aiohttp_apispec
Expand All @@ -12,6 +14,9 @@
from app.api.v2.handlers.base_api import BaseApi
from app.api.v2.schemas.payload_schemas import PayloadQuerySchema, PayloadSchema, PayloadCreateRequestSchema, \
PayloadDeleteRequestSchema
from app.service.file_svc import USER_PAYLOAD_ENCRYPTION_FLAG

PAYLOAD_API_LOGGER_NAME = 'payload_api_handler'


class PayloadApi(BaseApi):
Expand Down Expand Up @@ -87,7 +92,7 @@ async def post_payloads(self, request: web.Request):
# Save the file to a temporary location first
temp_file_path = pathlib.Path(file_path).parent / f"temp_{file_name}"
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.__save_file, str(temp_file_path), file_field.file)
await loop.run_in_executor(None, self._save_file, str(temp_file_path), file_field.file)

# Validate the saved file to ensure it is not a symbolic link
if temp_file_path.is_symlink():
Expand Down Expand Up @@ -162,24 +167,44 @@ async def __generate_file_name_and_path(cls, sanitized_filename: str) -> [str, s
return file_name, file_path

@staticmethod
def __save_file(target_file_path: str, io_base_src: IOBase):
def _save_file(target_file_path: str, io_base_src: IOBase):
"""
Save an uploaded file content into a targeted file path.
To prevent unintended server-side execution of payloads, user-provided
payloads will be encrypted using a randomly generated key and IV.

The on-disk file format is as follows:
USER_PAYLOAD_ENCRYPTION_FLAG + key + IV + ciphertext

Note this method calls blocking methods and must be run into a dedicated thread.

:param target_file_path: The destination path to write to.
:param io_base_src: The stream with file content to read from.
"""
size: int = 0
read_chunk: bool = True
key = os.urandom(32)
iv = os.urandom(16)
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
encryptor = cipher.encryptor()

with open(target_file_path, 'wb') as buffered_io_base_dest:
# Write flag, key, and IV prior to ciphertext.
buffered_io_base_dest.write(USER_PAYLOAD_ENCRYPTION_FLAG + key + iv)

# Encrypt each chunk prior to appending it to the file
# We use CTR mode to take advantage of the stream cipher.
while read_chunk:
chunk: bytes = io_base_src.read(8192)
if chunk:
size += len(chunk)
buffered_io_base_dest.write(chunk)
buffered_io_base_dest.write(encryptor.update(chunk))
else:
read_chunk = False
buffered_io_base_dest.write(encryptor.finalize())
logging.getLogger(PAYLOAD_API_LOGGER_NAME).debug(
f'Encrypted {size} bytes of user payload and wrote to disk at {target_file_path}'
)

@staticmethod
def validate_and_canonicalize_path(input_path: str, base_directory: str = "data/payloads/") -> str:
Expand Down
13 changes: 12 additions & 1 deletion app/service/file_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

from app.service.interfaces.i_file_svc import FileServiceInterface
Expand All @@ -28,6 +29,7 @@
'http': URL_SANITIZATION_REGEX,
'socket': re.compile(r'^[\w\-\.:]+$')
}
USER_PAYLOAD_ENCRYPTION_FLAG = bytes('%userencryptedpayload%', encoding='utf-8')


class FileSvc(FileServiceInterface, BaseService):
Expand Down Expand Up @@ -294,7 +296,16 @@ def _save(self, filename, content, encrypt=True):
def _read(self, filename):
with open(filename, 'rb') as f:
buf = f.read()
if self.encryptor and buf.startswith(bytes(FILE_ENCRYPTION_FLAG, encoding='utf-8')):
if buf.startswith(USER_PAYLOAD_ENCRYPTION_FLAG):
# Handle encrypted user-uploaded payloads
buf = buf[len(USER_PAYLOAD_ENCRYPTION_FLAG):]
key = buf[0:32]
iv = buf[32:48]
ciphertext = buf[48:]
cipher = Cipher(algorithms.AES(key), modes.CTR(iv))
decryptor = cipher.decryptor()
buf = decryptor.update(ciphertext) + decryptor.finalize()
elif self.encryptor and buf.startswith(bytes(FILE_ENCRYPTION_FLAG, encoding='utf-8')):
try:
buf = self.encryptor.decrypt(buf[len(FILE_ENCRYPTION_FLAG):])
except InvalidToken:
Expand Down
69 changes: 68 additions & 1 deletion tests/api/v2/handlers/test_payloads_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import io
import os
import pathlib
import tempfile
from aiohttp import FormData
from http import HTTPStatus
from unittest import mock

from app.api.v2.handlers.payload_api import PayloadApi
from app.utility.base_service import BaseService

import pytest

Expand Down Expand Up @@ -36,8 +42,12 @@ def expected_payload_file_names(expected_payload_file_paths):
return {os.path.basename(path) for path in expected_payload_file_paths}


class TestPayloadsApi:
# Return n 0x01 bytes
def _mock_urandom(n):
return b'\x01' * n


class TestPayloadsApi:
async def test_get_payloads(self, api_v2_client, api_cookies, expected_payload_file_names):
resp = await api_v2_client.get('/api/v2/payloads', cookies=api_cookies)
payload_file_names = await resp.json()
Expand Down Expand Up @@ -82,3 +92,60 @@ async def test_get_payloads_name_filter_with_sort_and_add_path(
async def test_unauthorized_get_payloads(self, api_v2_client):
resp = await api_v2_client.get('/api/v2/payloads')
assert resp.status == HTTPStatus.UNAUTHORIZED

@mock.patch.object(pathlib.Path, 'rename')
async def test_post_payloads(self, mock_rename, api_v2_client, api_cookies):
file_data = bytes([0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef])
header_data = b'%userencryptedpayload%' + (b'\x01' * 48) # 32-byte key + 16-byte IV
ciphertext = bytes([0x9d, 0x8f, 0xd1, 0xa1, 0x3d, 0x2e, 0xac, 0x17])
with tempfile.TemporaryFile(mode='w+b') as tmp_file:
tmp_file.write(file_data)
tmp_file.flush()
tmp_file.seek(0)

m = mock.mock_open()
with mock.patch.object(os, 'urandom', wraps=_mock_urandom):
with mock.patch('builtins.open', m):
upload_data = FormData()
upload_data.add_field('file', tmp_file, filename='testpostpayload')
resp = await api_v2_client.post('/api/v2/payloads',
data=upload_data)
assert resp.status == HTTPStatus.OK
assert await resp.json() == dict(payloads=['testpostpayload'])
mock_rename.assert_called_with('data/payloads/testpostpayload')
m.assert_called_with('data/payloads/temp_testpostpayload', 'wb')
m().write.assert_any_call(header_data)
m().write.assert_any_call(ciphertext)
m().write.assert_called_with(b'')

def test_save_file(self):
original_data = os.urandom(24*1024)
with tempfile.NamedTemporaryFile() as fp:
PayloadApi._save_file(fp.name, io.BytesIO(original_data))
decrypted = BaseService.get_service('file_svc')._read(fp.name)
assert decrypted == original_data

async def test_delete_payloads(self, api_v2_client, api_cookies):
want_path = pathlib.Path('data/payloads/testtodelete').resolve()
with mock.patch.object(os, 'remove') as mock_remove:
resp = await api_v2_client.delete('/api/v2/payloads/testtodelete')
mock_remove.assert_called_once_with(want_path)
assert resp.status == 204

# Test ValueError
with mock.patch.object(os, 'remove', side_effect=ValueError('testvalueerror')) as mock_remove:
resp = await api_v2_client.delete('/api/v2/payloads/testtodelete')
assert resp.status == 404
assert resp.reason == 'testvalueerror'

# Test FileNotFoundError
with mock.patch.object(os, 'remove', side_effect=FileNotFoundError()) as mock_remove:
resp = await api_v2_client.delete('/api/v2/payloads/testtodelete')
assert resp.status == 404
assert resp.reason == 'Not Found'

# Test PermissionError
with mock.patch.object(os, 'remove', side_effect=PermissionError()) as mock_remove:
resp = await api_v2_client.delete('/api/v2/payloads/testtodelete')
assert resp.status == 403
assert resp.reason == 'Permission denied.'
Loading