diff --git a/__tests__/index.test.ts b/__tests__/index.test.ts index 227970c..4b5a8df 100644 --- a/__tests__/index.test.ts +++ b/__tests__/index.test.ts @@ -5,7 +5,7 @@ import { expect, it, describe, inject } from "vitest" import { deserialize, serialize, RpcSession, type RpcSessionOptions, RpcTransport, RpcTarget, RpcStub, newWebSocketRpcSession, newMessagePortRpcSession, - newHttpBatchRpcSession} from "../src/index.js" + newHttpBatchRpcSession, type WireFormat, jsonFormat} from "../src/index.js" import { Counter, TestTarget } from "./test-util.js"; let SERIALIZE_TEST_CASES: Record = { @@ -111,15 +111,17 @@ class TestTransport implements RpcTransport { } } - private queue: string[] = []; + private queue: (string | ArrayBuffer)[] = []; private waiter?: () => void; private aborter?: (err: any) => void; public log = false; - async send(message: string): Promise { + async send(message: string | ArrayBuffer): Promise { // HACK: If the string "$remove$" appears in the message, remove it. This is used in some // tests to hack the RPC protocol. - message = message.replaceAll("$remove$", ""); + if (typeof message === "string") { + message = message.replaceAll("$remove$", ""); + } if (this.log) console.log(`${this.name}: ${message}`); this.partner!.queue.push(message); @@ -130,7 +132,7 @@ class TestTransport implements RpcTransport { } } - async receive(): Promise { + async receive(): Promise { if (this.queue.length == 0) { await new Promise((resolve, reject) => { this.waiter = resolve; @@ -161,11 +163,11 @@ class TestHarness { stub: RpcStub; - constructor(target: T, serverOptions?: RpcSessionOptions) { + constructor(target: T, serverOptions?: RpcSessionOptions, clientOptions?: RpcSessionOptions) { this.clientTransport = new TestTransport("client"); this.serverTransport = new TestTransport("server", this.clientTransport); - this.client = new RpcSession(this.clientTransport); + this.client = new RpcSession(this.clientTransport, undefined, clientOptions); // TODO: If I remove `` here, I get a TypeScript error about the instantiation being // excessively deep and possibly infinite. Why? `` is supposed to be the default. @@ -1509,3 +1511,58 @@ describe("MessagePorts", () => { new Error("Peer closed MessagePort connection.")); }); }); + +describe("WireFormat", () => { + it("works with the default jsonFormat", async () => { + let fmtOpts: RpcSessionOptions = { format: jsonFormat }; + await using harness = new TestHarness(new TestTarget(), fmtOpts, fmtOpts); + expect(await harness.stub.square(5)).toBe(25); + }); + + it("works with a custom identity format", async () => { + // A trivial format that still uses JSON under the hood but proves the plumbing works. + let encodeCount = 0; + let decodeCount = 0; + let customFormat: WireFormat = { + encode(value: unknown): string { + encodeCount++; + return JSON.stringify(value); + }, + decode(data: string | ArrayBuffer): unknown { + decodeCount++; + if (typeof data !== "string") throw new Error("expected string"); + return JSON.parse(data); + }, + }; + + let fmtOpts: RpcSessionOptions = { format: customFormat }; + await using harness = new TestHarness(new TestTarget(), fmtOpts, fmtOpts); + expect(await harness.stub.square(4)).toBe(16); + expect(encodeCount).toBeGreaterThan(0); + expect(decodeCount).toBeGreaterThan(0); + }); + + it("works with an ArrayBuffer-based format", async () => { + // Format that encodes to ArrayBuffer via TextEncoder/TextDecoder (proves binary path). + let encoder = new TextEncoder(); + let decoder = new TextDecoder(); + let binaryFormat: WireFormat = { + encode(value: unknown): ArrayBuffer { + return encoder.encode(JSON.stringify(value)).buffer as ArrayBuffer; + }, + decode(data: string | ArrayBuffer): unknown { + if (data instanceof ArrayBuffer) { + return JSON.parse(decoder.decode(data)); + } + return JSON.parse(data as string); + }, + }; + + let fmtOpts: RpcSessionOptions = { format: binaryFormat }; + await using harness = new TestHarness(new TestTarget(), fmtOpts, fmtOpts); + expect(await harness.stub.square(6)).toBe(36); + + using counter = await harness.stub.makeCounter(10); + expect(await counter.increment(5)).toBe(15); + }); +}); diff --git a/src/batch.ts b/src/batch.ts index 4f423c7..4557f7a 100644 --- a/src/batch.ts +++ b/src/batch.ts @@ -19,16 +19,16 @@ class BatchClientTransport implements RpcTransport { #batchToSend: string[] | null = []; #batchToReceive: string[] | null = null; - async send(message: string): Promise { + async send(message: string | ArrayBuffer): Promise { // If the batch was already sent, we just ignore the message, because throwing may cause the // RPC system to abort prematurely. Once the last receive() is done then we'll throw an error // that aborts the RPC system at the right time and will propagate to all other requests. if (this.#batchToSend !== null) { - this.#batchToSend.push(message); + this.#batchToSend.push(message as string); } } - async receive(): Promise { + async receive(): Promise { if (!this.#batchToReceive) { await this.#promise; } @@ -98,11 +98,11 @@ class BatchServerTransport implements RpcTransport { #batchToReceive: string[]; #allReceived: PromiseWithResolvers = Promise.withResolvers(); - async send(message: string): Promise { - this.#batchToSend.push(message); + async send(message: string | ArrayBuffer): Promise { + this.#batchToSend.push(message as string); } - async receive(): Promise { + async receive(): Promise { let msg = this.#batchToReceive!.shift(); if (msg !== undefined) { return msg; diff --git a/src/index.ts b/src/index.ts index 1ff376e..39c2ded 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,7 +4,8 @@ import { RpcTarget as RpcTargetImpl, RpcStub as RpcStubImpl, RpcPromise as RpcPromiseImpl } from "./core.js"; import { serialize, deserialize } from "./serialize.js"; -import { RpcTransport, RpcSession as RpcSessionImpl, RpcSessionOptions } from "./rpc.js"; +import { RpcTransport, RpcSession as RpcSessionImpl, RpcSessionOptions, WireFormat, + jsonFormat } from "./rpc.js"; import { RpcTargetBranded, RpcCompatible, Stub, Stubify, __RPC_TARGET_BRAND } from "./types.js"; import { newWebSocketRpcSession as newWebSocketRpcSessionImpl, newWorkersWebSocketRpcResponse } from "./websocket.js"; @@ -17,8 +18,8 @@ forceInitMap(); // Re-export public API types. export { serialize, deserialize, newWorkersWebSocketRpcResponse, newHttpBatchRpcResponse, - nodeHttpBatchRpcResponse }; -export type { RpcTransport, RpcSessionOptions, RpcCompatible }; + nodeHttpBatchRpcResponse, jsonFormat }; +export type { RpcTransport, RpcSessionOptions, RpcCompatible, WireFormat }; // Hack the type system to make RpcStub's types work nicely! /** diff --git a/src/messageport.ts b/src/messageport.ts index f9b5a3c..206c5bc 100644 --- a/src/messageport.ts +++ b/src/messageport.ts @@ -29,7 +29,8 @@ class MessagePortTransport implements RpcTransport { } else if (event.data === null) { // Peer is signaling that they're closing the connection this.#receivedError(new Error("Peer closed MessagePort connection.")); - } else if (typeof event.data === "string") { + } else if (typeof event.data === "string" || + event.data instanceof ArrayBuffer) { if (this.#receiveResolver) { this.#receiveResolver(event.data); this.#receiveResolver = undefined; @@ -48,25 +49,25 @@ class MessagePortTransport implements RpcTransport { } #port: MessagePort; - #receiveResolver?: (message: string) => void; + #receiveResolver?: (message: string | ArrayBuffer) => void; #receiveRejecter?: (err: any) => void; - #receiveQueue: string[] = []; + #receiveQueue: (string | ArrayBuffer)[] = []; #error?: any; - async send(message: string): Promise { + async send(message: string | ArrayBuffer): Promise { if (this.#error) { throw this.#error; } this.#port.postMessage(message); } - async receive(): Promise { + async receive(): Promise { if (this.#receiveQueue.length > 0) { return this.#receiveQueue.shift()!; } else if (this.#error) { throw this.#error; } else { - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { this.#receiveResolver = resolve; this.#receiveRejecter = reject; }); diff --git a/src/rpc.ts b/src/rpc.ts index c0f80cd..cbd87d9 100644 --- a/src/rpc.ts +++ b/src/rpc.ts @@ -5,6 +5,28 @@ import { StubHook, RpcPayload, RpcStub, PropertyPath, PayloadStubHook, ErrorStubHook, RpcTarget, unwrapStubAndPath } from "./core.js"; import { Devaluator, Evaluator, ExportId, ImportId, Exporter, Importer, serialize } from "./serialize.js"; +/** + * Pluggable wire format for encoding/decoding RPC messages. Implement this interface to use a + * binary format (e.g. CBOR) instead of JSON. + */ +export interface WireFormat { + encode(value: unknown): string | ArrayBuffer; + decode(data: string | ArrayBuffer): unknown; +} + +/** Default wire format that uses JSON text encoding. */ +export const jsonFormat: WireFormat = { + encode(value: unknown): string { + return JSON.stringify(value); + }, + decode(data: string | ArrayBuffer): unknown { + if (typeof data !== "string") { + throw new TypeError("jsonFormat received non-string data"); + } + return JSON.parse(data); + }, +}; + /** * Interface for an RPC transport, which is a simple bidirectional message stream. Implement this * interface if the built-in transports (e.g. for HTTP batch and WebSocket) don't meet your needs. @@ -13,7 +35,7 @@ export interface RpcTransport { /** * Sends a message to the other end. */ - send(message: string): Promise; + send(message: string | ArrayBuffer): Promise; /** * Receives a message sent by the other end. @@ -23,7 +45,7 @@ export interface RpcTransport { * If there are no outstanding calls (and none are made in the future), then the error does not * propagate anywhere -- this is considered a "clean" shutdown. */ - receive(): Promise; + receive(): Promise; /** * Indicates that the RPC system has suffered an error that prevents the session from continuing. @@ -298,6 +320,15 @@ export type RpcSessionOptions = { * to serialize the error with the stack omitted. */ onSendError?: (error: Error) => Error | void; + + /** Wire format for encoding/decoding messages. Defaults to `jsonFormat` (JSON text). */ + format?: WireFormat; + + /** + * When true, `Uint8Array` values are passed through raw instead of being base64-encoded. + * Only useful with a binary wire format that natively supports byte arrays. + */ + binaryBytes?: boolean; }; class RpcSessionImpl implements Importer, Exporter { @@ -322,8 +353,13 @@ class RpcSessionImpl implements Importer, Exporter { // may be deleted from the middle (hence leaving the array sparse). onBrokenCallbacks: ((error: any) => void)[] = []; + private format: WireFormat; + private binaryBytes: boolean; + constructor(private transport: RpcTransport, mainHook: StubHook, private options: RpcSessionOptions) { + this.format = options.format ?? jsonFormat; + this.binaryBytes = options.binaryBytes ?? false; // Export zero is automatically the bootstrap object. this.exports.push({hook: mainHook, refcount: 1}); @@ -440,18 +476,21 @@ class RpcSessionImpl implements Importer, Exporter { payload => { // We don't transfer ownership of stubs in the payload since the payload // belongs to the hook which sticks around to handle pipelined requests. - let value = Devaluator.devaluate(payload.value, undefined, this, payload); + let value = Devaluator.devaluate( + payload.value, undefined, this, payload, this.binaryBytes); this.send(["resolve", exportId, value]); }, error => { - this.send(["reject", exportId, Devaluator.devaluate(error, undefined, this)]); + this.send(["reject", exportId, + Devaluator.devaluate(error, undefined, this, undefined, this.binaryBytes)]); } ).catch( error => { // If serialization failed, report the serialization error, which should // itself always be serializable. try { - this.send(["reject", exportId, Devaluator.devaluate(error, undefined, this)]); + this.send(["reject", exportId, + Devaluator.devaluate(error, undefined, this, undefined, this.binaryBytes)]); } catch (error2) { // TODO: Shouldn't happen, now what? this.abort(error2); @@ -511,17 +550,17 @@ class RpcSessionImpl implements Importer, Exporter { return; } - let msgText: string; + let encoded: string | ArrayBuffer; try { - msgText = JSON.stringify(msg); + encoded = this.format.encode(msg); } catch (err) { - // If JSON stringification failed, there's something wrong with the devaluator, as it should - // not allow non-JSONable values to be injected in the first place. + // If encoding failed, there's something wrong with the devaluator, as it should + // not allow non-encodable values to be injected in the first place. try { this.abort(err); } catch (err2) {} throw err; } - this.transport.send(msgText) + this.transport.send(encoded) // If send fails, abort the connection, but don't try to send an abort message since // that'll probably also fail. .catch(err => this.abort(err, false)); @@ -532,7 +571,7 @@ class RpcSessionImpl implements Importer, Exporter { let value: Array = ["pipeline", id, path]; if (args) { - let devalue = Devaluator.devaluate(args.value, undefined, this, args); + let devalue = Devaluator.devaluate(args.value, undefined, this, args, this.binaryBytes); // HACK: Since the args is an array, devaluator will wrap in a second array. Need to unwrap. // TODO: Clean this up somehow. @@ -596,8 +635,8 @@ class RpcSessionImpl implements Importer, Exporter { if (trySendAbortMessage) { try { - this.transport.send(JSON.stringify(["abort", Devaluator - .devaluate(error, undefined, this)])) + this.transport.send(this.format.encode(["abort", Devaluator + .devaluate(error, undefined, this, undefined, this.binaryBytes)])) .catch(err => {}); } catch (err) { // ignore, probably the whole reason we're aborting is because the transport is broken @@ -644,7 +683,8 @@ class RpcSessionImpl implements Importer, Exporter { private async readLoop(abortPromise: Promise) { while (!this.abortReason) { - let msg = JSON.parse(await Promise.race([this.transport.receive(), abortPromise])); + let msg: any = this.format.decode( + await Promise.race([this.transport.receive(), abortPromise])); if (this.abortReason) break; // check again before processing if (msg instanceof Array) { diff --git a/src/serialize.ts b/src/serialize.ts index a4aa042..2e1f7b4 100644 --- a/src/serialize.ts +++ b/src/serialize.ts @@ -65,7 +65,10 @@ interface FromBase64 { // actually converting to a string. (The name is meant to be the opposite of "Evaluator", which // implements the opposite direction.) export class Devaluator { - private constructor(private exporter: Exporter, private source: RpcPayload | undefined) {} + private constructor( + private exporter: Exporter, + private source: RpcPayload | undefined, + private binaryBytes: boolean) {} // Devaluate the given value. // * value: The value to devaluate. @@ -73,12 +76,14 @@ export class Devaluator { // as a function. // * exporter: Callbacks to the RPC session for exporting capabilities found in this message. // * source: The RpcPayload which contains the value, and therefore owns stubs within. + // * binaryBytes: When true, Uint8Array values are passed through raw instead of base64-encoding. // - // Returns: The devaluated value, ready to be JSON-serialized. + // Returns: The devaluated value, ready to be encoded by the wire format. public static devaluate( - value: unknown, parent?: object, exporter: Exporter = NULL_EXPORTER, source?: RpcPayload) + value: unknown, parent?: object, exporter: Exporter = NULL_EXPORTER, source?: RpcPayload, + binaryBytes: boolean = false) : unknown { - let devaluator = new Devaluator(exporter, source); + let devaluator = new Devaluator(exporter, source, binaryBytes); try { return devaluator.devaluateImpl(value, parent, 0); } catch (err) { @@ -155,6 +160,9 @@ export class Devaluator { case "bytes": { let bytes = value as Uint8Array; + if (this.binaryBytes) { + return ["bytes", bytes]; + } if (bytes.toBase64) { return ["bytes", bytes.toBase64({omitPadding: true})]; } else { @@ -326,6 +334,9 @@ export class Evaluator { } break; case "bytes": { + if (value[1] instanceof Uint8Array) { + return value[1]; + } let b64 = Uint8Array as FromBase64; if (typeof value[1] == "string") { if (b64.fromBase64) { diff --git a/src/websocket.ts b/src/websocket.ts index 32dbefa..f955c03 100644 --- a/src/websocket.ts +++ b/src/websocket.ts @@ -59,7 +59,8 @@ class WebSocketTransport implements RpcTransport { webSocket.addEventListener("message", (event: MessageEvent) => { if (this.#error) { // Ignore further messages. - } else if (typeof event.data === "string") { + } else if (typeof event.data === "string" || + event.data instanceof ArrayBuffer) { if (this.#receiveResolver) { this.#receiveResolver(event.data); this.#receiveResolver = undefined; @@ -67,8 +68,20 @@ class WebSocketTransport implements RpcTransport { } else { this.#receiveQueue.push(event.data); } + } else if (ArrayBuffer.isView(event.data)) { + // Convert typed arrays (e.g. Node Buffer) to ArrayBuffer. + let view = event.data; + let buf = (view.buffer as ArrayBuffer).slice( + view.byteOffset, view.byteOffset + view.byteLength); + if (this.#receiveResolver) { + this.#receiveResolver(buf); + this.#receiveResolver = undefined; + this.#receiveRejecter = undefined; + } else { + this.#receiveQueue.push(buf); + } } else { - this.#receivedError(new TypeError("Received non-string message from WebSocket.")); + this.#receivedError(new TypeError("Received unsupported message type from WebSocket.")); } }); @@ -82,13 +95,13 @@ class WebSocketTransport implements RpcTransport { } #webSocket: WebSocket; - #sendQueue?: string[]; // only if not opened yet - #receiveResolver?: (message: string) => void; + #sendQueue?: (string | ArrayBuffer)[]; // only if not opened yet + #receiveResolver?: (message: string | ArrayBuffer) => void; #receiveRejecter?: (err: any) => void; - #receiveQueue: string[] = []; + #receiveQueue: (string | ArrayBuffer)[] = []; #error?: any; - async send(message: string): Promise { + async send(message: string | ArrayBuffer): Promise { if (this.#sendQueue === undefined) { this.#webSocket.send(message); } else { @@ -97,13 +110,13 @@ class WebSocketTransport implements RpcTransport { } } - async receive(): Promise { + async receive(): Promise { if (this.#receiveQueue.length > 0) { return this.#receiveQueue.shift()!; } else if (this.#error) { throw this.#error; } else { - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { this.#receiveResolver = resolve; this.#receiveRejecter = reject; });