From 8830ba7637b505315a5c39ba959ac8c9db51f0f5 Mon Sep 17 00:00:00 2001 From: Alexander Date: Wed, 20 Sep 2023 18:32:16 +0500 Subject: [PATCH] Improve Demo 3, some changes in Demo 2 --- package.json | 1 + src/app/app.module.ts | 2 + .../loading-screen.component.css | 26 ++++ .../loading-screen.component.html | 11 ++ .../loading-screen.component.ts | 18 +++ src/app/pages/demo2/demo2.component.html | 13 +- src/app/pages/demo2/demo2.component.ts | 3 + src/app/pages/demo3/demo3.component.html | 22 +-- src/app/pages/demo3/demo3.component.ts | 123 ++++++++------- src/app/workers/demo3/demo3.worker.ts | 140 ++++++++++++++++++ src/styles.css | 2 +- 11 files changed, 269 insertions(+), 92 deletions(-) create mode 100644 src/app/components/loading-screen/loading-screen.component.css create mode 100644 src/app/components/loading-screen/loading-screen.component.html create mode 100644 src/app/components/loading-screen/loading-screen.component.ts create mode 100644 src/app/workers/demo3/demo3.worker.ts diff --git a/package.json b/package.json index a892521..7f159a4 100644 --- a/package.json +++ b/package.json @@ -26,6 +26,7 @@ "pngjs": "^6.0.0", "rxjs": "~6.6.0", "text-graph.js": "^1.0.3", + "threads": "^1.7.0", "tslib": "^2.6.1", "zone.js": "~0.13.1" }, diff --git a/src/app/app.module.ts b/src/app/app.module.ts index 865a0a4..67cafc6 100644 --- a/src/app/app.module.ts +++ b/src/app/app.module.ts @@ -14,6 +14,7 @@ import {PlotDrawerComponent} from './components/plot-drawer/plot-drawer.componen import {NeuralNetworkDrawerComponent} from './components/neural-network-drawer/neural-network-drawer.component'; import {BinaryImageDrawerComponent} from './components/binary-image-drawer/binary-image-drawer.component'; import {ColorSelectorComponent} from './components/color-selector/color-selector.component'; +import { LoadingScreenComponent } from './components/loading-screen/loading-screen.component'; const routes: Routes = [ {path: 'demo1', component: Demo1Component}, @@ -33,6 +34,7 @@ const routes: Routes = [ Demo2Component, Demo3Component, ColorSelectorComponent, + LoadingScreenComponent, ], imports: [ BrowserModule, diff --git a/src/app/components/loading-screen/loading-screen.component.css b/src/app/components/loading-screen/loading-screen.component.css new file mode 100644 index 0000000..cfc7295 --- /dev/null +++ b/src/app/components/loading-screen/loading-screen.component.css @@ -0,0 +1,26 @@ +.loading-screen { + display: none; + position: fixed; + left: 0; + top: 0; + bottom: 0; + right: 0; + background: rgba(0, 0, 0, 0.4); + z-index: 100000; +} + +.loading-screen.visible { + display: block; +} + +.loading-screen .loading-screen-progress { + position: absolute; + top: 50%; + left: 50%; + width: 33.3%; + transform: translate(-50%, -50%) !important; +} + +.loading-screen .progress { + height: 30px; +} \ No newline at end of file diff --git a/src/app/components/loading-screen/loading-screen.component.html b/src/app/components/loading-screen/loading-screen.component.html new file mode 100644 index 0000000..34eae11 --- /dev/null +++ b/src/app/components/loading-screen/loading-screen.component.html @@ -0,0 +1,11 @@ +
+
+

{{label}}

+
+
+ {{converter(current)}}/{{converter(total)}} +
+
+
+
\ No newline at end of file diff --git a/src/app/components/loading-screen/loading-screen.component.ts b/src/app/components/loading-screen/loading-screen.component.ts new file mode 100644 index 0000000..5d2f66b --- /dev/null +++ b/src/app/components/loading-screen/loading-screen.component.ts @@ -0,0 +1,18 @@ +import {Component, Input} from '@angular/core'; + +export type LoadingConverterFn = (value: number) => string; + +@Component({ + selector: 'loading-screen', + templateUrl: './loading-screen.component.html', + styleUrls: ['./loading-screen.component.css'] +}) +export class LoadingScreenComponent { + @Input() label = "Loading..."; + @Input() visible: boolean = false; + + @Input() current: number = 0 + @Input() total: number = 0; + + @Input() converter: LoadingConverterFn = (value) => value.toString(); +} diff --git a/src/app/pages/demo2/demo2.component.html b/src/app/pages/demo2/demo2.component.html index c830d56..9099ab9 100644 --- a/src/app/pages/demo2/demo2.component.html +++ b/src/app/pages/demo2/demo2.component.html @@ -75,14 +75,5 @@

Generated

-
-
-

Loading training data...

-
-
- {{fileProcessingCurrent}}/{{fileProcessingTotal}} -
-
-
-
\ No newline at end of file + + \ No newline at end of file diff --git a/src/app/pages/demo2/demo2.component.ts b/src/app/pages/demo2/demo2.component.ts index 0c712ab..81a2d8d 100644 --- a/src/app/pages/demo2/demo2.component.ts +++ b/src/app/pages/demo2/demo2.component.ts @@ -9,6 +9,7 @@ import * as image from "../../utils/image"; import * as matrix from "../../neural-network/engine/matrix"; import {HttpClient} from "@angular/common/http"; import {AsyncSubject} from "rxjs"; +import {ProgressUtils} from "../../neural-network/neural-network"; @Component({ selector: 'app-demo2', @@ -126,6 +127,7 @@ export class Demo2Component { this.nnWorker.postMessage({type: "set_data", data: result}); } finally { this.fileLoading = false; + this.fileProcessingTotal = 0; } } @@ -173,4 +175,5 @@ export class Demo2Component { this.nnWorker.postMessage({type: "load_dump", dump}); } + protected readonly ProgressUtils = ProgressUtils; } \ No newline at end of file diff --git a/src/app/pages/demo3/demo3.component.html b/src/app/pages/demo3/demo3.component.html index bf059a0..5ecfca1 100644 --- a/src/app/pages/demo3/demo3.component.html +++ b/src/app/pages/demo3/demo3.component.html @@ -1,6 +1,6 @@
-
+

Source image

Source image
-

Generated Image

+

+ Generated Image + +

{{modelDetails}}

@@ -28,16 +31,5 @@

Generated Image

- -
-
-

Loading...

-
-
- {{progress.converter(progress.loaded)}}/{{progress.converter(progress.total)}} -
-
-
-
\ No newline at end of file + \ No newline at end of file diff --git a/src/app/pages/demo3/demo3.component.ts b/src/app/pages/demo3/demo3.component.ts index 7e736c6..9384c00 100644 --- a/src/app/pages/demo3/demo3.component.ts +++ b/src/app/pages/demo3/demo3.component.ts @@ -1,17 +1,15 @@ import {AfterViewInit, Component, ElementRef, ViewChild} from '@angular/core'; -import * as FileInteraction from "../../utils/file-interaction"; -import {IModel} from "../../neural-network/engine/base"; -import {FileAsyncReader, ObservableStreamLoader} from "../../neural-network/utils/fetch"; +import {spawn} from "threads" + import { - ChainModel, - CommonUtils, - ImageUtils, - ProgressUtils, - UniversalModelSerializer + FileAsyncReader, + ObservableStreamLoader, + CommonUtils, ProgressUtils, } from "../../neural-network/neural-network"; + +import {ModelParams, WorkerT} from "../../workers/demo3/demo3.worker"; import {BinaryImageDrawerComponent} from "../../components/binary-image-drawer/binary-image-drawer.component"; -import * as ColorUtils from "../../neural-network/utils/color"; -import {BinarySerializer} from "../../neural-network/serialization/binary"; +import * as FileInteraction from "../../utils/file-interaction"; @Component({ selector: 'app-demo3', @@ -22,29 +20,32 @@ export class Demo3Component implements AfterViewInit { @ViewChild('drawingCanvas', {static: true}) drawingCanvasRef!: ElementRef; - @ViewChild("generatedImage") generatedImage!: BinaryImageDrawerComponent; private drawingContext!: CanvasRenderingContext2D; - private model?: IModel; - private isDrawing = false; + private modelParams?: ModelParams; + + public isDrawing = false; + public isRendering = false; + + private renderingRequested = false; + + worker!: WorkerT; brushColor!: string; brushSize!: number modelDetails: string = "Load model to view details"; progress = { - offset: 0, loaded: 0, total: 0, - progressFn: ProgressUtils.throttle( - (loaded: number, _: number) => { - this.progress.loaded = this.progress.offset + loaded; - }, ProgressUtils.ValueLimit.inclusive, 300 - ), + progressFn: (loaded: number, total: number) => { + this.progress.loaded = loaded; + this.progress.total = total; + }, + reset: () => { - this.progress.offset = 0 this.progress.loaded = 0; this.progress.total = 0; }, @@ -63,7 +64,8 @@ export class Demo3Component implements AfterViewInit { } draw(event: MouseEvent) { - if (!this.isDrawing || event.button !== 0) return; + if (!this.isDrawing) return; + if (event.buttons !== 1) return this.stopDrawing(); const canvas = this.drawingCanvasRef.nativeElement; const rect = canvas.getBoundingClientRect(); @@ -79,6 +81,8 @@ export class Demo3Component implements AfterViewInit { } stopDrawing() { + if (!this.isDrawing) return; + this.isDrawing = false; this.updateModelPrediction(); } @@ -106,44 +110,23 @@ export class Demo3Component implements AfterViewInit { } async loadModel() { + if (!this.worker) { + this.worker = await spawn( + new Worker(new URL('../../workers/demo3/demo3.worker', import.meta.url)) + ); + } + + const progressSub = this.worker.progress() + .subscribe((data: any) => this.progress.progressFn(data.current, data.total)); + try { const files = await FileInteraction.openFile('application/json,.bin', true) as File[]; if (!files?.length) return; - this.progress.total = files.reduce((p, c) => p + c.size, 0); - - const chain = new ChainModel(); - for (const file of files) { - const reader = new FileAsyncReader(file); - const loader = new ObservableStreamLoader(reader, this.progress.progressFn); - - const data = await loader.load(); + const res = await this.worker.loadModel(files); + this.modelParams = res; + this.modelDetails = res.description; - if (file.name.endsWith(".json")) { - const config = JSON.parse(new TextDecoder().decode(data)); - const model = UniversalModelSerializer.load(config, true); - chain.addModel(model); - } else { - const model = BinarySerializer.load(data) - chain.addModel(model); - } - - this.progress.offset += file.size; - } - - this.model = chain; - this.model.compile(); - - const sizes = this.model.layers.map(l => { - const size = Math.sqrt(l.size); - if (Number.isInteger(size)) { - return `${size}²`; - } - - return l.size.toString(); - }); - - this.modelDetails = `${this.model.constructor.name} (${sizes.join(" -> ")})`; this.updateModelPrediction(); } catch (err) { if (err instanceof Error) { @@ -153,25 +136,34 @@ export class Demo3Component implements AfterViewInit { } } finally { this.progress.reset(); + progressSub.unsubscribe(); } } updateModelPrediction() { - if (!this.model) return; + if (!this.modelParams) return; - const size = Math.sqrt(this.model.inputSize); - const data = this.getImageDataFromCanvas(this.drawingCanvasRef.nativeElement, size, size); + if (!this.isRendering) { + this.isRendering = true; + this._updateModelPredictionImpl().finally(() => { + this.isRendering = false; - const input = ColorUtils.transformChannelCount(Array.from(data), 4, 3); - ColorUtils.transformColorSpace(ColorUtils.rgbToTanh, input, 3, input); + if (this.renderingRequested) { + this.renderingRequested = false; + this.updateModelPrediction(); + } + }); + } else if (!this.renderingRequested) { + this.renderingRequested = true; + } + } - const output3 = ImageUtils.processMultiChannelData(this.model, input, 3); - ColorUtils.transformColorSpace(ColorUtils.tanhToRgb, output3, 3, output3); + async _updateModelPredictionImpl() { + const size = Math.sqrt(this.modelParams!.inputSize); + const data = this.getImageDataFromCanvas(this.drawingCanvasRef.nativeElement, size, size); - const output4 = ColorUtils.transformChannelCount(output3, 3, 4); - const outData = new Uint8ClampedArray(output4); - const outSize = Math.sqrt(this.model.outputSize); - this.generatedImage.draw(outData.buffer, outSize, outSize); + const result = await this.worker.compute(data); + this.generatedImage.draw(result.buffer, result.size, result.size); } getImageDataFromCanvas(canvas: HTMLCanvasElement, targetWidth: number, targetHeight: number) { @@ -203,4 +195,5 @@ export class Demo3Component implements AfterViewInit { image.src = url; }); } + protected readonly ProgressUtils = ProgressUtils; } \ No newline at end of file diff --git a/src/app/workers/demo3/demo3.worker.ts b/src/app/workers/demo3/demo3.worker.ts new file mode 100644 index 0000000..75c716c --- /dev/null +++ b/src/app/workers/demo3/demo3.worker.ts @@ -0,0 +1,140 @@ +import {expose} from "threads/worker"; +import {Observable, Subject} from "threads/observable"; + +import {IModel} from "../../neural-network/engine/base"; +import {ProgressFn} from "../../neural-network/utils/fetch"; +import { + ChainModel, + UniversalModelSerializer, + BinarySerializer, + FileAsyncReader, + ObservableStreamLoader, + ImageUtils, + ColorUtils, + ProgressUtils +} from "../../neural-network/neural-network"; + +class Progress { + private _current = 0; + private _total = 0; + + private readonly _throttledSendProgress: ProgressFn; + + readonly subject = new Subject(); + offset = 0; + + get current() {return this._current;} + set current(value: number) { + this._current = value; + this.refresh(); + } + + get total() {return this._total;} + set total(value: number) { + this._total = value; + this.refresh(); + } + + constructor(delay = 100) { + this._throttledSendProgress = ProgressUtils.throttle( + (current, total) => this.subject.next({current, total}), + ProgressUtils.ValueLimit.inclusive, + delay + ); + } + + reset() { + this.offset = 0; + this._current = 0; + this._total = 0; + + this.refresh(); + } + + refresh() { + this._throttledSendProgress(this.offset + this.current, this.total); + } + + progressFn(current: number, _: number) { + this.current = current; + } +} + +const progress = new Progress(); +let model: IModel; + +export type ModelParams = { + inputSize: number + outputSize: number + description: string +} + +const WorkerImpl = { + async loadModel(files: File[]): Promise { + progress.total = files.reduce((p, c) => p + c.size, 0); + + try { + const chain = new ChainModel(); + for (const file of files) { + const reader = new FileAsyncReader(file); + const loader = new ObservableStreamLoader(reader, progress.progressFn.bind(progress)); + + const data = await loader.load(); + if (file.name.endsWith(".json")) { + const config = JSON.parse(new TextDecoder().decode(data)); + const model = UniversalModelSerializer.load(config, true); + chain.addModel(model); + } else { + const model = BinarySerializer.load(data) + chain.addModel(model); + } + + progress.offset += file.size; + } + + model = chain; + model.compile(); + + const sizes = model.layers.map(l => { + const size = Math.sqrt(l.size); + if (Number.isInteger(size)) { + return `${size}²`; + } + + return l.size.toString(); + }); + + return { + inputSize: model.inputSize, + outputSize: model.outputSize, + description: `${sizes.join(" -> ")}` + }; + } finally { + progress.reset(); + } + }, + + async compute(data: Uint8ClampedArray) { + if (!model) throw new Error("Model is not loaded"); + + const input = ColorUtils.transformChannelCount(Array.from(data), 4, 3); + ColorUtils.transformColorSpace(ColorUtils.rgbToTanh, input, 3, input); + + const output3 = ImageUtils.processMultiChannelData(model, input, 3); + ColorUtils.transformColorSpace(ColorUtils.tanhToRgb, output3, 3, output3); + + const output4 = ColorUtils.transformChannelCount(output3, 3, 4); + const outData = new Uint8ClampedArray(output4); + const outSize = Math.sqrt(model.outputSize); + + return {buffer: outData.buffer, size: outSize}; + }, + + progress() { + return Observable.from(progress.subject); + } +} + +expose(WorkerImpl); + +export type WorkerT = typeof WorkerImpl; \ No newline at end of file diff --git a/src/styles.css b/src/styles.css index 88190d3..658825e 100644 --- a/src/styles.css +++ b/src/styles.css @@ -8,4 +8,4 @@ .pointer:disabled { cursor: not-allowed; -} +} \ No newline at end of file