diff --git a/src/synthesis/common.ts b/src/synthesis/common.ts new file mode 100644 index 00000000..029ead19 --- /dev/null +++ b/src/synthesis/common.ts @@ -0,0 +1,22 @@ +import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding"; + +export type Task = { + inputText: string; + option: SynthesisOption; +}; + +export type Result = { + type: "task"; + data: Int16Array; +}; + +export function isAltJTalkConfigValid(arg: unknown): arg is AltJTalkConfig { + return ( + typeof arg === "object" && + arg !== null && + "dictionary" in arg && + "model" in arg && + typeof arg.dictionary === "string" && + typeof arg.model === "string" + ); +} diff --git a/src/synthesis/index.ts b/src/synthesis/index.ts index 4606a7ab..26cf6e9b 100644 --- a/src/synthesis/index.ts +++ b/src/synthesis/index.ts @@ -1,6 +1,5 @@ import { Readable } from "node:stream"; import * as util from "node:util"; -import { Worker } from "node:worker_threads"; import { AudioResource, StreamType, @@ -8,22 +7,12 @@ import { } from "@discordjs/voice"; import { Message } from "discord.js"; import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding"; -import { Result, Task } from "./task"; +import { Result, Task } from "./common"; import WorkerPool from "./worker-pool"; -class SynthesizeWorkerPool extends WorkerPool { - constructor( - private config: AltJTalkConfig, - numThreads?: number, - ) { - super(new URL("task.js", import.meta.url), numThreads ?? 1); - } - - protected override prepareWorker(worker: Worker): void { - worker.postMessage({ - type: "setup", - config: this.config, - } satisfies Task); +class SynthesizeWorkerPool extends WorkerPool { + constructor(config: AltJTalkConfig, numThreads?: number) { + super(new URL("task", import.meta.url), config, numThreads ?? 1); } public async synthesize( @@ -31,15 +20,11 @@ class SynthesizeWorkerPool extends WorkerPool { option: SynthesisOption, ): Promise { const result = await util.promisify(this.runTask.bind(this))({ - type: "task", inputText, option, }); - if (!result) throw new Error("Task returned error!"); - if (result.type !== "task") - throw new Error("Task returned wrong type of response!"); - - return result?.data; + if (result) return result?.data; + else throw new Error("Task returned error!"); } } diff --git a/src/synthesis/task.ts b/src/synthesis/task.ts index 36d6e3d8..28b9acd3 100644 --- a/src/synthesis/task.ts +++ b/src/synthesis/task.ts @@ -1,45 +1,19 @@ -import { parentPort } from "node:worker_threads"; -import { - AltJTalk, - AltJTalkConfig, - SynthesisOption, -} from "node-altjtalk-binding"; +import { parentPort, workerData } from "node:worker_threads"; +import { AltJTalk } from "node-altjtalk-binding"; +import { Result, Task, isAltJTalkConfigValid } from "./common"; -export type Task = - | { - type: "setup"; - config: AltJTalkConfig; - } - | { - type: "task"; - inputText: string; - option: SynthesisOption; - }; +if (!isAltJTalkConfigValid(workerData)) + throw new Error("AltJTalk config is invalid."); -export type Result = { - type: "task"; - data: Int16Array; -}; - -let synthesizer: AltJTalk | undefined = undefined; +const synthesizer: AltJTalk = AltJTalk.fromConfig(workerData); if (parentPort) { parentPort.on("message", (task: Task) => { if (!parentPort) return; - - switch (task.type) { - case "setup": - synthesizer = AltJTalk.fromConfig(task.config); - break; - case "task": { - if (!synthesizer) throw new Error("Synthesizer is not initialized!"); - const data = synthesizer.synthesize(task.inputText, task.option); - parentPort.postMessage({ - type: "task", - data, - } satisfies Result); - break; - } - } + const data = synthesizer.synthesize(task.inputText, task.option); + parentPort.postMessage({ + type: "task", + data, + } satisfies Result); }); } diff --git a/src/synthesis/worker-pool.ts b/src/synthesis/worker-pool.ts index bc7d8d7a..a0ab9548 100644 --- a/src/synthesis/worker-pool.ts +++ b/src/synthesis/worker-pool.ts @@ -4,6 +4,11 @@ import { AsyncResource } from "node:async_hooks"; import { EventEmitter } from "node:events"; import { Worker } from "node:worker_threads"; +const isTsNode = () => { + const tsNodeSymbol = Symbol.for("ts-node.register.instance"); + return tsNodeSymbol in process && !!process[tsNodeSymbol]; +}; + const kWorkerFreedEvent = Symbol("kWorkerFreedEvent"); type Callback = (err: Error | null, result: R | null) => void; @@ -21,14 +26,15 @@ class WorkerPoolTaskInfo extends AsyncResource { } } -export default class WorkerPool extends EventEmitter { - workers: Worker[]; - freeWorkers: Worker[]; - workerInfo = new Map>(); - tasks: { task: T; callback: Callback }[]; +export default class WorkerPool extends EventEmitter { + private workers: Worker[]; + private freeWorkers: Worker[]; + private workerInfo = new Map>(); + private tasks: { task: T; callback: Callback }[]; constructor( - private workerPath: URL, + protected workerPath: string | URL, + protected workerData: W, numThreads: number, ) { super(); @@ -50,8 +56,10 @@ export default class WorkerPool extends EventEmitter { } protected addNewWorker() { - const worker = new Worker(this.workerPath); - this.prepareWorker(worker); + const worker = new Worker(this.workerPath, { + execArgv: isTsNode() ? ["--loader", "ts-node/esm"] : undefined, + workerData: this.workerData, + }); worker.on("message", (result: R) => { // In case of success: Call the callback that was passed to `runTask`, // remove the `TaskInfo` associated with the Worker, and mark it as free @@ -81,9 +89,6 @@ export default class WorkerPool extends EventEmitter { this.emit(kWorkerFreedEvent); } - // eslint-disable-next-line @typescript-eslint/no-unused-vars - protected prepareWorker(worker: Worker) {} - protected runTask(task: T, callback: Callback) { const worker = this.freeWorkers.pop(); if (!worker) {