-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.js
83 lines (72 loc) · 3.24 KB
/
generate.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// -------------
// -- imports --
// -------------
import { combineTopicEmbeddings } from "./modules/embedding.js";
import { labels } from "./labels-config.js";
import fs from 'fs';
console.log('\n\n\n\n');
// ------------------------------------------
// -- Clean the topic_embeddings directory --
// ------------------------------------------
const topicEmbeddingsDir = 'data/topic_embeddings';
fs.readdirSync(topicEmbeddingsDir)
.filter(file => file.endsWith('.json'))
.forEach(file => {
fs.unlinkSync(`${topicEmbeddingsDir}/${file}`);
console.log(`Deleted: ${file}`);
});
console.log('\nCleaned topic_embeddings directory\n');
// ---------------------------------------------------------------------------------
// -- Load `data/training_data.jsonl` and get all the phrases for the label --
// ---------------------------------------------------------------------------------
const allTrainPositives = fs.readFileSync('data/training_data.jsonl', 'utf8');
const allTrainPositivesArray = allTrainPositives.split('\n')
.map(line => {
try {
return JSON.parse(line);
} catch (e) {
// console.error('Failed to parse JSON:', line.slice(0, 100) + '...');
return null; // Return null instead of undefined
}
})
.filter(item => item !== null); // Remove null entries before processing
// ---------------------------------------------------
// -- Generate the topic average weighted embedding --
// ---------------------------------------------------
async function generateTopicEmbedding(label) {
const existingEmbedding = null;
const existingCount = 0;
const topicName = label.label;
const threshold = label.threshold;
const newPhrases = allTrainPositivesArray
.filter(item => item.label.toLowerCase() === topicName.toLowerCase())
.map(item => item.text);
if (newPhrases.length === 0) {
console.log(`No training data found for topic "${topicName}" - skipping embedding generation`);
return;
}
let topicEmbedding;
try {
topicEmbedding = await combineTopicEmbeddings(existingEmbedding, existingCount, newPhrases);
const dataObject = {
topic: topicName,
threshold: threshold,
numPhrases: newPhrases.length,
embeddingModel: process.env.ONNX_EMBEDDING_MODEL,
modelPrecision: process.env.ONNX_EMBEDDING_MODEL_PRECISION,
embedding: Array.isArray(topicEmbedding) ? topicEmbedding : Object.values(topicEmbedding)
};
const dataString = JSON.stringify(dataObject, null, 2);
fs.writeFileSync(`data/topic_embeddings/${topicName}.json`, dataString, { flag: 'w' });
console.log(`Topic embedding for ${topicName} generated successfully`);
} catch (error) {
console.error(`Error generating topic embedding for "${topicName}":`, error);
process.exit(1);
}
}
// ------------------------------------------------------------------
// -- Loop through labels and generate average weighted embeddings --
// ------------------------------------------------------------------
for (const label of labels) {
generateTopicEmbedding(label);
}