Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 88 additions & 28 deletions Sources/Subprocess/API.swift
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ public func run<
standardError: Error.OutputType
)

let executionContext = ExecutionContext(configuration)
let customInput = CustomWriteInput()

let result = try await configuration.run(
Expand All @@ -454,11 +455,20 @@ public func run<
var errorIOBox: IODescriptor? = consume errorIO

// Write input, capture output and error in parallel
async let stdout = try output.captureOutput(from: outputIOBox.take())
async let stderr = try error.captureOutput(from: errorIOBox.take())
async let stdout = try output.captureOutput(
from: outputIOBox.take(),
executionContext: executionContext
)
async let stderr = try error.captureOutput(
from: errorIOBox.take(),
executionContext: executionContext
)
// Write span at the same isolation
if let writeFd = inputIOBox.take() {
let writer = StandardInputWriter(diskIO: writeFd)
let writer = StandardInputWriter(
diskIO: writeFd,
executionContext: executionContext
)
_ = try await writer.write(input._bytes)
try await writer.finish()
}
Expand Down Expand Up @@ -501,6 +511,7 @@ public func run<
standardOutput: Output.OutputType,
standardError: Error.OutputType
)
let executionContext = ExecutionContext(configuration)
let inputPipe = try input.createPipe()
let outputPipe = try output.createPipe()
let errorPipe = try error.createPipe(from: outputPipe)
Expand All @@ -523,21 +534,26 @@ public func run<
var errorIOContainer: IODescriptor? = errorIOBox.take()
group.addTask {
if let writeFd = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: writeFd)
let writer = StandardInputWriter(
diskIO: writeFd,
executionContext: executionContext
)
try await input.write(with: writer)
try await writer.finish()
}
return nil
}
group.addTask {
let stdout = try await output.captureOutput(
from: outputIOContainer.take()
from: outputIOContainer.take(),
executionContext: executionContext
)
return .standardOutputCaptured(stdout)
}
group.addTask {
let stderr = try await error.captureOutput(
from: errorIOContainer.take()
from: errorIOContainer.take(),
executionContext: executionContext
)
return .standardErrorCaptured(stderr)
}
Expand All @@ -561,14 +577,22 @@ public func run<
standardOutput: stdout,
standardError: stderror
)
} catch {
if let underlying = error as? SubprocessError.UnderlyingError {
throw SubprocessError.asyncIOFailed(
reason: "Failed to capture output",
underlyingError: underlying
)
}
} catch let error as SubprocessError {
// Inner I/O layers (`StandardInputWriter`, `OutputProtocol.captureOutput`)
// already attached `executionContext` at the throw site.
// Rethrow as-is. The outer wrap in `Configuration.run()` is
// a no-op due to the idempotency check in
// `SubprocessError.withExecutionContext(_:)`.
throw error
} catch {
// Should be unreachable. Every child task throws `SubprocessError`.
// If a future change causes a non-`SubprocessError` to escape,
// fall back to wrapping it as `asyncIOFailed`, and the outer
// wrap will populate context.
throw SubprocessError.asyncIOFailed(
reason: "Failed to capture output",
underlyingError: error as? SubprocessError.UnderlyingError
)
}
}
}
Expand Down Expand Up @@ -605,6 +629,7 @@ public func run<
error: Error = .discarded,
body: (_ execution: Execution) async throws -> Result
) async throws -> ExecutionOutcome<Result> where Error.OutputType == Void {
let executionContext = ExecutionContext(configuration)
let inputPipe = try input.createPipe()
let outputPipe = try output.createPipe()
let errorPipe = try error.createPipe(from: outputPipe)
Expand All @@ -627,7 +652,10 @@ public func run<
var inputIOContainer: IODescriptor? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
let writer = StandardInputWriter(
diskIO: inputIO,
executionContext: executionContext
)
try await input.write(with: writer)
try await writer.finish()
}
Expand Down Expand Up @@ -667,6 +695,7 @@ public func run<
_ outputSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> where Error.OutputType == Void {
let executionContext = ExecutionContext(configuration)
let output = SequenceOutput()
let inputPipe = try input.createPipe()
let outputPipe = try output.createPipe()
Expand All @@ -689,15 +718,19 @@ public func run<
var inputIOContainer: IODescriptor? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
let writer = StandardInputWriter(
diskIO: inputIO,
executionContext: executionContext
)
try await input.write(with: writer)
try await writer.finish()
}
}

// Body runs in the same isolation
let outputSequence = AsyncBufferSequence(
diskIO: outputIOBox!.consumeDescriptor()
diskIO: outputIOBox!.consumeDescriptor(),
executionContext: executionContext
)

let result = try await body(execution, outputSequence)
Expand Down Expand Up @@ -728,6 +761,7 @@ public func run<Result, Input: InputProtocol, Output: OutputProtocol>(
_ errorSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> where Output.OutputType == Void {
let executionContext = ExecutionContext(configuration)
let error = SequenceOutput()

return try await configuration.run(
Expand All @@ -747,13 +781,17 @@ public func run<Result, Input: InputProtocol, Output: OutputProtocol>(
var inputIOContainer: IODescriptor? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
let writer = StandardInputWriter(
diskIO: inputIO,
executionContext: executionContext
)
try await input.write(with: writer)
try await writer.finish()
}
}
let errorSequence = AsyncBufferSequence(
diskIO: errorIOBox!.consumeDescriptor()
diskIO: errorIOBox!.consumeDescriptor(),
executionContext: executionContext
)
// Body runs in the same isolation
let result = try await body(execution, errorSequence)
Expand Down Expand Up @@ -793,6 +831,7 @@ public func run<Result, Input: InputProtocol>(
_ errorSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> {
let executionContext = ExecutionContext(configuration)
let output = SequenceOutput()
let error = SequenceOutput()

Expand All @@ -812,19 +851,24 @@ public func run<Result, Input: InputProtocol>(
var inputIOContainer: IODescriptor? = inputIOBox.take()
group.addTask {
if let inputIO = inputIOContainer.take() {
let writer = StandardInputWriter(diskIO: inputIO)
let writer = StandardInputWriter(
diskIO: inputIO,
executionContext: executionContext
)
try await input.write(with: writer)
try await writer.finish()
}
}

// Body runs in the same isolation
let outputSequence = AsyncBufferSequence(
diskIO: outputIOBox!.consumeDescriptor()
diskIO: outputIOBox!.consumeDescriptor(),
executionContext: executionContext
)

let errorSequence = AsyncBufferSequence(
diskIO: errorIOBox!.consumeDescriptor()
diskIO: errorIOBox!.consumeDescriptor(),
executionContext: executionContext
)

let result = try await body(execution, outputSequence, errorSequence)
Expand Down Expand Up @@ -856,6 +900,7 @@ public func run<Result, Error: ErrorOutputProtocol>(
_ outputSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> where Error.OutputType == Void {
let executionContext = ExecutionContext(configuration)
let input = CustomWriteInput()
let output = SequenceOutput()
let inputPipe = try input.createPipe()
Expand All @@ -871,9 +916,13 @@ public func run<Result, Error: ErrorOutputProtocol>(
var errorIOBox = consume errorIO
try errorIOBox?.safelyClose()

let writer = StandardInputWriter(diskIO: inputIO!)
let writer = StandardInputWriter(
diskIO: inputIO!,
executionContext: executionContext
)
let outputSequence = AsyncBufferSequence(
diskIO: outputIOBox!.consumeDescriptor()
diskIO: outputIOBox!.consumeDescriptor(),
executionContext: executionContext
)

let result = try await body(execution, writer, outputSequence)
Expand Down Expand Up @@ -904,6 +953,7 @@ public func run<Result, Output: OutputProtocol>(
_ errorSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> where Output.OutputType == Void {
let executionContext = ExecutionContext(configuration)
let input = CustomWriteInput()
let error = SequenceOutput()

Expand All @@ -916,9 +966,13 @@ public func run<Result, Output: OutputProtocol>(
var errorIOBox = consume errorIO
try outputIOBox?.safelyClose()

let writer = StandardInputWriter(diskIO: inputIO!)
let writer = StandardInputWriter(
diskIO: inputIO!,
executionContext: executionContext
)
let errorSequence = AsyncBufferSequence(
diskIO: errorIOBox!.consumeDescriptor()
diskIO: errorIOBox!.consumeDescriptor(),
executionContext: executionContext
)
let bodyResult = try await body(execution, writer, errorSequence)
try await writer.finish()
Expand Down Expand Up @@ -948,6 +1002,7 @@ public func run<Result>(
_ errorSequence: AsyncBufferSequence
) async throws -> Result
) async throws -> ExecutionOutcome<Result> {
let executionContext = ExecutionContext(configuration)
let input = CustomWriteInput()
let output = SequenceOutput()
let error = SequenceOutput()
Expand All @@ -960,12 +1015,17 @@ public func run<Result>(
var outputIOBox = consume outputIO
var errorIOBox = consume errorIO

let writer = StandardInputWriter(diskIO: inputIO!)
let writer = StandardInputWriter(
diskIO: inputIO!,
executionContext: executionContext
)
let outputSequence = AsyncBufferSequence(
diskIO: outputIOBox!.consumeDescriptor()
diskIO: outputIOBox!.consumeDescriptor(),
executionContext: executionContext
)
let errorSequence = AsyncBufferSequence(
diskIO: errorIOBox!.consumeDescriptor()
diskIO: errorIOBox!.consumeDescriptor(),
executionContext: executionContext
)
let result = try await body(execution, writer, outputSequence, errorSequence)
try await writer.finish()
Expand Down
Loading