Skip to content

Commit

Permalink
Added options to PollingController to help handle cases where the Con…
Browse files Browse the repository at this point in the history
…troller need more than just networkClientId (#1776)

## Explanation
Adds `options` parameter passed around to allow a unique key to be
derived per networkClientId + option combination. This handles cases
where a Controller need more than just networkClientId to poll on, like
an `address` for example.

Fixes https://github.com/MetaMask/MetaMask-planning/issues/1406

## Changelog

### `@metamask/polling-controller`
- **ADDED**: options to PollingController Mixin to help handle cases
where a Controller need more than just networkClientId

---------

Co-authored-by: Alex Donesky <[email protected]>
  • Loading branch information
shanejonas and adonesky1 authored Oct 11, 2023
1 parent 033bba6 commit d53d83a
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 54 deletions.
1 change: 1 addition & 0 deletions packages/polling-controller/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"@metamask/network-controller": "^14.0.0",
"@metamask/utils": "^8.1.0",
"@types/uuid": "^8.3.0",
"fast-json-stable-stringify": "^2.1.0",
"uuid": "^8.3.2"
},
"devDependencies": {
Expand Down
68 changes: 52 additions & 16 deletions packages/polling-controller/src/PollingController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,40 @@ describe('PollingController', () => {
await Promise.resolve();
expect(controller.executePoll).toHaveBeenCalledTimes(2);
});
it('should start and stop polling sessions for different networkClientIds with the same options', async () => {
jest.useFakeTimers();

class MyGasFeeController extends PollingController<any, any, any> {
executePoll = createExecutePollMock();
}
const mockMessenger = new ControllerMessenger<any, any>();

const controller = new MyGasFeeController({
messenger: mockMessenger,
metadata: {},
name: 'PollingController',
state: { foo: 'bar' },
});
const pollToken1 = controller.startPollingByNetworkClientId('mainnet', {
address: '0x1',
});
controller.startPollingByNetworkClientId('mainnet', { address: '0x2' });
controller.startPollingByNetworkClientId('sepolia', { address: '0x2' });
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll).toHaveBeenCalledTimes(3);
controller.stopPollingByNetworkClientId(pollToken1);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll).toHaveBeenCalledTimes(5);
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet', { address: '0x1' }],
['mainnet', { address: '0x2' }],
['sepolia', { address: '0x2' }],
['mainnet', { address: '0x2' }],
['sepolia', { address: '0x2' }],
]);
});
});
describe('multiple networkClientIds', () => {
it('should poll for each networkClientId', async () => {
Expand All @@ -231,16 +265,16 @@ describe('PollingController', () => {
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet'],
['rinkeby'],
['mainnet', {}],
['rinkeby', {}],
]);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet'],
['rinkeby'],
['mainnet'],
['rinkeby'],
['mainnet', {}],
['rinkeby', {}],
['mainnet', {}],
['rinkeby', {}],
]);
controller.stopAllPolling();
});
Expand All @@ -267,27 +301,29 @@ describe('PollingController', () => {
expect(controller.executePoll.mock.calls).toMatchObject([]);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([['mainnet']]);
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet', {}],
]);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet'],
['sepolia'],
['mainnet', {}],
['sepolia', {}],
]);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet'],
['sepolia'],
['mainnet'],
['mainnet', {}],
['sepolia', {}],
['mainnet', {}],
]);
jest.advanceTimersByTime(TICK_TIME);
await Promise.resolve();
expect(controller.executePoll.mock.calls).toMatchObject([
['mainnet'],
['sepolia'],
['mainnet'],
['sepolia'],
['mainnet', {}],
['sepolia', {}],
['mainnet', {}],
['sepolia', {}],
]);
});
});
Expand Down
104 changes: 66 additions & 38 deletions packages/polling-controller/src/PollingController.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import { BaseController, BaseControllerV2 } from '@metamask/base-controller';
import type { NetworkClientId } from '@metamask/network-controller';
import type { Json } from '@metamask/utils';
import stringify from 'fast-json-stable-stringify';
import { v4 as random } from 'uuid';

// Mixin classes require a constructor with an `...any[]` parameter
// See TS2545
type Constructor = new (...args: any[]) => object;

/**
* Returns a unique key for a networkClientId and options. This is used to group networkClientId polls with the same options
* @param networkClientId - The networkClientId to get a key for
* @param options - The options used to group the polling events
* @returns The unique key
*/
export const getKey = (
networkClientId: NetworkClientId,
options: Json,
): PollingGroupId => `${networkClientId}:${stringify(options)}`;

type PollingGroupId = `${NetworkClientId}:${string}`;
/**
* PollingControllerMixin
*
Expand All @@ -20,10 +34,9 @@ function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
*
*/
abstract class PollingControllerBase extends Base {
readonly #networkClientIdTokensMap: Map<NetworkClientId, Set<string>> =
new Map();
readonly #pollingTokenSets: Map<PollingGroupId, Set<string>> = new Map();

readonly #intervalIds: Record<NetworkClientId, NodeJS.Timeout> = {};
readonly #intervalIds: Record<PollingGroupId, NodeJS.Timeout> = {};

#callbacks: Map<
NetworkClientId,
Expand All @@ -49,28 +62,35 @@ function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
* Starts polling for a networkClientId
*
* @param networkClientId - The networkClientId to start polling for
* @param options - The options used to group the polling events
* @returns void
*/
startPollingByNetworkClientId(networkClientId: NetworkClientId) {
const innerPollToken = random();
if (this.#networkClientIdTokensMap.has(networkClientId)) {
const set = this.#networkClientIdTokensMap.get(networkClientId);
set?.add(innerPollToken);
startPollingByNetworkClientId(
networkClientId: NetworkClientId,
options: Json = {},
) {
const pollToken = random();

const key = getKey(networkClientId, options);

const pollingTokenSet = this.#pollingTokenSets.get(key);
if (pollingTokenSet) {
pollingTokenSet.add(pollToken);
} else {
const set = new Set<string>();
set.add(innerPollToken);
this.#networkClientIdTokensMap.set(networkClientId, set);
set.add(pollToken);
this.#pollingTokenSets.set(key, set);
}
this.#poll(networkClientId);
return innerPollToken;
this.#poll(networkClientId, options);
return pollToken;
}

/**
* Stops polling for all networkClientIds
*/
stopAllPolling() {
this.#networkClientIdTokensMap.forEach((tokens, _networkClientId) => {
tokens.forEach((token) => {
this.#pollingTokenSets.forEach((tokenSet, _networkClientId) => {
tokenSet.forEach((token) => {
this.stopPollingByNetworkClientId(token);
});
});
Expand All @@ -86,20 +106,18 @@ function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
throw new Error('pollingToken required');
}
let found = false;
this.#networkClientIdTokensMap.forEach((tokens, networkClientId) => {
if (tokens.has(pollingToken)) {
this.#pollingTokenSets.forEach((tokenSet, key) => {
if (tokenSet.has(pollingToken)) {
found = true;
this.#networkClientIdTokensMap
.get(networkClientId)
?.delete(pollingToken);
if (this.#networkClientIdTokensMap.get(networkClientId)?.size === 0) {
clearTimeout(this.#intervalIds[networkClientId]);
delete this.#intervalIds[networkClientId];
this.#networkClientIdTokensMap.delete(networkClientId);
this.#callbacks.get(networkClientId)?.forEach((callback) => {
callback(networkClientId);
tokenSet.delete(pollingToken);
if (tokenSet.size === 0) {
clearTimeout(this.#intervalIds[key]);
delete this.#intervalIds[key];
this.#pollingTokenSets.delete(key);
this.#callbacks.get(key)?.forEach((callback) => {
callback(key);
});
this.#callbacks.get(networkClientId)?.clear();
this.#callbacks.get(key)?.clear();
}
}
});
Expand All @@ -112,24 +130,29 @@ function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
* Executes the poll for a networkClientId
*
* @param networkClientId - The networkClientId to execute the poll for
* @param options - The options passed to startPollingByNetworkClientId
*/
abstract executePoll(networkClientId: NetworkClientId): Promise<void>;
abstract executePoll(
networkClientId: NetworkClientId,
options: Json,
): Promise<void>;

#poll(networkClientId: NetworkClientId) {
if (this.#intervalIds[networkClientId]) {
clearTimeout(this.#intervalIds[networkClientId]);
delete this.#intervalIds[networkClientId];
#poll(networkClientId: NetworkClientId, options: Json) {
const key = getKey(networkClientId, options);
if (this.#intervalIds[key]) {
clearTimeout(this.#intervalIds[key]);
delete this.#intervalIds[key];
}
// setTimeout is not `await`ing this async function, which is expected
// We're just using async here for improved stack traces
// eslint-disable-next-line @typescript-eslint/no-misused-promises
this.#intervalIds[networkClientId] = setTimeout(async () => {
this.#intervalIds[key] = setTimeout(async () => {
try {
await this.executePoll(networkClientId);
await this.executePoll(networkClientId, options);
} catch (error) {
console.error(error);
}
this.#poll(networkClientId);
this.#poll(networkClientId, options);
}, this.#intervalLength);
}

Expand All @@ -138,17 +161,22 @@ function PollingControllerMixin<TBase extends Constructor>(Base: TBase) {
*
* @param networkClientId - The networkClientId to listen for polling complete events
* @param callback - The callback to execute when polling is complete
* @param options - The options used to group the polling events
*/
onPollingCompleteByNetworkClientId(
networkClientId: NetworkClientId,
callback: (networkClientId: NetworkClientId) => void,
options: Json = {},
) {
if (this.#callbacks.has(networkClientId)) {
this.#callbacks.get(networkClientId)?.add(callback);
} else {
const key = getKey(networkClientId, options);
const callbacks = this.#callbacks.get(key);

if (callbacks === undefined) {
const set = new Set<typeof callback>();
set.add(callback);
this.#callbacks.set(networkClientId, set);
this.#callbacks.set(key, set);
} else {
callbacks.add(callback);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,7 @@ __metadata:
"@types/jest": ^27.4.1
"@types/uuid": ^8.3.0
deepmerge: ^4.2.2
fast-json-stable-stringify: ^2.1.0
jest: ^27.5.1
ts-jest: ^27.1.4
typedoc: ^0.24.8
Expand Down

0 comments on commit d53d83a

Please sign in to comment.