Skip to content

Commit

Permalink
refactor worker pool
Browse files Browse the repository at this point in the history
  • Loading branch information
phenylshima committed Oct 20, 2023
1 parent 83f3fb4 commit 2ecb7f1
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 64 deletions.
84 changes: 60 additions & 24 deletions src/synthesis/index.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,46 @@
import { Readable } from "node:stream";
import * as util from "node:util";
import { Worker } from "node:worker_threads";
import {
AudioResource,
StreamType,
createAudioResource,
} from "@discordjs/voice";
import { Message } from "discord.js";
import SynthesizeWorkerPool from "./worker-pool";
import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding";
import WorkerPool from "./worker-pool";
import { Result, Task } from "./worker-task";

const pool =
process.env.DICTIONARY && process.env.MODEL
? new SynthesizeWorkerPool(
{
dictionary: process.env.DICTIONARY,
model: process.env.MODEL,
},
process.env.NUM_THREADS ? Number(process.env.NUM_THREADS) : undefined,
)
: undefined;

export async function synthesis(message: Message): Promise<AudioResource> {
if (!pool) throw new Error("Please provide path to the dictionary and model");
class SynthesizeWorkerPool extends WorkerPool<Task, Result> {
constructor(
private config: AltJTalkConfig,
numThreads?: number,
) {
super(new URL("worker-task.js", import.meta.url), numThreads ?? 1);
}

const content =
message.cleanContent.length > 200
? `${message.cleanContent.slice(0, 190)} 以下略`
: message.cleanContent;
protected override prepareWorker(worker: Worker): void {
worker.postMessage({
type: "setup",
config: this.config,
} satisfies Task);
}

const data = await pool.synthesize(content, {
samplingFrequency: 48000,
});
public async synthesize(
inputText: string,
option: SynthesisOption,
): Promise<Int16Array> {
const result = await util.promisify(this.runTask.bind(this))({
type: "task",
inputText,
option,
});
if (!result) throw new Error("Task returned error!");
if (result.type !== "task")
throw new Error("Task returned wrong type of response!");

return createAudioResource(new SynthesizedSoundStream(data), {
inputType: StreamType.Raw,
});
return result?.data;
}
}

class SynthesizedSoundStream extends Readable {
Expand Down Expand Up @@ -70,3 +78,31 @@ class SynthesizedSoundStream extends Readable {
this.buf = null;
}
}

const pool =
process.env.DICTIONARY && process.env.MODEL
? new SynthesizeWorkerPool(
{
dictionary: process.env.DICTIONARY,
model: process.env.MODEL,
},
process.env.NUM_THREADS ? Number(process.env.NUM_THREADS) : undefined,
)
: undefined;

export async function synthesis(message: Message): Promise<AudioResource> {
if (!pool) throw new Error("Please provide path to the dictionary and model");

const content =
message.cleanContent.length > 200
? `${message.cleanContent.slice(0, 190)} 以下略`
: message.cleanContent;

const data = await pool.synthesize(content, {
samplingFrequency: 48000,
});

return createAudioResource(new SynthesizedSoundStream(data), {
inputType: StreamType.Raw,
});
}
60 changes: 20 additions & 40 deletions src/synthesis/worker-pool.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,40 @@
import { AsyncResource } from "node:async_hooks";
import { EventEmitter } from "node:events";
import * as util from "node:util";
import { Worker } from "node:worker_threads";
import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding";
import { Result, Task } from "./worker-task";

const kWorkerFreedEvent = Symbol("kWorkerFreedEvent");

type Callback = (err: Error | null, result: Result | null) => void;
type Callback<R> = (err: Error | null, result: R | null) => void;

class WorkerPoolTaskInfo extends AsyncResource {
constructor(private callback: Callback) {
class WorkerPoolTaskInfo<R> extends AsyncResource {
constructor(private callback: Callback<R>) {
super("WorkerPoolTaskInfo");
}

done(err: Error, result: null): void;
done(err: null, result: Result): void;
done(err: Error | null, result: Result | null) {
done(err: null, result: R): void;
done(err: Error | null, result: R | null) {
this.runInAsyncScope(this.callback, null, err, result);
this.emitDestroy();
}
}

export default class SynthesizeWorkerPool extends EventEmitter {
export default class WorkerPool<T, R> extends EventEmitter {
workers: Worker[];
freeWorkers: Worker[];
workerInfo = new Map<number, WorkerPoolTaskInfo>();
tasks: { task: Task; callback: Callback }[];
workerInfo = new Map<number, WorkerPoolTaskInfo<R>>();
tasks: { task: T; callback: Callback<R> }[];

constructor(
private config: AltJTalkConfig,
numThreads?: number,
private workerPath: URL,
numThreads: number,
) {
super();
this.workers = [];
this.freeWorkers = [];
this.tasks = [];

const threads = numThreads ?? 1;

for (let i = 0; i < threads; i++) this.addNewWorker();
for (let i = 0; i < numThreads; i++) this.addNewWorker();

// Any time the kWorkerFreedEvent is emitted, dispatch
// the next task pending in the queue, if any.
Expand All @@ -52,14 +47,10 @@ export default class SynthesizeWorkerPool extends EventEmitter {
});
}

addNewWorker() {
const worker = new Worker(new URL("worker-task.js", import.meta.url));
worker.postMessage({
type: "setup",
config: this.config,
} satisfies Task);

worker.on("message", (result: Result) => {
protected addNewWorker() {
const worker = new Worker(this.workerPath);
this.prepareWorker(worker);
worker.on("message", (result: R) => {
// In case of success: Call the callback that was passed to `runTask`,
// remove the `TaskInfo` associated with the Worker, and mark it as free
// again.
Expand Down Expand Up @@ -88,7 +79,10 @@ export default class SynthesizeWorkerPool extends EventEmitter {
this.emit(kWorkerFreedEvent);
}

private runTask(task: Task, callback: Callback) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
protected prepareWorker(worker: Worker) {}

protected runTask(task: T, callback: Callback<R>) {
const worker = this.freeWorkers.pop();
if (!worker) {
// No free threads, wait until a worker thread becomes free.
Expand All @@ -101,21 +95,7 @@ export default class SynthesizeWorkerPool extends EventEmitter {
worker.postMessage(task);
}

public async synthesize(
inputText: string,
option: SynthesisOption,
): Promise<Int16Array> {
const result = await util.promisify(this.runTask.bind(this))({
type: "task",
inputText,
option,
});
if (!result) throw new Error("Task returned error!");

return result?.data;
}

async close() {
public async close() {
for (const worker of this.workers) await worker.terminate();
}
}

0 comments on commit 2ecb7f1

Please sign in to comment.