Skip to content

Commit

Permalink
fix for ts-node
Browse files Browse the repository at this point in the history
  • Loading branch information
phenylshima committed Oct 21, 2023
1 parent 38a1845 commit 2643164
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 69 deletions.
22 changes: 22 additions & 0 deletions src/synthesis/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding";

export type Task = {
inputText: string;
option: SynthesisOption;
};

export type Result = {
type: "task";
data: Int16Array;
};

export function isAltJTalkConfigValid(arg: unknown): arg is AltJTalkConfig {
return (
typeof arg === "object" &&
arg !== null &&
"dictionary" in arg &&
"model" in arg &&
typeof arg.dictionary === "string" &&
typeof arg.model === "string"
);
}
27 changes: 6 additions & 21 deletions src/synthesis/index.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,30 @@
import { Readable } from "node:stream";
import * as util from "node:util";
import { Worker } from "node:worker_threads";
import {
AudioResource,
StreamType,
createAudioResource,
} from "@discordjs/voice";
import { Message } from "discord.js";
import { AltJTalkConfig, SynthesisOption } from "node-altjtalk-binding";
import { Result, Task } from "./task";
import { Result, Task } from "./common";
import WorkerPool from "./worker-pool";

class SynthesizeWorkerPool extends WorkerPool<Task, Result> {
constructor(
private config: AltJTalkConfig,
numThreads?: number,
) {
super(new URL("task.js", import.meta.url), numThreads ?? 1);
}

protected override prepareWorker(worker: Worker): void {
worker.postMessage({
type: "setup",
config: this.config,
} satisfies Task);
class SynthesizeWorkerPool extends WorkerPool<Task, Result, AltJTalkConfig> {
constructor(config: AltJTalkConfig, numThreads?: number) {
super(new URL("task", import.meta.url), config, numThreads ?? 1);
}

public async synthesize(
inputText: string,
option: SynthesisOption,
): Promise<Int16Array> {
const result = await util.promisify(this.runTask.bind(this))({
type: "task",
inputText,
option,
});
if (!result) throw new Error("Task returned error!");
if (result.type !== "task")
throw new Error("Task returned wrong type of response!");

return result?.data;
if (result) return result?.data;
else throw new Error("Task returned error!");
}
}

Expand Down
48 changes: 11 additions & 37 deletions src/synthesis/task.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,19 @@
import { parentPort } from "node:worker_threads";
import {
AltJTalk,
AltJTalkConfig,
SynthesisOption,
} from "node-altjtalk-binding";
import { parentPort, workerData } from "node:worker_threads";
import { AltJTalk } from "node-altjtalk-binding";
import { Result, Task, isAltJTalkConfigValid } from "./common";

export type Task =
| {
type: "setup";
config: AltJTalkConfig;
}
| {
type: "task";
inputText: string;
option: SynthesisOption;
};
if (!isAltJTalkConfigValid(workerData))
throw new Error("AltJTalk config is invalid.");

export type Result = {
type: "task";
data: Int16Array;
};

let synthesizer: AltJTalk | undefined = undefined;
const synthesizer: AltJTalk = AltJTalk.fromConfig(workerData);

if (parentPort) {
parentPort.on("message", (task: Task) => {
if (!parentPort) return;

switch (task.type) {
case "setup":
synthesizer = AltJTalk.fromConfig(task.config);
break;
case "task": {
if (!synthesizer) throw new Error("Synthesizer is not initialized!");
const data = synthesizer.synthesize(task.inputText, task.option);
parentPort.postMessage({
type: "task",
data,
} satisfies Result);
break;
}
}
const data = synthesizer.synthesize(task.inputText, task.option);
parentPort.postMessage({
type: "task",
data,
} satisfies Result);
});
}
27 changes: 16 additions & 11 deletions src/synthesis/worker-pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import { AsyncResource } from "node:async_hooks";
import { EventEmitter } from "node:events";
import { Worker } from "node:worker_threads";

const isTsNode = () => {
const tsNodeSymbol = Symbol.for("ts-node.register.instance");
return tsNodeSymbol in process && !!process[tsNodeSymbol];
};

const kWorkerFreedEvent = Symbol("kWorkerFreedEvent");

type Callback<R> = (err: Error | null, result: R | null) => void;
Expand All @@ -21,14 +26,15 @@ class WorkerPoolTaskInfo<R> extends AsyncResource {
}
}

export default class WorkerPool<T, R> extends EventEmitter {
workers: Worker[];
freeWorkers: Worker[];
workerInfo = new Map<number, WorkerPoolTaskInfo<R>>();
tasks: { task: T; callback: Callback<R> }[];
export default class WorkerPool<T, R, W> extends EventEmitter {
private workers: Worker[];
private freeWorkers: Worker[];
private workerInfo = new Map<number, WorkerPoolTaskInfo<R>>();
private tasks: { task: T; callback: Callback<R> }[];

constructor(
private workerPath: URL,
protected workerPath: string | URL,
protected workerData: W,
numThreads: number,
) {
super();
Expand All @@ -50,8 +56,10 @@ export default class WorkerPool<T, R> extends EventEmitter {
}

protected addNewWorker() {
const worker = new Worker(this.workerPath);
this.prepareWorker(worker);
const worker = new Worker(this.workerPath, {
execArgv: isTsNode() ? ["--loader", "ts-node/esm"] : undefined,
workerData: this.workerData,
});
worker.on("message", (result: R) => {
// In case of success: Call the callback that was passed to `runTask`,
// remove the `TaskInfo` associated with the Worker, and mark it as free
Expand Down Expand Up @@ -81,9 +89,6 @@ export default class WorkerPool<T, R> extends EventEmitter {
this.emit(kWorkerFreedEvent);
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
protected prepareWorker(worker: Worker) {}

protected runTask(task: T, callback: Callback<R>) {
const worker = this.freeWorkers.pop();
if (!worker) {
Expand Down

0 comments on commit 2643164

Please sign in to comment.