forked from tensorflow/tfjs-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_node.js
126 lines (109 loc) · 3.98 KB
/
gen_node.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Use a trained next-character prediction model to generate some text.
*/
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import * as argparse from 'argparse';
import * as tf from '@tensorflow/tfjs';
import {maybeDownload, TextData, TEXT_DATA_URLS} from './data';
import {generateText} from './model';
function parseArgs() {
const parser = argparse.ArgumentParser({
description: 'Train an lstm-text-generation model.'
});
parser.addArgument('textDatasetNameOrPath', {
type: 'string',
help: 'Name of the text dataset (one of ' +
Object.keys(TEXT_DATA_URLS).join(', ') +
') or the path to a text file containing a custom dataset'
});
parser.addArgument('modelJSONPath', {
type: 'string',
help: 'Path to the trained next-char prediction model saved on disk ' +
'(e.g., ./my-model/model.json)'
});
parser.addArgument('--genLength', {
type: 'int',
defaultValue: 200,
help: 'Length of the text to generate.'
});
parser.addArgument('--temperature', {
type: 'float',
defaultValue: 0.5,
help: 'Temperature value to use for text generation. Higher values ' +
'lead to more random-looking generation results.'
});
parser.addArgument('--gpu', {
action: 'storeTrue',
help: 'Use CUDA GPU for training.'
});
parser.addArgument('--sampleStep', {
type: 'int',
defaultValue: 3,
help: 'Step length: how many characters to skip between one example ' +
'extracted from the text data to the next.'
});
const args = parser.parseArgs();
const isDataset = TEXT_DATA_URLS[args.textDatasetNameOrPath];
const isFile = fs.existsSync(args.textDatasetNameOrPath)
&& fs.statSync(args.textDatasetNameOrPath).isFile();
if (isDataset) {
args.textDatasetName = args.textDatasetNameOrPath;
delete args.textDatasetNameOrPath;
} else if (isFile) {
args.textDatasetPath = args.textDatasetNameOrPath;
delete args.textDatasetNameOrPath;
} else {
parser.error('Argument should be one of ' +
Object.keys(TEXT_DATA_URLS).join(', ') +
' or the path to a dataset text file');
}
return args;
}
async function main() {
const args = parseArgs();
if (args.gpu) {
console.log('Using GPU');
require('@tensorflow/tfjs-node-gpu');
} else {
console.log('Using CPU');
require('@tensorflow/tfjs-node');
}
// Load the model.
const loadModel = tf.loadModel || tf.loadLayersModel;
const model = await loadModel(`file://${args.modelJSONPath}`);
const sampleLen = model.inputs[0].shape[1];
// Create the text data object.
let localTextDataPath = args.textDatasetPath;
if (args.textDatasetName) {
const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url;
localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL));
await maybeDownload(textDataURL, localTextDataPath);
}
const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'});
const textData = new TextData('text-data', text, sampleLen, args.sampleStep);
// Get a seed text from the text data object.
const [seed, seedIndices] = textData.getRandomSlice();
console.log(`Seed text:\n"${seed}"\n`);
const generated = await generateText(
model, textData, seedIndices, args.genLength, args.temperature);
console.log(`Generated text:\n"${generated}"\n`);
}
main();