diff --git a/python_tests/test_copy_prefix.py b/python_tests/test_copy_prefix.py index 703877e7..2b37169d 100644 --- a/python_tests/test_copy_prefix.py +++ b/python_tests/test_copy_prefix.py @@ -271,3 +271,17 @@ def validate_copy(copy_id, expected_len = None, expected_min_len = None): for i in range(copy_id): last_len = validate_copy(i, expected_min_len = last_len) print(f"--- Copy {i} valid with {last_len} objects") + +def test_copy_prefix_throws_on_path_passed(db0_fixture): + path = "./invalid-dir/nonexistent/-copy/" + # remove path if it exists + if os.path.exists(path): + os.rmdir(path) + + root = MemoTestSingleton([]) + for _ in range(50): + root.value.append(MemoTestClass("a" * 1024)) # 1 KB string + db0.commit() + + with pytest.raises(OSError) as excinfo: + db0.copy_prefix(path) diff --git a/src/dbzero/bindings/python/PyInternalAPI.cpp b/src/dbzero/bindings/python/PyInternalAPI.cpp index 82276b66..02c3fe4a 100644 --- a/src/dbzero/bindings/python/PyInternalAPI.cpp +++ b/src/dbzero/bindings/python/PyInternalAPI.cpp @@ -936,9 +936,15 @@ namespace db0::python PyObject *tryCopyPrefixImpl(BDevStorage &src_storage, const std::string &output_file_name, std::optional page_io_step_size, std::optional meta_io_step_size) { + // make sure output is file doesn't point to a directory + if (output_file_name.back() == std::filesystem::path::preferred_separator) { + PyErr_Format(PyExc_OSError, "Output file points to a directory: '%s'", output_file_name); + return nullptr; + } // make sure output file does not exist if (db0::CFile::exists(output_file_name)) { - THROWF(db0::IOException) << "Output file already exists: " << output_file_name; + PyErr_Format(PyExc_OSError, "Output file already exists: '%s'", output_file_name); + return nullptr; } // use either explicit step size, input step size (if > 1) or default = 4MB @@ -1033,6 +1039,9 @@ namespace db0::python auto result = Py_OWN(tryCopyPrefixImpl(*storage, output_file_name, page_io_step_size, meta_io_step_size)); storage->close(); return result.steal(); + if (!result) { + return nullptr; + } } catch (...) { if (storage) { storage->close();