Skip to content

Commit

Permalink
Allow to select multiple models in Demo 3
Browse files Browse the repository at this point in the history
  • Loading branch information
DrA1ex committed Sep 15, 2023
1 parent 70559ba commit 5638d27
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/app/pages/demo3/demo3.component.css
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
canvas,
binary-image-drawer {
display: block;
image-rendering: pixelated;
}
62 changes: 43 additions & 19 deletions src/app/pages/demo3/demo3.component.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@ import {AfterViewInit, Component, ElementRef, ViewChild} from '@angular/core';
import * as FileInteraction from "../../utils/file-interaction";
import {IModel} from "../../neural-network/engine/base";
import {FileAsyncReader, ObservableStreamLoader} from "../../neural-network/utils/fetch";
import {CommonUtils, ImageUtils, UniversalModelSerializer} from "../../neural-network/neural-network";
import {
ChainModel,
CommonUtils,
ImageUtils,
ProgressUtils,
UniversalModelSerializer
} from "../../neural-network/neural-network";
import {BinaryImageDrawerComponent} from "../../components/binary-image-drawer/binary-image-drawer.component";
import * as ColorUtils from "../../neural-network/utils/color";

@Component({
selector: 'app-demo3',
Expand All @@ -27,13 +34,16 @@ export class Demo3Component implements AfterViewInit {

modelDetails: string = "Load model to view details";
progress = {
offset: 0,
loaded: 0,
total: 0,
progressFn: (loaded: number, total: number) => {
this.progress.loaded = loaded;
this.progress.total = total;
},
progressFn: ProgressUtils.throttle(
(loaded: number, _: number) => {
this.progress.loaded = this.progress.offset + loaded;
}, ProgressUtils.ValueLimit.inclusive, 300
),
reset: () => {
this.progress.offset = 0
this.progress.loaded = 0;
this.progress.total = 0;
},
Expand All @@ -52,7 +62,8 @@ export class Demo3Component implements AfterViewInit {
}

draw(event: MouseEvent) {
if (!this.isDrawing) return;
if (!this.isDrawing || event.button !== 0) return;

const canvas = this.drawingCanvasRef.nativeElement;
const rect = canvas.getBoundingClientRect();

Expand Down Expand Up @@ -95,20 +106,31 @@ export class Demo3Component implements AfterViewInit {

async loadModel() {
try {
const file = await FileInteraction.openFile('application/json', false) as File;
if (!file) return;
const files = await FileInteraction.openFile('application/json', true) as File[];
if (!files?.length) return;

const reader = new FileAsyncReader(file);
const loader = new ObservableStreamLoader(reader, this.progress.progressFn);
this.progress.total = files.reduce((p, c) => p + c.size, 0);

const data = await loader.load();
const config = JSON.parse(new TextDecoder().decode(data));
this.model = UniversalModelSerializer.load(config);
const chain = new ChainModel();
for (const file of files) {
const reader = new FileAsyncReader(file);
const loader = new ObservableStreamLoader(reader, this.progress.progressFn);

const data = await loader.load();
const config = JSON.parse(new TextDecoder().decode(data));
const model = UniversalModelSerializer.load(config);
chain.addModel(model);

this.progress.offset += file.size;
}

this.model = chain;
this.model.compile();

const sizes = this.model.layers.map(l => {
const size = Math.sqrt(l.size);
if (Number.isInteger(size)) {
return `${size}x${size}`;
return `${size}²`;
}

return l.size.toString();
Expand All @@ -133,13 +155,15 @@ export class Demo3Component implements AfterViewInit {
const size = Math.sqrt(this.model.inputSize);
const data = this.getImageDataFromCanvas(this.drawingCanvasRef.nativeElement, size, size);

const input = Array.from(data).map(value => value / 127.5 - 1);
const output = ImageUtils.processMultiChannelData(this.model, input, 4);
const input = ColorUtils.transformChannelCount(Array.from(data), 4, 3);
ColorUtils.transformColorSpace(ColorUtils.rgbToTanh, input, 3, input);

const mappedOutput = output.map((value, i) => i % 4 !== 3 ? (value + 1) * 127.5 : 255);
const outData = new Uint8ClampedArray(mappedOutput);
const outSize = Math.sqrt(this.model.outputSize);
const output3 = ImageUtils.processMultiChannelData(this.model, input, 3);
ColorUtils.transformColorSpace(ColorUtils.tanhToRgb, output3, 3, output3);

const output4 = ColorUtils.transformChannelCount(output3, 3, 4);
const outData = new Uint8ClampedArray(output4);
const outSize = Math.sqrt(this.model.outputSize);
this.generatedImage.draw(outData.buffer, outSize, outSize);
}

Expand Down

0 comments on commit 5638d27

Please sign in to comment.