From 49684b165eae6b5e4f663525ad92e51bb2e808c7 Mon Sep 17 00:00:00 2001 From: Alexander Date: Tue, 19 Sep 2023 17:39:03 +0500 Subject: [PATCH] Refactor serialization --- src/app/neural-network/engine/activations.ts | 2 +- src/app/neural-network/engine/layers.ts | 2 +- src/app/neural-network/engine/loss.ts | 2 +- src/app/neural-network/engine/optimizers.ts | 2 +- src/app/neural-network/neural-network.ts | 4 +- src/app/neural-network/serialization.ts | 432 ------------------ src/app/neural-network/serialization/base.ts | 77 ++++ src/app/neural-network/serialization/chain.ts | 49 ++ src/app/neural-network/serialization/gan.ts | 46 ++ src/app/neural-network/serialization/index.ts | 4 + src/app/neural-network/serialization/model.ts | 167 +++++++ .../neural-network/serialization/universal.ts | 24 + src/app/neural-network/serialization/utils.ts | 90 ++++ 13 files changed, 463 insertions(+), 438 deletions(-) delete mode 100644 src/app/neural-network/serialization.ts create mode 100644 src/app/neural-network/serialization/base.ts create mode 100644 src/app/neural-network/serialization/chain.ts create mode 100644 src/app/neural-network/serialization/gan.ts create mode 100644 src/app/neural-network/serialization/index.ts create mode 100644 src/app/neural-network/serialization/model.ts create mode 100644 src/app/neural-network/serialization/universal.ts create mode 100644 src/app/neural-network/serialization/utils.ts diff --git a/src/app/neural-network/engine/activations.ts b/src/app/neural-network/engine/activations.ts index 5c2c4f0..ae55ed3 100644 --- a/src/app/neural-network/engine/activations.ts +++ b/src/app/neural-network/engine/activations.ts @@ -1,6 +1,6 @@ import {IActivation, ISingleValueActivation} from "./base"; import {Matrix1D} from "./matrix"; -import {Param} from "../serialization"; +import {Param} from "../serialization/base"; import * as Matrix from "./matrix"; import * as Iter from "./iter"; diff --git a/src/app/neural-network/engine/layers.ts b/src/app/neural-network/engine/layers.ts index 2e9084e..52ddafc 100644 --- a/src/app/neural-network/engine/layers.ts +++ b/src/app/neural-network/engine/layers.ts @@ -3,7 +3,7 @@ import * as matrix from "./matrix"; import {IActivation, ILayer, InitializerFn} from "./base"; import {ActivationsMap, ActivationT} from "./activations"; import {Initializers, InitializerT} from "./initializers"; -import {Param} from "../serialization"; +import {Param} from "../serialization/base"; import {Matrix1D} from "./matrix"; import * as Matrix from "./matrix"; diff --git a/src/app/neural-network/engine/loss.ts b/src/app/neural-network/engine/loss.ts index fda9a20..44d808b 100644 --- a/src/app/neural-network/engine/loss.ts +++ b/src/app/neural-network/engine/loss.ts @@ -1,5 +1,5 @@ import {ILoss} from "./base"; -import {Param} from "../serialization"; +import {Param} from "../serialization/base"; import {Matrix1D} from "./matrix"; import * as Matrix from "./matrix"; import * as CommonUtils from "../utils/common" diff --git a/src/app/neural-network/engine/optimizers.ts b/src/app/neural-network/engine/optimizers.ts index 43ef2ad..e47a5b8 100644 --- a/src/app/neural-network/engine/optimizers.ts +++ b/src/app/neural-network/engine/optimizers.ts @@ -1,4 +1,4 @@ -import {Param} from "../serialization"; +import {Param} from "../serialization/base"; import {ILayer, IOptimizer} from "./base"; import {Matrix1D, Matrix2D} from "./matrix"; import * as matrix from "./matrix"; diff --git a/src/app/neural-network/neural-network.ts b/src/app/neural-network/neural-network.ts index 2dadd44..8f66b77 100644 --- a/src/app/neural-network/neural-network.ts +++ b/src/app/neural-network/neural-network.ts @@ -53,8 +53,8 @@ export { } from "./chart" export { - ModelSerialization, GanSerialization, ChainSerialization, UniversalModelSerializer -} from "./serialization" + ModelSerialization, ChainSerialization, UniversalModelSerializer, GanSerialization +} from "./serialization"; export { ParallelModelWrapper, ParallelUtils diff --git a/src/app/neural-network/serialization.ts b/src/app/neural-network/serialization.ts deleted file mode 100644 index a5ad4ff..0000000 --- a/src/app/neural-network/serialization.ts +++ /dev/null @@ -1,432 +0,0 @@ -import {IActivation, ILayer, ILoss, IModel} from "./engine/base"; -import { - Activations, - Initializers, - Layers, - Loss, - Matrix, - Models, - Optimizers, - SequentialModel, - GenerativeAdversarialModel, ChainModel, ComplexModels, -} from "./neural-network"; -import {AbstractMomentAcceleratedOptimizer, MomentCacheT} from "./engine/optimizers"; -import {InitializerMapping} from "./engine/initializers"; - -const SerializationConfig = new Map; - -type Constructor = new (...args: any[]) => T; -type Function = (...args: any[]) => T; - -type SerializedParams = { [key: string]: any }; -type SerializationEntry = { key: keyof T, params: SerializedParams }; - -type AliasesObject = { [key: string]: R } -type Alias, R> = { key: keyof A, type: R }; -type ClassAlias>, R> = Alias>; -type FunctionAlias>, R> = Alias>; - -export function Param(path?: string) { - return function (target: any, propertyKey: string) { - let entries = SerializationConfig.get(target.constructor); - if (!entries) { - entries = []; - SerializationConfig.set(target.constructor, entries); - } - - - const propPath = [path?.split(".") ?? null, propertyKey]; - entries.push(propPath.filter(p => p).join(".")); - }; -} - -type LayerSerializationEntry = { - key: keyof typeof Layers, - size: number, - activation: SerializationEntry, - weightInitializer: keyof typeof InitializerMapping, - biasInitializer: keyof typeof InitializerMapping, - weights: Matrix.Matrix2D, - biases: Matrix.Matrix1D, - params: SerializedParams -} - -type OptimizerSerializationEntry = { - key: keyof typeof Optimizers, - params: SerializedParams, - moments?: object[] -} - -export type ModelSerialized = { - model: keyof typeof Models, - optimizer: OptimizerSerializationEntry, - loss: SerializationEntry, - layers: LayerSerializationEntry[], - epoch: number -} - -export class ModelSerialization { - public static save(model: IModel): ModelSerialized { - if (!model.isCompiled) throw new Error("Model should be compiled"); - - return { - model: SerializationUtils.getTypeAlias(Models, model).key, - optimizer: this.saveOptimizer(model), - loss: this.saveLoss(model.loss), - layers: model.layers.map(l => this.saveLayer(l)), - epoch: model.epoch - } - } - - public static load(data: ModelSerialized): IModel { - const modelT = Models[data.model]; - - if (!modelT) throw new Error(`Invalid model: ${data.model}`); - - const optimizerT = Optimizers[data.optimizer.key]; - if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); - - const lossT = Loss[data.loss.key]; - if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); - - const optimizer = new optimizerT(data.optimizer.params); - const model = new modelT(optimizer, new lossT(data.loss.params)); - - let layerIndex = 0; - for (const layerConf of data.layers) { - const layerT = Layers[layerConf.key]; - if (!layerT) throw new Error(`Invalid layer: ${layerConf.key}`); - - const activationT = Activations[layerConf.activation.key]; - if (!activationT) throw new Error(`Invalid activation: ${layerConf.activation.key}`); - - const layer = new layerT(layerConf.size, { - ...layerConf.params, - activation: new activationT(layerConf.activation.params), - weightInitializer: InitializerMapping[layerConf.weightInitializer], - biasInitializer: InitializerMapping[layerConf.biasInitializer], - }); - - layer.skipWeightsInitialization = true; - - if (layerIndex > 0) { - if (!(layerConf.biases?.length > 0) - || typeof layerConf.biases[0] !== "number") { - throw new Error("Invalid layer biases") - } - if (!(layerConf.weights?.length > 0) - || !(layerConf.weights[0] instanceof Array) - || typeof layerConf.weights[0][0] !== "number") { - throw new Error("Invalid layer weights"); - } - - layer.biases = Matrix.copy(layerConf.biases); - layer.weights = Matrix.copy_2d(layerConf.weights); - } - - model.addLayer(layer); - ++layerIndex; - } - - if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { - this.loadMoments(model, optimizer, data.optimizer.moments as any); - } - - // @ts-ignore - //TODO: ? - model._epoch = data.epoch; - - model.compile(); - return model; - } - - public static loadMoments( - model: IModel, - optimizer: AbstractMomentAcceleratedOptimizer, - moments: T[] - ) { - for (let i = 0; i < model.layers.length; i++) { - const moment = moments[i]; - if (moment) { - optimizer.moments.set(model.layers[i], moment) - } - } - } - - public static saveOptimizer(model: IModel): SerializationEntry { - const optimizer = model.optimizer; - - const type = SerializationUtils.getTypeAlias(Optimizers, optimizer); - const params = SerializationUtils.getSerializableParams(optimizer); - - const result: OptimizerSerializationEntry = {key: type.key, params} - - if (optimizer instanceof AbstractMomentAcceleratedOptimizer) { - result.moments = this.saveOptimizerMoments(model, optimizer); - } - - return result; - } - - public static saveOptimizerMoments(model: IModel, optimizer: AbstractMomentAcceleratedOptimizer): T[] { - const result = new Array(model.layers.length).fill(undefined); - for (let i = 0; i < model.layers.length; i++) { - const layerCache = optimizer.moments.get(model.layers[i]); - - if (layerCache) { - result[i] = {} as T; - for (const [key, value] of Object.entries(layerCache)) { - if (value[0] instanceof Array) { - result[i][key] = Matrix.copy_2d(value as Matrix.Matrix2D); - } else { - result[i][key] = Matrix.copy(value as Matrix.Matrix1D); - } - } - } - } - - return result; - } - - public static saveLoss(loss: ILoss): SerializationEntry { - const type = SerializationUtils.getTypeAlias(Loss, loss); - const params = SerializationUtils.getSerializableParams(loss); - - return { - key: type.key, - params, - } - } - - public static saveLayer(layer: ILayer): LayerSerializationEntry { - const type = SerializationUtils.getTypeAlias(Layers, layer); - const params = SerializationUtils.getSerializableParams(layer); - - return { - key: type.key, - size: layer.size, - activation: this.saveActivation(layer.activation), - weightInitializer: SerializationUtils.getFnAlias(Initializers, layer.weightInitializer).key, - biasInitializer: SerializationUtils.getFnAlias(Initializers, layer.biasInitializer).key, - weights: Matrix.copy_2d(layer.weights), - biases: Matrix.copy(layer.biases), - params, - } - } - - public static saveActivation(activation: IActivation): SerializationEntry { - const type = SerializationUtils.getTypeAlias(Activations, activation); - const params = SerializationUtils.getSerializableParams(activation); - - return { - key: type.key, - params, - } - } -} - -class SerializationUtils { - static getTypeAlias>, T>( - aliases: A, instance: T - ): ClassAlias { - if (!instance) { - throw new Error("Instance can't be empty") - } - - const instanceType = Object.entries(aliases).find(([, type]) => instance instanceof type); - if (!instanceType) { - throw new Error(`Unsupported type: ${instance.constructor.name}`); - } - - return { - key: instanceType[0], - type: instanceType[1], - }; - } - - static getFnAlias, R>( - aliases: AliasesObject, fn: T - ): FunctionAlias { - if (!fn) { - throw new Error("Function can't be empty"); - } - - const instanceFn = Object.entries(aliases).find(([, f]) => fn === f); - if (!instanceFn) { - throw new Error(`Unsupported function: ${fn.constructor.name}`); - } - - return { - key: instanceFn[0], - type: instanceFn[1], - }; - } - - static getSerializableParams(instance: T): SerializedParams { - let result = {}; - - let type = instance.constructor as Constructor; - while (type && type.constructor?.name) { - const classParams = this.getTypeSerializableParams(instance, type); - if (classParams) { - result = {...classParams, ...result} - } - - type = Object.getPrototypeOf(type); - } - - return result; - } - - static getTypeSerializableParams(instance: T, type: Constructor): SerializedParams { - const config = SerializationConfig.get(type); - let params: { [key: string]: any } = {}; - - if (config) { - for (const path of config) { - this.storePropertyValue(instance, path, params); - } - } - - return params; - } - - static storePropertyValue(instance: any, path: string, out: any) { - const parts = path.split("."); - - let cOut = out; - for (let i = 0; i < parts.length - 1; i++) { - const part = parts[i]; - - if (cOut[part] === undefined) cOut[part] = {}; - cOut = cOut[part]; - } - - const part = parts[parts.length - 1]; - cOut[part] = instance[part] - } -} - -export type GanSerialized = { - generator: ModelSerialized, - discriminator: ModelSerialized, - epoch: number, - optimizer: OptimizerSerializationEntry, - loss: SerializationEntry, -} - -export class GanSerialization { - public static save(gan: GenerativeAdversarialModel): GanSerialized { - return { - generator: ModelSerialization.save(gan.generator), - discriminator: ModelSerialization.save(gan.discriminator), - - epoch: gan.ganChain.epoch, - optimizer: ModelSerialization.saveOptimizer(gan.ganChain), - loss: ModelSerialization.saveLoss(gan.ganChain.loss), - } - } - - public static load(data: GanSerialized): GenerativeAdversarialModel { - const generator = ModelSerialization.load(data.generator); - const discriminator = ModelSerialization.load(data.discriminator); - - const optimizerT = Optimizers[data.optimizer.key]; - if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); - - const lossT = Loss[data.loss.key]; - if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); - - const optimizer = new optimizerT(data.optimizer.params); - const model = new GenerativeAdversarialModel( - generator as SequentialModel, - discriminator as SequentialModel, - optimizer, - new lossT(data.loss.params), - ); - - if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { - ModelSerialization.loadMoments(model.ganChain, optimizer, data.optimizer.moments as any); - } - - // @ts-ignore - model.ganChain._epoch = data.epoch; - - return model; - } -} - -export type ChainSerialized = { - model: keyof typeof ComplexModels, - models: ModelSerialized[], - trainable: boolean[], - epoch: number, - optimizer: OptimizerSerializationEntry, - loss: SerializationEntry, -} - -export class ChainSerialization { - public static save(chain: ChainModel): ChainSerialized { - if (!chain.isCompiled) throw new Error("Model should be compiled"); - - return { - model: SerializationUtils.getTypeAlias(ComplexModels, chain as any).key, - models: chain.models.map(model => ModelSerialization.save(model)), - trainable: chain.trainable.concat(), - - epoch: chain.epoch, - optimizer: ModelSerialization.saveOptimizer(chain), - loss: ModelSerialization.saveLoss(chain.loss), - } - } - - public static load(data: ChainSerialized): ChainModel { - const optimizerT = Optimizers[data.optimizer.key]; - if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); - - const lossT = Loss[data.loss.key]; - if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); - - const optimizer = new optimizerT(data.optimizer.params); - const model = new ChainModel(optimizer, new lossT(data.loss.params),); - - for (let i = 0; i < data.models.length; i++) { - model.addModel( - ModelSerialization.load(data.models[i]), - data.trainable[i] - ); - } - - if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { - ModelSerialization.loadMoments(model, optimizer, data.optimizer.moments as any); - } - - // @ts-ignore - model._epoch = data.epoch; - model.compile(); - - return model; - } -} - -type ISerializedModel = { - model: string -} - -export class UniversalModelSerializer { - static save(model: IModel) { - if (model instanceof ChainModel) { - return ChainSerialization.save(model); - } else { - return ModelSerialization.save(model); - } - } - - static load(data: ISerializedModel) { - if (data.model === "Chain") { - return ChainSerialization.load(data as any); - } else { - return ModelSerialization.load(data as any); - } - } -} \ No newline at end of file diff --git a/src/app/neural-network/serialization/base.ts b/src/app/neural-network/serialization/base.ts new file mode 100644 index 0000000..6554f49 --- /dev/null +++ b/src/app/neural-network/serialization/base.ts @@ -0,0 +1,77 @@ +import {Layers} from "../engine/layers"; +import {Activations} from "../engine/activations"; +import {InitializerMapping} from "../engine/initializers"; +import {ComplexModels, Loss, Matrix, Models, Optimizers} from "../neural-network"; + +export const SerializationConfig = new Map; + +export type Constructor = new (...args: any[]) => T; +export type Function = (...args: any[]) => T; + +export type SerializedParams = { [key: string]: any }; +export type SerializationEntry = { key: keyof T, params: SerializedParams }; + +export type AliasesObject = { [key: string]: R } +export type Alias, R> = { key: keyof A, type: R }; +export type ClassAlias>, R> = Alias>; +export type FunctionAlias>, R> = Alias>; + +export function Param(path?: string) { + return function (target: any, propertyKey: string) { + let entries = SerializationConfig.get(target.constructor); + if (!entries) { + entries = []; + SerializationConfig.set(target.constructor, entries); + } + + + const propPath = [path?.split(".") ?? null, propertyKey]; + entries.push(propPath.filter(p => p).join(".")); + }; +} + +export type LayerSerializationEntry = { + key: keyof typeof Layers, + size: number, + activation: SerializationEntry, + weightInitializer: keyof typeof InitializerMapping, + biasInitializer: keyof typeof InitializerMapping, + weights: Matrix.Matrix2D, + biases: Matrix.Matrix1D, + params: SerializedParams +} + +export type OptimizerSerializationEntry = { + key: keyof typeof Optimizers, + params: SerializedParams, + moments?: object[] +} + +export type ModelSerialized = { + model: keyof typeof Models, + optimizer: OptimizerSerializationEntry, + loss: SerializationEntry, + layers: LayerSerializationEntry[], + epoch: number +} + +export type ISerializedModel = { + model: string +} + +export type ChainSerialized = { + model: keyof typeof ComplexModels, + models: ModelSerialized[], + trainable: boolean[], + epoch: number, + optimizer: OptimizerSerializationEntry, + loss: SerializationEntry, +} + +export type GanSerialized = { + generator: ModelSerialized, + discriminator: ModelSerialized, + epoch: number, + optimizer: OptimizerSerializationEntry, + loss: SerializationEntry, +} \ No newline at end of file diff --git a/src/app/neural-network/serialization/chain.ts b/src/app/neural-network/serialization/chain.ts new file mode 100644 index 0000000..b052e7b --- /dev/null +++ b/src/app/neural-network/serialization/chain.ts @@ -0,0 +1,49 @@ +import {AbstractMomentAcceleratedOptimizer} from "../engine/optimizers"; +import {ChainModel, ComplexModels, Loss, Optimizers} from "../neural-network"; +import {ChainSerialized} from "./base"; +import {SerializationUtils} from "./utils"; +import {ModelSerialization} from "./model"; + +export class ChainSerialization { + public static save(chain: ChainModel): ChainSerialized { + if (!chain.isCompiled) throw new Error("Model should be compiled"); + + return { + model: SerializationUtils.getTypeAlias(ComplexModels, chain as any).key, + models: chain.models.map(model => ModelSerialization.save(model)), + trainable: chain.trainable.concat(), + + epoch: chain.epoch, + optimizer: ModelSerialization.saveOptimizer(chain), + loss: ModelSerialization.saveLoss(chain.loss), + } + } + + public static load(data: ChainSerialized): ChainModel { + const optimizerT = Optimizers[data.optimizer.key]; + if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); + + const lossT = Loss[data.loss.key]; + if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); + + const optimizer = new optimizerT(data.optimizer.params); + const model = new ChainModel(optimizer, new lossT(data.loss.params),); + + for (let i = 0; i < data.models.length; i++) { + model.addModel( + ModelSerialization.load(data.models[i]), + data.trainable[i] + ); + } + + if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { + ModelSerialization.loadMoments(model, optimizer, data.optimizer.moments as any); + } + + // @ts-ignore + model._epoch = data.epoch; + model.compile(); + + return model; + } +} \ No newline at end of file diff --git a/src/app/neural-network/serialization/gan.ts b/src/app/neural-network/serialization/gan.ts new file mode 100644 index 0000000..33c5963 --- /dev/null +++ b/src/app/neural-network/serialization/gan.ts @@ -0,0 +1,46 @@ +import {GenerativeAdversarialModel, SequentialModel, Optimizers, Loss} from "../neural-network"; +import {AbstractMomentAcceleratedOptimizer} from "../engine/optimizers"; + +import {GanSerialized} from "./base"; +import {ModelSerialization} from "./model"; + +export class GanSerialization { + public static save(gan: GenerativeAdversarialModel): GanSerialized { + return { + generator: ModelSerialization.save(gan.generator), + discriminator: ModelSerialization.save(gan.discriminator), + + epoch: gan.ganChain.epoch, + optimizer: ModelSerialization.saveOptimizer(gan.ganChain), + loss: ModelSerialization.saveLoss(gan.ganChain.loss), + } + } + + public static load(data: GanSerialized): GenerativeAdversarialModel { + const generator = ModelSerialization.load(data.generator); + const discriminator = ModelSerialization.load(data.discriminator); + + const optimizerT = Optimizers[data.optimizer.key]; + if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); + + const lossT = Loss[data.loss.key]; + if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); + + const optimizer = new optimizerT(data.optimizer.params); + const model = new GenerativeAdversarialModel( + generator as SequentialModel, + discriminator as SequentialModel, + optimizer, + new lossT(data.loss.params), + ); + + if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { + ModelSerialization.loadMoments(model.ganChain, optimizer, data.optimizer.moments as any); + } + + // @ts-ignore + model.ganChain._epoch = data.epoch; + + return model; + } +} \ No newline at end of file diff --git a/src/app/neural-network/serialization/index.ts b/src/app/neural-network/serialization/index.ts new file mode 100644 index 0000000..164df56 --- /dev/null +++ b/src/app/neural-network/serialization/index.ts @@ -0,0 +1,4 @@ +export {ModelSerialization} from "./model" +export {ChainSerialization} from "./chain" +export {GanSerialization} from "./gan" +export {UniversalModelSerializer} from "./universal" \ No newline at end of file diff --git a/src/app/neural-network/serialization/model.ts b/src/app/neural-network/serialization/model.ts new file mode 100644 index 0000000..aed9188 --- /dev/null +++ b/src/app/neural-network/serialization/model.ts @@ -0,0 +1,167 @@ +import {IActivation, ILayer, ILoss, IModel} from "../engine/base"; +import {Models, Loss, Layers, Activations, Initializers, Optimizers, Matrix} from "../neural-network"; +import {InitializerMapping} from "../engine/initializers"; +import {AbstractMomentAcceleratedOptimizer, MomentCacheT} from "../engine/optimizers"; + +import {LayerSerializationEntry, ModelSerialized, OptimizerSerializationEntry, SerializationEntry} from "./base"; +import {SerializationUtils} from "./utils"; + +export class ModelSerialization { + public static save(model: IModel): ModelSerialized { + if (!model.isCompiled) throw new Error("Model should be compiled"); + + return { + model: SerializationUtils.getTypeAlias(Models, model).key, + optimizer: this.saveOptimizer(model), + loss: this.saveLoss(model.loss), + layers: model.layers.map(l => this.saveLayer(l)), + epoch: model.epoch + } + } + + public static load(data: ModelSerialized): IModel { + const modelT = Models[data.model]; + + if (!modelT) throw new Error(`Invalid model: ${data.model}`); + + const optimizerT = Optimizers[data.optimizer.key]; + if (!optimizerT) throw new Error(`Invalid optimizer: ${data.optimizer.key}`); + + const lossT = Loss[data.loss.key]; + if (!lossT) throw new Error(`Invalid loss: ${data.loss.key}`); + + const optimizer = new optimizerT(data.optimizer.params); + const model = new modelT(optimizer, new lossT(data.loss.params)); + + let layerIndex = 0; + for (const layerConf of data.layers) { + const layerT = Layers[layerConf.key]; + if (!layerT) throw new Error(`Invalid layer: ${layerConf.key}`); + + const activationT = Activations[layerConf.activation.key]; + if (!activationT) throw new Error(`Invalid activation: ${layerConf.activation.key}`); + + const layer = new layerT(layerConf.size, { + ...layerConf.params, + activation: new activationT(layerConf.activation.params), + weightInitializer: InitializerMapping[layerConf.weightInitializer], + biasInitializer: InitializerMapping[layerConf.biasInitializer], + }); + + layer.skipWeightsInitialization = true; + + if (layerIndex > 0) { + if (!(layerConf.biases?.length > 0) + || typeof layerConf.biases[0] !== "number") { + throw new Error("Invalid layer biases") + } + if (!(layerConf.weights?.length > 0) + || !(layerConf.weights[0] instanceof Array) + || typeof layerConf.weights[0][0] !== "number") { + throw new Error("Invalid layer weights"); + } + + layer.biases = Matrix.copy(layerConf.biases); + layer.weights = Matrix.copy_2d(layerConf.weights); + } + + model.addLayer(layer); + ++layerIndex; + } + + if (optimizer instanceof AbstractMomentAcceleratedOptimizer && data.optimizer.moments) { + this.loadMoments(model, optimizer, data.optimizer.moments as any); + } + + // @ts-ignore + //TODO: ? + model._epoch = data.epoch; + + model.compile(); + return model; + } + + public static loadMoments( + model: IModel, + optimizer: AbstractMomentAcceleratedOptimizer, + moments: T[] + ) { + for (let i = 0; i < model.layers.length; i++) { + const moment = moments[i]; + if (moment) { + optimizer.moments.set(model.layers[i], moment) + } + } + } + + public static saveOptimizer(model: IModel): SerializationEntry { + const optimizer = model.optimizer; + + const type = SerializationUtils.getTypeAlias(Optimizers, optimizer); + const params = SerializationUtils.getSerializableParams(optimizer); + + const result: OptimizerSerializationEntry = {key: type.key, params} + + if (optimizer instanceof AbstractMomentAcceleratedOptimizer) { + result.moments = this.saveOptimizerMoments(model, optimizer); + } + + return result; + } + + public static saveOptimizerMoments(model: IModel, optimizer: AbstractMomentAcceleratedOptimizer): T[] { + const result = new Array(model.layers.length).fill(undefined); + for (let i = 0; i < model.layers.length; i++) { + const layerCache = optimizer.moments.get(model.layers[i]); + + if (layerCache) { + result[i] = {} as T; + for (const [key, value] of Object.entries(layerCache)) { + if (value[0] instanceof Array) { + result[i][key] = Matrix.copy_2d(value as Matrix.Matrix2D); + } else { + result[i][key] = Matrix.copy(value as Matrix.Matrix1D); + } + } + } + } + + return result; + } + + public static saveLoss(loss: ILoss): SerializationEntry { + const type = SerializationUtils.getTypeAlias(Loss, loss); + const params = SerializationUtils.getSerializableParams(loss); + + return { + key: type.key, + params, + } + } + + public static saveLayer(layer: ILayer): LayerSerializationEntry { + const type = SerializationUtils.getTypeAlias(Layers, layer); + const params = SerializationUtils.getSerializableParams(layer); + + return { + key: type.key, + size: layer.size, + activation: this.saveActivation(layer.activation), + weightInitializer: SerializationUtils.getFnAlias(Initializers, layer.weightInitializer).key, + biasInitializer: SerializationUtils.getFnAlias(Initializers, layer.biasInitializer).key, + weights: Matrix.copy_2d(layer.weights), + biases: Matrix.copy(layer.biases), + params, + } + } + + public static saveActivation(activation: IActivation): SerializationEntry { + const type = SerializationUtils.getTypeAlias(Activations, activation); + const params = SerializationUtils.getSerializableParams(activation); + + return { + key: type.key, + params, + } + } +} \ No newline at end of file diff --git a/src/app/neural-network/serialization/universal.ts b/src/app/neural-network/serialization/universal.ts new file mode 100644 index 0000000..3b8675d --- /dev/null +++ b/src/app/neural-network/serialization/universal.ts @@ -0,0 +1,24 @@ +import {IModel} from "../engine/base"; +import {ChainModel} from "../neural-network"; + +import {ISerializedModel} from "./base"; +import {ChainSerialization} from "./chain"; +import {ModelSerialization} from "./model"; + +export class UniversalModelSerializer { + static save(model: IModel) { + if (model instanceof ChainModel) { + return ChainSerialization.save(model); + } else { + return ModelSerialization.save(model); + } + } + + static load(data: ISerializedModel) { + if (data.model === "Chain") { + return ChainSerialization.load(data as any); + } else { + return ModelSerialization.load(data as any); + } + } +} \ No newline at end of file diff --git a/src/app/neural-network/serialization/utils.ts b/src/app/neural-network/serialization/utils.ts new file mode 100644 index 0000000..b054e89 --- /dev/null +++ b/src/app/neural-network/serialization/utils.ts @@ -0,0 +1,90 @@ +import { + AliasesObject, + ClassAlias, + Constructor, + Function, + FunctionAlias, + SerializationConfig, + SerializedParams +} from "./base"; + +export class SerializationUtils { + static getTypeAlias>, T>( + aliases: A, instance: T + ): ClassAlias { + if (!instance) { + throw new Error("Instance can't be empty") + } + + const instanceType = Object.entries(aliases).find(([, type]) => instance instanceof type); + if (!instanceType) { + throw new Error(`Unsupported type: ${instance.constructor.name}`); + } + + return { + key: instanceType[0], + type: instanceType[1], + }; + } + + static getFnAlias, R>( + aliases: AliasesObject, fn: T + ): FunctionAlias { + if (!fn) { + throw new Error("Function can't be empty"); + } + + const instanceFn = Object.entries(aliases).find(([, f]) => fn === f); + if (!instanceFn) { + throw new Error(`Unsupported function: ${fn.constructor.name}`); + } + + return { + key: instanceFn[0], + type: instanceFn[1], + }; + } + + static getSerializableParams(instance: T): SerializedParams { + let result = {}; + + let type = instance.constructor as Constructor; + while (type && type.constructor?.name) { + const classParams = this.getTypeSerializableParams(instance, SerializationConfig.get(type)!); + if (classParams) { + result = {...classParams, ...result} + } + + type = Object.getPrototypeOf(type); + } + + return result; + } + + static getTypeSerializableParams(instance: T, typeConfig: string[]): SerializedParams { + let params: { [key: string]: any } = {}; + + if (typeConfig) { + for (const path of typeConfig) { + this.storePropertyValue(instance, path, params); + } + } + + return params; + } + + static storePropertyValue(instance: any, path: string, out: any) { + const parts = path.split("."); + + let cOut = out; + for (let i = 0; i < parts.length - 1; i++) { + const part = parts[i]; + + if (cOut[part] === undefined) cOut[part] = {}; + cOut = cOut[part]; + } + + const part = parts[parts.length - 1]; + cOut[part] = instance[part] + } +} \ No newline at end of file