diff --git a/packages/keyring-controller/CHANGELOG.md b/packages/keyring-controller/CHANGELOG.md index c6dad52923f..0a2910c9aef 100644 --- a/packages/keyring-controller/CHANGELOG.md +++ b/packages/keyring-controller/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add `withController` action to run atomic operations on multiple keyrings (within a single transaction) ([#8416](https://github.com/MetaMask/core/pull/8416)) + - This action uses a `RestrictedController` object that exposes `addNewKeyring` and `removeKeyring` methods to add and remove keyring during the transaction (atomic) call. - Expose `KeyringController:signTransaction` method through `KeyringController` messenger ([#8408](https://github.com/MetaMask/core/pull/8408)) - Persist vault when keyring state changes during unlock ([#8415](https://github.com/MetaMask/core/pull/8415)) - If a keyring's serialized state differs after deserialization (e.g. a migration ran, or metadata was missing), the vault is now re-persisted so the change is not lost on the next unlock. diff --git a/packages/keyring-controller/src/KeyringController-method-action-types.ts b/packages/keyring-controller/src/KeyringController-method-action-types.ts index 6fb608c050d..3f1bd62fee4 100644 --- a/packages/keyring-controller/src/KeyringController-method-action-types.ts +++ b/packages/keyring-controller/src/KeyringController-method-action-types.ts @@ -375,6 +375,24 @@ export type KeyringControllerWithKeyringV2UnsafeAction = { handler: KeyringController['withKeyringV2Unsafe']; }; +/** + * Execute an operation against all keyrings as a mutually exclusive atomic + * operation. The operation receives a {@link RestrictedController} instance + * that exposes a read-only live view of all keyrings as well as + * `addNewKeyring` and `removeKeyring` methods to stage mutations. + * + * The method automatically persists changes at the end of the function + * execution, or rolls back the changes if an error is thrown. + * + * @param operation - Function to execute with the restricted controller. + * @returns Promise resolving to the result of the function execution. + * @template CallbackResult - The type of the value resolved by the callback function. + */ +export type KeyringControllerWithControllerAction = { + type: `KeyringController:withController`; + handler: KeyringController['withController']; +}; + /** * Union of all KeyringController action types. */ @@ -401,4 +419,5 @@ export type KeyringControllerMethodActions = | KeyringControllerWithKeyringAction | KeyringControllerWithKeyringUnsafeAction | KeyringControllerWithKeyringV2Action - | KeyringControllerWithKeyringV2UnsafeAction; + | KeyringControllerWithKeyringV2UnsafeAction + | KeyringControllerWithControllerAction; diff --git a/packages/keyring-controller/src/KeyringController.test.ts b/packages/keyring-controller/src/KeyringController.test.ts index 57cd39afe3f..b596774b040 100644 --- a/packages/keyring-controller/src/KeyringController.test.ts +++ b/packages/keyring-controller/src/KeyringController.test.ts @@ -4068,6 +4068,265 @@ describe('KeyringController', () => { }); }); + describe('withController', () => { + it('throws if the controller is locked', async () => { + await withController( + { skipVaultCreation: true }, + async ({ controller }) => { + await expect(controller.withController(jest.fn())).rejects.toThrow( + KeyringControllerErrorMessage.ControllerLocked, + ); + }, + ); + }); + + it('provides the current keyrings to the callback', async () => { + await withController(async ({ controller, initialState }) => { + await controller.withController(async (restrictedController) => { + expect(restrictedController.keyrings).toHaveLength(1); + expect(restrictedController.keyrings[0].metadata).toStrictEqual( + initialState.keyrings[0].metadata, + ); + }); + }); + }); + + it('returns the result of the callback', async () => { + await withController(async ({ controller }) => { + const result = await controller.withController(async () => 'hello'); + expect(result).toBe('hello'); + }); + }); + + it('throws if the callback returns a raw keyring instance', async () => { + await withController(async ({ controller }) => { + await expect( + controller.withController(async (restrictedController) => { + return restrictedController.keyrings[0].keyring; + }), + ).rejects.toThrow( + KeyringControllerErrorMessage.UnsafeDirectKeyringAccess, + ); + }); + }); + + it('throws if the callback returns a raw keyring (v2) instance', async () => { + await withController(async ({ controller }) => { + await expect( + controller.withController(async (restrictedController) => { + return restrictedController.keyrings[0].keyringV2; + }), + ).rejects.toThrow( + KeyringControllerErrorMessage.UnsafeDirectKeyringAccess, + ); + }); + }); + + describe('addNewKeyring', () => { + it('creates an initialized keyring and stages it for commit', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.withController(async (restrictedController) => { + const entry = await restrictedController.addNewKeyring( + MockKeyring.type, + ); + + expect(entry.keyring).toBeInstanceOf(MockKeyring); + expect(entry.metadata.id).toBeDefined(); + }); + + expect(controller.state.keyrings).toHaveLength(2); + }, + ); + }); + + it('populates keyringV2 when a V2 builder is registered for the type', async () => { + await withController(async ({ controller }) => { + await controller.withController(async (restrictedController) => { + const entry = await restrictedController.addNewKeyring( + KeyringTypes.simple, + ); + + expect(entry.keyringV2).toBeDefined(); + }); + }); + }); + + it('appears immediately in restrictedController.keyrings', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.withController(async (restrictedController) => { + expect(restrictedController.keyrings).toHaveLength(1); + await restrictedController.addNewKeyring(MockKeyring.type); + expect(restrictedController.keyrings).toHaveLength(2); + }); + }, + ); + }); + + it('destroys created keyrings and does not commit them if the callback throws', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + const destroySpy = jest + .spyOn(MockKeyring.prototype, 'destroy') + .mockResolvedValue(undefined); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await expect( + controller.withController(async (restrictedController) => { + await restrictedController.addNewKeyring(MockKeyring.type); + throw new Error('Oops'); + }), + ).rejects.toThrow('Oops'); + + expect(destroySpy).toHaveBeenCalledTimes(1); + expect(controller.state.keyrings).toHaveLength(1); + }, + ); + }); + }); + + describe('removeKeyring', () => { + it('removes a keyring by id and commits the removal', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.addNewKeyring(MockKeyring.type); + const idToRemove = controller.state.keyrings[1].metadata.id; + + await controller.withController(async (restrictedController) => { + await restrictedController.removeKeyring(idToRemove); + }); + + expect(controller.state.keyrings).toHaveLength(1); + expect( + controller.state.keyrings.find( + (k) => k.metadata.id === idToRemove, + ), + ).toBeUndefined(); + }, + ); + }); + + it('disappears from restrictedController.keyrings immediately', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.addNewKeyring(MockKeyring.type); + const idToRemove = controller.state.keyrings[1].metadata.id; + + await controller.withController(async (restrictedController) => { + expect(restrictedController.keyrings).toHaveLength(2); + await restrictedController.removeKeyring(idToRemove); + expect(restrictedController.keyrings).toHaveLength(1); + }); + }, + ); + }); + + it('destroys the removed keyring', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + const destroySpy = jest + .spyOn(MockKeyring.prototype, 'destroy') + .mockResolvedValue(undefined); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.addNewKeyring(MockKeyring.type); + const idToRemove = controller.state.keyrings[1].metadata.id; + + await controller.withController(async (restrictedController) => { + await restrictedController.removeKeyring(idToRemove); + }); + + expect(destroySpy).toHaveBeenCalledTimes(1); + }, + ); + }); + + it('throws KeyringNotFound for an unknown id', async () => { + await withController(async ({ controller }) => { + await expect( + controller.withController(async (restrictedController) => { + await restrictedController.removeKeyring('non-existent-id'); + }), + ).rejects.toThrow(KeyringControllerErrorMessage.KeyringNotFound); + }); + }); + + it('destroys a keyring that was created then removed within the same callback', async () => { + const mockAddress = '0x4584d2B4905087A100420AFfCe1b2d73fC69B8E4'; + stubKeyringClassWithAccount(MockKeyring, mockAddress); + const destroySpy = jest + .spyOn(MockKeyring.prototype, 'destroy') + .mockResolvedValue(undefined); + + await withController( + { keyringBuilders: [keyringBuilderFactory(MockKeyring)] }, + async ({ controller }) => { + await controller.withController(async (restrictedController) => { + const { metadata } = await restrictedController.addNewKeyring( + MockKeyring.type, + ); + await restrictedController.removeKeyring(metadata.id); + }); + + expect(destroySpy).toHaveBeenCalledTimes(1); + expect(controller.state.keyrings).toHaveLength(1); + }, + ); + }); + }); + + it('rolls back on error', async () => { + await withController(async ({ controller, initialState }) => { + await expect( + controller.withController(async (restrictedController) => { + await restrictedController.addNewKeyring(KeyringTypes.simple); + throw new Error('Oops'); + }), + ).rejects.toThrow('Oops'); + + expect(controller.state.keyrings).toHaveLength( + initialState.keyrings.length, + ); + expect(await controller.getAccounts()).toStrictEqual( + initialState.keyrings[0].accounts, + ); + }); + }); + + it('does not update the vault if no keyrings change', async () => { + await withController(async ({ controller, encryptor }) => { + const encryptSpy = jest.spyOn(encryptor, 'encrypt'); + + await controller.withController(async () => { + // no-op + }); + + expect(encryptSpy).not.toHaveBeenCalled(); + }); + }); + }); + describe('withKeyringUnsafe', () => { it('calls the given function without acquiring the lock', async () => { await withController(async ({ controller }) => { @@ -5025,6 +5284,28 @@ describe('KeyringController', () => { }); }); + describe('withController', () => { + it('should call withController', async () => { + await withController(async ({ messenger }) => { + const operation = jest.fn().mockResolvedValue('result'); + + const actionReturnValue = await messenger.call( + 'KeyringController:withController', + operation, + ); + + expect(operation).toHaveBeenCalledWith( + expect.objectContaining({ + keyrings: expect.any(Array), + addNewKeyring: expect.any(Function), + removeKeyring: expect.any(Function), + }), + ); + expect(actionReturnValue).toBe('result'); + }); + }); + }); + describe('addNewKeyring', () => { it('should call addNewKeyring', async () => { const mockKeyringMetadata: KeyringMetadata = { diff --git a/packages/keyring-controller/src/KeyringController.ts b/packages/keyring-controller/src/KeyringController.ts index 415a912b9d6..85e1f0f0eff 100644 --- a/packages/keyring-controller/src/KeyringController.ts +++ b/packages/keyring-controller/src/KeyringController.ts @@ -67,6 +67,7 @@ const MESSENGER_EXPOSED_METHODS = [ 'patchUserOperation', 'signUserOperation', 'addNewAccount', + 'withController', 'withKeyring', 'withKeyringUnsafe', 'withKeyringV2', @@ -232,7 +233,10 @@ export type KeyringMetadata = { name: string; }; -type KeyringEntry = { +/** + * A keyring entry, including the keyring instance (+ v2 instance) and its metadata. + */ +export type KeyringEntry = { /** * The keyring instance. */ @@ -249,6 +253,37 @@ type KeyringEntry = { metadata: KeyringMetadata; }; +/** + * A restricted view of the {@link KeyringController} exposed to the callback + * passed to {@link KeyringController.withController}. + * + * It provides a read-only live view of all keyrings and the ability to stage + * keyring additions and removals atomically within a single transaction. + */ +export type RestrictedController = { + /** + * Read-only live view of all keyrings in the current transaction (original + * keyrings plus any added, minus any removed so far in this callback). + */ + readonly keyrings: readonly KeyringEntry[]; + /** + * Create a new keyring of the given type and stage it for commit. The new + * entry is immediately visible in {@link RestrictedController.keyrings}. + * + * @param type - The type of keyring to create. + * @param opts - Optional data to pass to the keyring builder. + * @returns The newly created `{ keyring, metadata }` entry. + */ + addNewKeyring(type: string, opts?: unknown): Promise; + /** + * Stage the keyring with the given id for removal. The keyring is + * immediately removed from {@link RestrictedController.keyrings}. + * + * @param id - The id of the keyring to remove. + */ + removeKeyring(id: string): Promise; +}; + /** * A strategy for importing an account */ @@ -668,10 +703,7 @@ function isSerializedKeyringsArray( async function displayForKeyring({ keyring, metadata, -}: { - keyring: EthKeyring; - metadata: KeyringMetadata; -}): Promise { +}: KeyringEntry): Promise { const accounts = await keyring.getAccounts(); return { @@ -1765,13 +1797,7 @@ export class KeyringController< CallbackResult = void, >( selector: KeyringSelector, - operation: ({ - keyring, - metadata, - }: { - keyring: SelectedKeyring; - metadata: KeyringMetadata; - }) => Promise, + operation: ({ keyring, metadata }: KeyringEntry) => Promise, // eslint-disable-next-line @typescript-eslint/unified-signatures options: | { createIfMissing?: false } @@ -1798,13 +1824,7 @@ export class KeyringController< CallbackResult = void, >( selector: KeyringSelector, - operation: ({ - keyring, - metadata, - }: { - keyring: SelectedKeyring; - metadata: KeyringMetadata; - }) => Promise, + operation: ({ keyring, metadata }: KeyringEntry) => Promise, ): Promise; async withKeyring< @@ -2069,6 +2089,117 @@ export class KeyringController< ); } + /** + * Execute an operation against all keyrings as a mutually exclusive atomic + * operation. The operation receives a {@link RestrictedController} instance + * that exposes a read-only live view of all keyrings as well as + * `addNewKeyring` and `removeKeyring` methods to stage mutations. + * + * The method automatically persists changes at the end of the function + * execution, or rolls back the changes if an error is thrown. + * + * @param operation - Function to execute with the restricted controller. + * @returns Promise resolving to the result of the function execution. + * @template CallbackResult - The type of the value resolved by the callback function. + */ + async withController( + operation: ( + restrictedController: RestrictedController, + ) => Promise, + ): Promise { + this.#assertIsUnlocked(); + + return this.#persistOrRollback(async () => { + // Track created and removed keyrings during the operation execution. + const createdEntries = new Set(); + const removedEntries = new Set(); + + // Copy of the current keyrings that is mutated during the operation execution. + const restrictedEntries = [...this.#keyrings]; + + // The restricted controller proxies the current keyrings and allows staging + // mutations that are only applied to the real keyrings if the operation + // completes successfully. This allows us to have a single source of truth + // for the keyrings during the operation execution, and to automatically + // roll back any changes if an error is thrown. + const restrictedController: RestrictedController = { + // We freeze the array to prevent direct mutations, but the keyring instances + // themselves are not frozen, allowing safe read-only access. + get keyrings() { + return Object.freeze([...restrictedEntries]); + }, + + // Method to create a new keyring and adds it to the restricted entries. + addNewKeyring: async (type: string, opts?: unknown) => { + const entry = await this.#createKeyring(type, opts); + + restrictedEntries.push(entry); + createdEntries.add(entry); + + return entry; + }, + + // Method to remove a keyring from the restricted entries. + removeKeyring: async (id: string) => { + const index = restrictedEntries.findIndex( + (entry) => entry.metadata.id === id, + ); + if (index === -1) { + throw new KeyringControllerError( + KeyringControllerErrorMessage.KeyringNotFound, + ); + } + + const [removed] = restrictedEntries.splice(index, 1) as [ + KeyringEntry, + ]; + removedEntries.add(removed); + }, + }; + + const destroyKeyrings = async ( + entries: Iterable, + ): Promise => { + await Promise.all( + [...entries].map(({ keyring, keyringV2 }) => + this.#destroyKeyring(keyring, keyringV2), + ), + ); + }; + + let result: CallbackResult; + try { + result = await operation(restrictedController); + } catch (error) { + await destroyKeyrings(createdEntries); + + throw error; + } + + await destroyKeyrings(removedEntries); + + // We update the real keyrings only after the operation completes successfully, so that + // they will be persisted in the vault. + this.#keyrings = restrictedEntries; + + // As usual, we want to prevent returning direct references to keyring instances, so we check + // the result for any unsafe direct access before returning. + for (const { keyring, keyringV2 } of [ + ...this.#keyrings, + // We also check for keyrings that got removed during the operation, since the result could + // still have references to them. + ...removedEntries, + ]) { + this.#assertNoUnsafeDirectKeyringAccess(result, keyring); + if (keyringV2) { + this.#assertNoUnsafeDirectKeyringAccess(result, keyringV2); + } + } + + return result; + }); + } + async getAccountKeyringType(account: string): Promise { this.#assertIsUnlocked(); diff --git a/packages/keyring-controller/src/index.ts b/packages/keyring-controller/src/index.ts index 9ac85f95897..4d570f0c98e 100644 --- a/packages/keyring-controller/src/index.ts +++ b/packages/keyring-controller/src/index.ts @@ -19,6 +19,7 @@ export type { KeyringControllerPrepareUserOperationAction, KeyringControllerPatchUserOperationAction, KeyringControllerSignUserOperationAction, + KeyringControllerWithControllerAction, KeyringControllerWithKeyringAction, KeyringControllerWithKeyringUnsafeAction, KeyringControllerWithKeyringV2Action, diff --git a/packages/keyring-controller/tests/mocks/mockKeyring.ts b/packages/keyring-controller/tests/mocks/mockKeyring.ts index 4bfbeb77c52..89bff0f1abd 100644 --- a/packages/keyring-controller/tests/mocks/mockKeyring.ts +++ b/packages/keyring-controller/tests/mocks/mockKeyring.ts @@ -32,4 +32,8 @@ export class MockKeyring implements EthKeyring { async deserialize(_: unknown): Promise { return Promise.resolve(); } + + async destroy(): Promise { + return Promise.resolve(); + } }