-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo1.component.ts
122 lines (97 loc) · 3.64 KB
/
demo1.component.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import {Component, ViewChild} from '@angular/core';
import {DEFAULT_LEARNING_RATE, DEFAULT_NN_LAYERS, Point} from "../../workers/demo1/nn.worker.consts";
import * as fileInteraction from '../../utils/file-interaction';
import {PlotDrawerComponent} from "../../components/plot-drawer/plot-drawer.component";
import {NeuralNetworkDrawerComponent} from "../../components/neural-network-drawer/neural-network-drawer.component";
@Component({
selector: 'demo1-page',
templateUrl: './demo1.component.html',
styleUrls: ['./demo1.component.css']
})
export class Demo1Component {
@ViewChild('plotDrawer')
plotDrawer!: PlotDrawerComponent;
@ViewChild('nnDrawer')
nnDrawer!: NeuralNetworkDrawerComponent;
private nnWorker!: Worker;
defaultPointType: number = 0;
layersConfig: string = DEFAULT_NN_LAYERS.join(" ");
learningRate: number = DEFAULT_LEARNING_RATE;
currentEpoch: number = 0;
currentLoss: number = 1;
currentAccuracy: number = 0;
training: boolean = false;
points: any[] = [];
constructor() {
this.nnWorker = new Worker(new URL('../../workers/demo1/nn.worker', import.meta.url));
this.nnWorker.onmessage = ({data}) => {
switch (data.type) {
case "training_data":
this.currentEpoch = data.epoch;
this.currentLoss = data.loss;
this.currentAccuracy = data.accuracy;
this.training = data.isTraining;
this.plotDrawer.drawSnapshot(data.state, data.width, data.height);
this.nnDrawer.drawSnapshot(data.nnSnapshot);
break;
}
}
}
handleMouseEvent($event: MouseEvent) {
const element = $event.target as Element;
const rect = element.getBoundingClientRect();
const x = $event.clientX - rect.left, y = $event.clientY - rect.top;
const pointType = +($event.button > 0 || $event.altKey) || this.defaultPointType;
const point = {
x: x / element.clientWidth,
y: y / element.clientHeight,
type: pointType
}
this.plotDrawer.addPoint(point);
this.points.push(point);
this.nnWorker.postMessage({type: 'add_point', point: point});
}
refresh() {
const newLayersConfig = this.layersConfig.split(' ').map(v => Number.parseInt(v)).filter(v => !Number.isNaN(v));
this.nnWorker.postMessage({
type: "refresh", config: {
learningRate: this.learningRate,
layers: newLayersConfig
}
});
}
savePoints() {
fileInteraction.saveFile(JSON.stringify(this.points), 'points.json', 'application/json');
}
async loadPoints() {
const file = await fileInteraction.openFile('application/json', false) as File;
if (!file) {
return;
}
const content = await file.text();
this.setPoints(JSON.parse(content));
}
removePoints() {
this.setPoints([]);
}
setPoints(points: Point[]) {
this.points = points;
this.plotDrawer.clearPoints();
for (const point of this.points) {
this.plotDrawer.addPoint(point);
}
this.training = false;
this.currentEpoch = 0;
this.nnWorker.postMessage({type: "set_points", points: this.points});
}
onKeyEvent($event: KeyboardEvent) {
if ($event.key !== "Alt") {
return;
}
if ($event.type === "keydown") {
this.defaultPointType = 1;
} else if ($event.type === "keyup") {
this.defaultPointType = 0;
}
}
}