From a94b06c6db591965a6f57bce72dfdcc7ccf923bf Mon Sep 17 00:00:00 2001 From: aether <71261542+aetherpw@users.noreply.github.com> Date: Tue, 9 Jun 2026 12:06:43 +0200 Subject: [PATCH] feat: indomitable strategy --- package.json | 4 +- pnpm-lock.yaml | 3 +- src/client/ShardClientUtil.ts | 11 + src/index.ts | 1 + .../IndomitableShardingStrategyInterface.ts | 14 + .../IndomitableWorkerShardingStrategy.ts | 403 ++++++++++++++++++ src/strategies/defaultWorker.ts | 10 + src/utils/worker.ts | 186 ++++++++ 8 files changed, 630 insertions(+), 2 deletions(-) create mode 100644 src/strategies/IndomitableShardingStrategyInterface.ts create mode 100644 src/strategies/IndomitableWorkerShardingStrategy.ts create mode 100644 src/strategies/defaultWorker.ts create mode 100644 src/utils/worker.ts diff --git a/package.json b/package.json index 788aa17..cbe1b0d 100644 --- a/package.json +++ b/package.json @@ -51,8 +51,10 @@ "lint:fix": "eslint src --fix", "prepack": "pnpm run build" }, + "dependencies": { + "@discordjs/ws": "^2.0.4" + }, "devDependencies": { - "@discordjs/ws": "^2.0.4", "@types/node": "^25.9.1", "@types/ws": "^8.18.1", "discord.js": "^14.26.4", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7895fe9..dd3b0d8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,10 +8,11 @@ settings: importers: .: - devDependencies: + dependencies: '@discordjs/ws': specifier: ^2.0.4 version: 2.0.4 + devDependencies: '@types/node': specifier: ^25.9.1 version: 25.9.1 diff --git a/src/client/ShardClientUtil.ts b/src/client/ShardClientUtil.ts index 8367da6..c1c22fd 100644 --- a/src/client/ShardClientUtil.ts +++ b/src/client/ShardClientUtil.ts @@ -1,11 +1,13 @@ import EventEmitter from "node:events"; import process from "node:process"; import { clearTimeout } from "node:timers"; +import type { WebSocketManager } from "@discordjs/ws"; import type { Client } from "discord.js"; import type { Indomitable } from "../Indomitable.js"; import type { AbortableData, InternalOpsData, Message, SessionObject, Transportable } from "../Util.js"; import { EnvProcessData, MakeAbortableRequest, InternalOps } from "../Util.js"; import { ClientWorker } from "../ipc/ClientWorker.js"; +import type { IndomitableShardingStrategyInterface } from "../strategies/IndomitableShardingStrategyInterface.js"; export interface ShardClientUtilEvents { message: [message: Message]; @@ -53,6 +55,15 @@ export class ShardClientUtil extends EventEmitter { return Number(BigInt(end) - start); } + /** + * Returns the current WebSocket sharding strategy. + */ + public get strategy(): IndomitableShardingStrategyInterface { + // @ts-expect-error internal field + // eslint-disable-next-line @typescript-eslint/dot-notation + return (this.client.ws["_ws"] as WebSocketManager).strategy; + } + /** * Evaluates a script or function on all clusters in the context of the client * diff --git a/src/index.ts b/src/index.ts index 94e8dab..21118b2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,7 @@ export * from "./client/ShardClient.js"; export * from "./client/ShardClientUtil.js"; export * from "./concurrency/ConcurrencyManager.js"; +export * from "./strategies/IndomitableWorkerShardingStrategy.js"; export * from "./ipc/BaseIpc.js"; export * from "./ipc/MainWorker.js"; export * from "./ipc/BaseWorker.js"; diff --git a/src/strategies/IndomitableShardingStrategyInterface.ts b/src/strategies/IndomitableShardingStrategyInterface.ts new file mode 100644 index 0000000..6199a9b --- /dev/null +++ b/src/strategies/IndomitableShardingStrategyInterface.ts @@ -0,0 +1,14 @@ +/** + * This file is adapted from discord.js and includes additional modifications. + * + * Original Apache 2.0 license: + * https://github.com/discordjs/discord.js/blob/3d6121589f9c0d91f7cf4976307e8be07053a277/LICENSE + */ +import type { IShardingStrategy, WebSocketShardDestroyOptions } from "@discordjs/ws"; +import type { Awaitable } from "discord.js"; +import type { WebsocketShardState } from "./IndomitableWorkerShardingStrategy"; + +export interface IndomitableShardingStrategyInterface extends Omit { + destroyShards(shardIds: number[], options?: WebSocketShardDestroyOptions): Awaitable; + fetchStatus(): Awaitable>; +} diff --git a/src/strategies/IndomitableWorkerShardingStrategy.ts b/src/strategies/IndomitableWorkerShardingStrategy.ts new file mode 100644 index 0000000..a9eaf0c --- /dev/null +++ b/src/strategies/IndomitableWorkerShardingStrategy.ts @@ -0,0 +1,403 @@ +/** + * This file is adapted from discord.js and includes additional modifications. + * + * Original Apache 2.0 license: + * https://github.com/discordjs/discord.js/blob/3d6121589f9c0d91f7cf4976307e8be07053a277/LICENSE + */ +import { once } from "node:events"; +import { join, isAbsolute, resolve } from "node:path"; +import { Worker } from "node:worker_threads"; +import { + managerToFetchingStrategyOptions, + type FetchingStrategyOptions, + type IIdentifyThrottler, + type SessionInfo, + type WebSocketManager, + type WebSocketShardDestroyOptions, + type WebSocketShardEvents, + type WebSocketShardStatus, +} from "@discordjs/ws"; +import type { GatewaySendPayload } from "discord.js"; +import type { IndomitableShardingStrategyInterface } from "./IndomitableShardingStrategyInterface.js"; + +export interface WebsocketShardState { + lastHeartbeatAt: number; + sendQueueRemaining: number; + state: WebSocketShardStatus; +} +export interface WorkerData extends FetchingStrategyOptions { + shardIds: number[]; +} + +export enum WorkerSendPayloadOp { + Connect, + Destroy, + Send, + SessionInfoResponse, + ShardIdentifyResponse, + FetchStatus, +} + +export type WorkerSendPayload = + | { nonce: number; ok: boolean; op: WorkerSendPayloadOp.ShardIdentifyResponse } + | { nonce: number; op: WorkerSendPayloadOp.FetchStatus; shardId: number } + | { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null } + | { op: WorkerSendPayloadOp.Connect; shardId: number } + | { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number } + | { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number }; + +export enum WorkerReceivePayloadOp { + Connected, + Destroyed, + ShardDestroyed, + Event, + RetrieveSessionInfo, + UpdateSessionInfo, + WaitForIdentify, + FetchStatusResponse, + WorkerReady, + CancelIdentify, +} + +export type WorkerReceivePayload = + | { + nonce: number; + op: WorkerReceivePayloadOp.FetchStatusResponse; + status: WebsocketShardState; + } + | { data: any[]; event: WebSocketShardEvents; op: WorkerReceivePayloadOp.Event; shardId: number } + | { nonce: number; op: WorkerReceivePayloadOp.CancelIdentify } + | { nonce: number; op: WorkerReceivePayloadOp.RetrieveSessionInfo; shardId: number } + | { nonce: number; op: WorkerReceivePayloadOp.WaitForIdentify; shardId: number } + | { op: WorkerReceivePayloadOp.Connected; shardId: number } + | { op: WorkerReceivePayloadOp.Destroyed; shardId: number } + | { op: WorkerReceivePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number } + | { op: WorkerReceivePayloadOp.WorkerReady }; + +/** + * Options for a {@link WorkerShardingStrategy} + */ +export interface WorkerShardingStrategyOptions { + /** + * Dictates how many shards should be spawned per worker thread. + */ + shardsPerWorker: number | "all"; + /** + * Handles a payload not recognized by the handler. + */ + unknownPayloadHandler?(payload: any): unknown; + /** + * Path to the worker file to use. The worker requires quite a bit of setup, it is recommended you leverage the {@link WorkerBootstrapper} class. + */ + workerPath?: string; +} + +/** + * Strategy used to spawn threads in worker_threads + */ +export class WorkerShardingStrategy implements IndomitableShardingStrategyInterface { + #workers: Worker[] = []; + + readonly #workerByShardId = new Map(); + + private readonly connectPromises = new Map void>(); + + private readonly destroyPromises = new Map void>(); + + private readonly fetchStatusPromises = new Map void>(); + + private readonly waitForIdentifyControllers = new Map(); + + private readonly destroyShardPromises = new Map void>(); + + private throttler?: IIdentifyThrottler; + + public constructor( + private readonly manager: WebSocketManager, + private readonly options: WorkerShardingStrategyOptions, + ) {} + + public async spawn(shardIds: number[]) { + const shardsPerWorker = this.options.shardsPerWorker === "all" ? shardIds.length : this.options.shardsPerWorker; + const strategyOptions = await managerToFetchingStrategyOptions(this.manager); + + const loops = Math.ceil(shardIds.length / shardsPerWorker); + const promises: Promise[] = []; + + for (let idx = 0; idx < loops; idx++) { + const slice = shardIds.slice(idx * shardsPerWorker, (idx + 1) * shardsPerWorker); + const workerData: WorkerData = { + ...strategyOptions, + shardIds: slice, + }; + + promises.push(this.setupWorker(workerData)); + } + + await Promise.all(promises); + } + + public async connect() { + const promises = []; + + for (const [shardId, worker] of this.#workerByShardId.entries()) { + const payload: WorkerSendPayload = { + op: WorkerSendPayloadOp.Connect, + shardId, + }; + + // eslint-disable-next-line no-promise-executor-return + const promise = new Promise((resolve) => this.connectPromises.set(shardId, resolve)); + worker.postMessage(payload); + promises.push(promise); + } + + await Promise.all(promises); + } + + public async destroy(options: Omit = {}) { + const promises = []; + + for (const [shardId, worker] of this.#workerByShardId.entries()) { + const payload: WorkerSendPayload = { + op: WorkerSendPayloadOp.Destroy, + shardId, + options, + }; + + promises.push( + // eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then + new Promise((resolve) => this.destroyPromises.set(shardId, resolve)).then(async () => worker.terminate()), + ); + worker.postMessage(payload); + } + + this.#workers = []; + this.#workerByShardId.clear(); + + await Promise.all(promises); + } + + public async destroyShards(shardIds: number[], options?: WebSocketShardDestroyOptions): Promise { + const promises: Promise[] = []; + + for (const shardId of shardIds) { + const worker = this.#workerByShardId.get(shardId); + if (!worker) { + throw new RangeError(`Shard ${shardId} not found`); + } + + const payload: WorkerSendPayload = { + op: WorkerSendPayloadOp.Destroy, + shardId, + options, + }; + + // eslint-disable-next-line no-promise-executor-return + promises.push(new Promise((resolve) => this.destroyShardPromises.set(shardId, resolve))); + worker.postMessage(payload); + } + + await Promise.all(promises); + + if (options?.recover === undefined) { + for (const shardId of shardIds) { + this.#workerByShardId.delete(shardId); + } + } + } + + public send(shardId: number, data: GatewaySendPayload) { + const worker = this.#workerByShardId.get(shardId); + if (!worker) { + throw new Error(`No worker found for shard ${shardId}`); + } + + const payload: WorkerSendPayload = { + op: WorkerSendPayloadOp.Send, + shardId, + payload: data, + }; + worker.postMessage(payload); + } + + public async fetchStatus() { + const statuses = new Map(); + + for (const [shardId, worker] of this.#workerByShardId.entries()) { + const nonce = Math.random(); + const payload: WorkerSendPayload = { + op: WorkerSendPayloadOp.FetchStatus, + shardId, + nonce, + }; + + const promise = new Promise((resolve) => { + this.fetchStatusPromises.set(nonce, resolve); + }); + + worker.postMessage(payload); + + const status = await promise; + statuses.set(shardId, status); + } + + return statuses; + } + + private async setupWorker(workerData: WorkerData) { + const worker = new Worker(this.resolveWorkerPath(), { workerData }); + + await once(worker, "online"); + // We do this in case the user has any potentially long running code in their worker + await this.waitForWorkerReady(worker); + + worker + .on("error", (err) => { + throw err; + }) + .on("messageerror", (err) => { + throw err; + }) + .on("message", async (payload: any) => { + if ("op" in payload) { + await this.onMessage(worker, payload); + } else { + await this.options.unknownPayloadHandler?.(payload); + } + }); + + this.#workers.push(worker); + for (const shardId of workerData.shardIds) { + this.#workerByShardId.set(shardId, worker); + } + } + + private resolveWorkerPath(): string { + const path = this.options.workerPath; + + if (!path) { + return join(__dirname, "defaultWorker.js"); + } + + if (isAbsolute(path)) { + return path; + } + + if (/^\.\.?[/\\]/.test(path)) { + return resolve(path); + } + + try { + return require.resolve(path); + } catch { + return resolve(path); + } + } + + private async waitForWorkerReady(worker: Worker): Promise { + return new Promise((resolve) => { + const handler = (payload: WorkerReceivePayload) => { + if (payload.op === WorkerReceivePayloadOp.WorkerReady) { + resolve(); + worker.off("message", handler); + } + }; + + worker.on("message", handler); + }); + } + + private async onMessage(worker: Worker, payload: WorkerReceivePayload) { + switch (payload.op) { + case WorkerReceivePayloadOp.Connected: { + this.connectPromises.get(payload.shardId)?.(); + this.connectPromises.delete(payload.shardId); + break; + } + + case WorkerReceivePayloadOp.Destroyed: { + this.destroyPromises.get(payload.shardId)?.(); + this.destroyPromises.delete(payload.shardId); + break; + } + + case WorkerReceivePayloadOp.Event: { + // @ts-expect-error Event props can't be resolved properly, but they are correct + this.manager.emit(payload.event, ...payload.data, payload.shardId); + break; + } + + case WorkerReceivePayloadOp.RetrieveSessionInfo: { + const session = await this.manager.options.retrieveSessionInfo(payload.shardId); + const response: WorkerSendPayload = { + op: WorkerSendPayloadOp.SessionInfoResponse, + nonce: payload.nonce, + session, + }; + worker.postMessage(response); + break; + } + + case WorkerReceivePayloadOp.UpdateSessionInfo: { + await this.manager.options.updateSessionInfo(payload.shardId, payload.session); + break; + } + + case WorkerReceivePayloadOp.WaitForIdentify: { + const throttler = await this.ensureThrottler(); + + // If this rejects it means we aborted, in which case we reply elsewhere. + try { + const controller = new AbortController(); + this.waitForIdentifyControllers.set(payload.nonce, controller); + await throttler.waitForIdentify(payload.shardId, controller.signal); + } catch { + return; + } + + const response: WorkerSendPayload = { + op: WorkerSendPayloadOp.ShardIdentifyResponse, + nonce: payload.nonce, + ok: true, + }; + worker.postMessage(response); + break; + } + + case WorkerReceivePayloadOp.FetchStatusResponse: { + this.fetchStatusPromises.get(payload.nonce)?.(payload.status); + this.fetchStatusPromises.delete(payload.nonce); + break; + } + + case WorkerReceivePayloadOp.WorkerReady: { + break; + } + + case WorkerReceivePayloadOp.CancelIdentify: { + this.waitForIdentifyControllers.get(payload.nonce)?.abort(); + this.waitForIdentifyControllers.delete(payload.nonce); + + const response: WorkerSendPayload = { + op: WorkerSendPayloadOp.ShardIdentifyResponse, + nonce: payload.nonce, + ok: false, + }; + worker.postMessage(response); + + break; + } + + default: { + await this.options.unknownPayloadHandler?.(payload); + break; + } + } + } + + private async ensureThrottler(): Promise { + this.throttler ??= await this.manager.options.buildIdentifyThrottler(this.manager); + return this.throttler; + } +} diff --git a/src/strategies/defaultWorker.ts b/src/strategies/defaultWorker.ts new file mode 100644 index 0000000..39aba7d --- /dev/null +++ b/src/strategies/defaultWorker.ts @@ -0,0 +1,10 @@ +/** + * This file is adapted from discord.js and includes additional modifications. + * + * Original Apache 2.0 license: + * https://github.com/discordjs/discord.js/blob/3d6121589f9c0d91f7cf4976307e8be07053a277/LICENSE + */ +import { WorkerBootstrapper } from "../utils/worker.js"; + +const bootstrapper = new WorkerBootstrapper(); +void bootstrapper.bootstrap(); diff --git a/src/utils/worker.ts b/src/utils/worker.ts new file mode 100644 index 0000000..f6f50a1 --- /dev/null +++ b/src/utils/worker.ts @@ -0,0 +1,186 @@ +/** + * This file is adapted from discord.js and includes additional modifications. + * + * Original Apache 2.0 license: + * https://github.com/discordjs/discord.js/blob/3d6121589f9c0d91f7cf4976307e8be07053a277/LICENSE + */ +import { isMainThread, parentPort, workerData } from "node:worker_threads"; +import type { WebSocketShardDestroyOptions, WorkerData, WorkerSendPayload } from "@discordjs/ws"; +import { + WebSocketShard, + WebSocketShardEvents, + WorkerContextFetchingStrategy, + WorkerSendPayloadOp, +} from "@discordjs/ws"; +import { type Awaitable } from "discord.js"; +import type { WorkerReceivePayload } from "../strategies/IndomitableWorkerShardingStrategy.js"; +import { WorkerReceivePayloadOp } from "../strategies/IndomitableWorkerShardingStrategy.js"; + +/** + * Options for bootstrapping the worker + */ +export interface BootstrapOptions { + /** + * Shard events to just arbitrarily forward to the parent thread for the manager to emit + * Note: By default, this will include ALL events + * you most likely want to handle dispatch within the worker itself + */ + forwardEvents?: WebSocketShardEvents[]; + /** + * Function to call when a shard is created for additional setup + */ + shardCallback?(shard: WebSocketShard): Awaitable; +} + +/** + * Utility class for bootstrapping a worker thread to be used for sharding + */ +export class WorkerBootstrapper { + /** + * The data passed to the worker thread + */ + protected readonly data = workerData as WorkerData; + + /** + * The shards that are managed by this worker + */ + protected readonly shards = new Map(); + + public constructor() { + if (isMainThread) { + throw new Error("Expected WorkerBootstrap to not be used within the main thread"); + } + } + + /** + * Helper method to initiate a shard's connection process + */ + protected async connect(shardId: number): Promise { + const shard = this.shards.get(shardId); + if (!shard) { + throw new RangeError(`Shard ${shardId} does not exist`); + } + + await shard.connect(); + } + + /** + * Helper method to destroy a shard + */ + protected async destroy(shardId: number, options?: WebSocketShardDestroyOptions): Promise { + const shard = this.shards.get(shardId); + if (!shard) { + throw new RangeError(`Shard ${shardId} does not exist`); + } + + await shard.destroy(options); + } + + /** + * Helper method to attach event listeners to the parentPort + */ + protected setupThreadEvents(): void { + parentPort! + .on("messageerror", (err) => { + throw err; + }) + .on("message", async (payload: WorkerSendPayload) => { + switch (payload.op) { + case WorkerSendPayloadOp.Connect: { + await this.connect(payload.shardId); + const response: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.Connected, + shardId: payload.shardId, + }; + parentPort!.postMessage(response); + break; + } + + case WorkerSendPayloadOp.Destroy: { + await this.destroy(payload.shardId, payload.options); + const response: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.Destroyed, + shardId: payload.shardId, + }; + + parentPort!.postMessage(response); + break; + } + + case WorkerSendPayloadOp.Send: { + const shard = this.shards.get(payload.shardId); + if (!shard) { + throw new RangeError(`Shard ${payload.shardId} does not exist`); + } + + await shard.send(payload.payload); + break; + } + + case WorkerSendPayloadOp.SessionInfoResponse: { + break; + } + + case WorkerSendPayloadOp.ShardIdentifyResponse: { + break; + } + + case WorkerSendPayloadOp.FetchStatus: { + const shard = this.shards.get(payload.shardId); + if (!shard) { + throw new Error(`Shard ${payload.shardId} does not exist`); + } + + const response: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.FetchStatusResponse, + status: { + state: shard.status, + // eslint-disable-next-line @typescript-eslint/dot-notation + lastHeartbeatAt: shard["lastHeartbeatAt"], + // eslint-disable-next-line @typescript-eslint/dot-notation + sendQueueRemaining: shard["sendQueue"].remaining, + }, + nonce: payload.nonce, + }; + + parentPort!.postMessage(response); + break; + } + } + }); + } + + /** + * Bootstraps the worker thread with the provided options + */ + public async bootstrap(options: Readonly = {}): Promise { + // Start by initializing the shards + for (const shardId of this.data.shardIds) { + const shard = new WebSocketShard(new WorkerContextFetchingStrategy(this.data), shardId); + for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) { + shard.on(event, (...args: unknown[]) => { + const payload: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.Event, + event, + data: args, + shardId, + }; + + parentPort!.postMessage(payload); + }); + } + + // Any additional setup the user might want to do + await options.shardCallback?.(shard); + this.shards.set(shardId, shard); + } + + // Lastly, start listening to messages from the parent thread + this.setupThreadEvents(); + + const message: WorkerReceivePayload = { + op: WorkerReceivePayloadOp.WorkerReady, + }; + parentPort!.postMessage(message); + } +}