-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.js
77 lines (66 loc) · 2.08 KB
/
train.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
const { BayesClassifier } = require('natural');
const { readJson } = require('fs-extra');
const { join } = require('path');
const { hashtags, dataPath } = require('./local-constants');
async function main() {
var data = await loadData(...hashtags);
var classifier = train(data);
var results = predict(classifier, data);
console.log(results)
var total = results.positives + results.negatives
var accuracy = results.positives/total;
var error = (total - results.positives)/total;
console.log('Accuracy: %f\%', (accuracy*100).toFixed(2))
console.log('Error: %f\%', (error*100).toFixed(2))
}
async function loadData(...hashtags) {
var data = [];
for(var tag of hashtags) {
var path = join(dataPath, `${tag}.json`);
var tagData = await readJson(path);
data = data.concat(tagData.map(e => { e['label'] = tag; return e }));
}
return data;
}
function train(data) {
var classifier = new BayesClassifier();
for(var tweet of data) {
classifier.addDocument(tweet.text, tweet.label);
}
classifier.train()
return classifier;
}
function accuracy(classifier, data) {
const count = data.length;
const sets = 10;
const step = Math.ceil(count/sets);
var results = {
positives: 0,
negatives: 0,
}
for(var start=0,end=step; start<=count; start=end,end+=step) {
var trainingSet = data.slice(0, start).concat(data.slice(end));
var testingSet = data.slice(start, end);
var classifier = train(trainingSet);
var setResults = predict(classifier, testingSet);
results.positives += setResults.positives;
results.negatives += setResults.negatives;
}
return results;
}
function predict(classifier, testingSet) {
var results = {
positives: 0,
negatives: 0,
}
for(var tweet of testingSet) {
var classification = classifier.classify(tweet.text);
if(classification == tweet.label) {
results.positives++;
} else{
results.negatives++;
}
}
return results;
}
main();