forked from JulianMH/NounPhraseJS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
getSentenceConfiguration.js
executable file
·110 lines (89 loc) · 5.44 KB
/
getSentenceConfiguration.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"use strict";
define(["nounphrasejs"], function(nounphrasejs) {
var DEFAULT_OPTIONS = {
hidden_unit_count: 30,
filter_count: 30,
convolution_radius: 2,
max_sentence_width: 60,
output_labels: 3,
lookup_table_learn_rate: 0.1,
trainer_learn_rate: 0.001,
trainer_l1_decay: 0,
trainer_l2_decay: 0.0001,
trainer_momentum: 0.5,
trainer_batch_size: 1
};
return function getSentenceConfiguration(options, dictionary) {
var network = new convnetjs.Net();
if(typeof(options) == 'string') { // Load from JSON if string is passed as argument.
var savedConfiguration = JSON.parse(options);
options = savedConfiguration.options;
dictionary = nounphrasejs.reviveDictionary(savedConfiguration.dictionary);
network.fromJSON(savedConfiguration.netJSON);
} else {
// Make sure options are complete
for(var key in DEFAULT_OPTIONS) {
if(DEFAULT_OPTIONS.hasOwnProperty(key) && !options.hasOwnProperty(key))
options[key] = DEFAULT_OPTIONS[key];
}
options.width_with_padding = options.max_sentence_width + (options.convolution_radius * 2);
options.convolution_width = (options.convolution_radius * 2);
network.makeLayers([
{ type: 'input', out_sx: dictionary.wordFeatureCount + 3, out_sy: options.width_with_padding, out_depth: 1 },
{ type: 'conv', sx: dictionary.wordFeatureCount + 3, sy: options.convolution_width, filters: options.filter_count, stride: 1 },
{ type: 'fc', num_neurons: options.hidden_unit_count, group_size: options.filter_count, activation: 'maxout' },
//{ type: 'fc', num_neurons: options.hidden_unit_count },
//{ type: 'fc', num_neurons: options.hidden_unit_count, activation: 'tanh' },
//{ type: 'fc', num_neurons: options.hidden_unit_count }),
{ type: 'softmax', num_classes: options.output_labels }]);
}
var trainer = new convnetjs.Trainer(network, {
method: 'sgd',
learning_rate: options.trainer_learn_rate,
l1_decay: options.trainer_l1_decay,
l2_decay: options.trainer_l2_decay,
momentum: options.trainer_momentum,
batch_size: options.trainer_batch_size,
});
function getSentenceWithLabelForWord(textCorpus, sentenceIndex, wordIndex) {
var sentence = textCorpus.sentences[sentenceIndex];
var sentenceFeatures = new convnetjs.Vol(dictionary.wordFeatureCount + 3, options.width_with_padding, 1);
for (var i = 0; i < options.width_with_padding; ++i) {
var currentWordPosition = i - options.convolution_radius;
var distanceToLabeledWord = (currentWordPosition - wordIndex) / sentence.length;
if (distanceToLabeledWord < -1)
distanceToLabeledWord = -1;
if (distanceToLabeledWord > 1)
distanceToLabeledWord = 1;
var isCapitalised = false;
var currentWordVol = textCorpus.dictionary.paddingVol;
if (currentWordPosition >= 0 && currentWordPosition < sentence.length) {
var currentWord = sentence[currentWordPosition];
isCapitalised = currentWord.isCapitalised;
currentWordVol = textCorpus.dictionary.wordIndicesToVols[currentWord.index];
}
for (var j = 0; j < dictionary.wordFeatureCount; ++j)
sentenceFeatures.set(j, i, 0, currentWordVol.get(j, 0, 0));
sentenceFeatures.set(dictionary.wordFeatureCount, i, 0, isCapitalised ? 1 : 0);
sentenceFeatures.set(dictionary.wordFeatureCount + 1, i, 0, currentWordPosition == wordIndex ? 0 : 1);
sentenceFeatures.set(dictionary.wordFeatureCount + 2, i, 0, distanceToLabeledWord);
}
return sentenceFeatures;
};
function learnFromSentenceWithLabelForWordGradients(textCorpus, sentenceIndex, wordIndex, sentenceWithLabel) {
var sentence = textCorpus.sentences[sentenceIndex];
for (var i = 0; i < options.width_with_padding; ++i) {
var currentWordPosition = i - options.convolution_radius;
var volToChange = textCorpus.dictionary.paddingVol;
if (currentWordPosition >= 0 && currentWordPosition < textCorpus.sentences[sentenceIndex].length) {
var currentWord = sentence[currentWordPosition];
volToChange = textCorpus.dictionary.wordIndicesToVols[currentWord.index];
for (var j = 0; j < dictionary.wordFeatureCount; ++j) {
volToChange.set(j, 0, 0, sentenceWithLabel.get(j, i, 0) - (options.lookup_table_learn_rate * sentenceWithLabel.get_grad(j, i, 0)));
}
}
}
};
return new nounphrasejs.NetworkConfiguration(network, trainer, getSentenceWithLabelForWord, learnFromSentenceWithLabelForWordGradients, dictionary, options);
};
});