From 63af1f957e5a38a2066b2b826945cfc269bff142 Mon Sep 17 00:00:00 2001 From: femshima <49227365+femshima@users.noreply.github.com> Date: Fri, 20 Oct 2023 19:11:41 +0900 Subject: [PATCH] add worker pool and task --- build.js | 10 +++ src/synthesis.ts | 74 ++++++++++++++++++-- src/synthesizer/worker-pool.ts | 121 +++++++++++++++++++++++++++++++++ src/synthesizer/worker-task.ts | 45 ++++++++++++ 4 files changed, 245 insertions(+), 5 deletions(-) create mode 100644 src/synthesizer/worker-pool.ts create mode 100644 src/synthesizer/worker-task.ts diff --git a/build.js b/build.js index 8d4a2ee5..205c760a 100644 --- a/build.js +++ b/build.js @@ -10,3 +10,13 @@ await build({ external: Object.keys(packageJson.dependencies), minify: true, }); + +await build({ + entryPoints: ["src/synthesizer/worker-task.ts"], + bundle: true, + outfile: "dist/worker-task.js", + platform: "node", + format: "esm", + external: Object.keys(packageJson.dependencies), + minify: true, +}); diff --git a/src/synthesis.ts b/src/synthesis.ts index a89b6abd..0cd62930 100644 --- a/src/synthesis.ts +++ b/src/synthesis.ts @@ -1,8 +1,72 @@ -import { Readable } from "stream"; -import { AudioResource, createAudioResource } from "@discordjs/voice"; +import { Readable } from "node:stream"; +import { + AudioResource, + StreamType, + createAudioResource, +} from "@discordjs/voice"; import { Message } from "discord.js"; +import SynthesizeWorkerPool from "./synthesizer/worker-pool"; -// eslint-disable-next-line @typescript-eslint/require-await -export async function synthesis(_message: Message): Promise { - return createAudioResource(new Readable()); +const pool = + process.env.DICTIONARY && process.env.MODEL + ? new SynthesizeWorkerPool( + { + dictionary: process.env.DICTIONARY, + model: process.env.MODEL, + }, + process.env.NUM_THREADS ? Number(process.env.NUM_THREADS) : undefined, + ) + : undefined; + +export async function synthesis(message: Message): Promise { + if (!pool) throw new Error("Please provide path to the dictionary and model"); + + const content = + message.cleanContent.length > 200 + ? `${message.cleanContent.slice(0, 190)} 以下略` + : message.cleanContent; + + const data = await pool.synthesize(content, { + samplingFrequency: 48000, + }); + + return createAudioResource(new SynthesizedSoundStream(data), { + inputType: StreamType.Raw, + }); +} + +class SynthesizedSoundStream extends Readable { + private pos: number = 0; + private buf: Int16Array | null; + constructor(buf: Int16Array) { + super(); + this.buf = buf; + } + _read(size: number = ((48000 * 2 * 2) / 1000) * 20) { + if (!this.buf) { + throw new Error("Stream ended"); + } + + const offset = this.pos; + let end = Math.ceil(size / 4); + if (end + offset > this.buf.length) { + end = this.buf.length - offset; + } + const buf = Buffer.alloc(end * 4); + const dst = new Int16Array(buf.buffer); + for (let i = 0; i < end; ++i) { + const elem = this.buf[i + offset]; + dst[i * 2] = elem; + dst[i * 2 + 1] = elem; + } + this.push(buf); + this.pos += end; + if (this.pos == this.buf.length) { + this.buf = null; + this.push(null); + } + } + _destroy() { + this.buf = null; + } } diff --git a/src/synthesizer/worker-pool.ts b/src/synthesizer/worker-pool.ts new file mode 100644 index 00000000..f62aa878 --- /dev/null +++ b/src/synthesizer/worker-pool.ts @@ -0,0 +1,121 @@ +import { AsyncResource } from "node:async_hooks"; +import { EventEmitter } from "node:events"; +import * as util from "node:util"; +import { Worker } from "node:worker_threads"; +import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding"; +import { Result, Task } from "./worker-task"; + +const kWorkerFreedEvent = Symbol("kWorkerFreedEvent"); + +type Callback = (err: Error | null, result: Result | null) => void; + +class WorkerPoolTaskInfo extends AsyncResource { + constructor(private callback: Callback) { + super("WorkerPoolTaskInfo"); + } + + done(err: Error, result: null): void; + done(err: null, result: Result): void; + done(err: Error | null, result: Result | null) { + this.runInAsyncScope(this.callback, null, err, result); + this.emitDestroy(); + } +} + +export default class SynthesizeWorkerPool extends EventEmitter { + workers: Worker[]; + freeWorkers: Worker[]; + workerInfo = new Map(); + tasks: { task: Task; callback: Callback }[]; + + constructor( + private config: AltJTalkConfig, + numThreads?: number, + ) { + super(); + this.workers = []; + this.freeWorkers = []; + this.tasks = []; + + const threads = numThreads ?? 1; + + for (let i = 0; i < threads; i++) this.addNewWorker(); + + // Any time the kWorkerFreedEvent is emitted, dispatch + // the next task pending in the queue, if any. + this.on(kWorkerFreedEvent, () => { + const nextTask = this.tasks.shift(); + if (nextTask) { + const { task, callback } = nextTask; + this.runTask(task, callback); + } + }); + } + + addNewWorker() { + const worker = new Worker(new URL("worker-task.js", import.meta.url)); + worker.postMessage({ + type: "setup", + config: this.config, + } satisfies Task); + + worker.on("message", (result: Result) => { + // In case of success: Call the callback that was passed to `runTask`, + // remove the `TaskInfo` associated with the Worker, and mark it as free + // again. + const info = this.workerInfo.get(worker.threadId); + if (!info) return; + + info.done(null, result); + this.workerInfo.delete(worker.threadId); + + this.freeWorkers.push(worker); + this.emit(kWorkerFreedEvent); + }); + worker.on("error", (err) => { + // In case of an uncaught exception: Call the callback that was passed to + // `runTask` with the error. + const info = this.workerInfo.get(worker.threadId); + if (info) info.done(err, null); + else this.emit("error", err); + // Remove the worker from the list and start a new Worker to replace the + // current one. + this.workers.splice(this.workers.indexOf(worker), 1); + this.addNewWorker(); + }); + this.workers.push(worker); + this.freeWorkers.push(worker); + this.emit(kWorkerFreedEvent); + } + + private runTask(task: Task, callback: Callback) { + const worker = this.freeWorkers.pop(); + if (!worker) { + // No free threads, wait until a worker thread becomes free. + this.tasks.push({ task, callback }); + return; + } + + const info = new WorkerPoolTaskInfo(callback); + this.workerInfo.set(worker.threadId, info); + worker.postMessage(task); + } + + public async synthesize( + inputText: string, + option: SynthesisOption, + ): Promise { + const result = await util.promisify(this.runTask.bind(this))({ + type: "task", + inputText, + option, + }); + if (!result) throw new Error("Task returned error!"); + + return result?.data; + } + + async close() { + for (const worker of this.workers) await worker.terminate(); + } +} diff --git a/src/synthesizer/worker-task.ts b/src/synthesizer/worker-task.ts new file mode 100644 index 00000000..36d6e3d8 --- /dev/null +++ b/src/synthesizer/worker-task.ts @@ -0,0 +1,45 @@ +import { parentPort } from "node:worker_threads"; +import { + AltJTalk, + AltJTalkConfig, + SynthesisOption, +} from "node-altjtalk-binding"; + +export type Task = + | { + type: "setup"; + config: AltJTalkConfig; + } + | { + type: "task"; + inputText: string; + option: SynthesisOption; + }; + +export type Result = { + type: "task"; + data: Int16Array; +}; + +let synthesizer: AltJTalk | undefined = undefined; + +if (parentPort) { + parentPort.on("message", (task: Task) => { + if (!parentPort) return; + + switch (task.type) { + case "setup": + synthesizer = AltJTalk.fromConfig(task.config); + break; + case "task": { + if (!synthesizer) throw new Error("Synthesizer is not initialized!"); + const data = synthesizer.synthesize(task.inputText, task.option); + parentPort.postMessage({ + type: "task", + data, + } satisfies Result); + break; + } + } + }); +}