From e689fbb41d17af73c83d61bd9b190e7b27501c36 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Fri, 24 May 2024 11:30:50 +0200
Subject: [PATCH] DOCS: shimmer API docs for v0.5.1 (#86)
Co-authored-by: github-actions
---
docs/api/v0.5.1/index.html | 249 +
docs/api/v0.5.1/search.js | 46 +
.../v0.5.1/shimmer/cli/ckpt_migration.html | 319 ++
docs/api/v0.5.1/shimmer/dataset.html | 448 ++
.../shimmer/modules/contrastive_loss.html | 797 +++
docs/api/v0.5.1/shimmer/modules/domain.html | 1220 +++++
.../shimmer/modules/global_workspace.html | 4260 +++++++++++++++++
.../api/v0.5.1/shimmer/modules/gw_module.html | 2893 +++++++++++
docs/api/v0.5.1/shimmer/modules/losses.html | 3888 +++++++++++++++
.../api/v0.5.1/shimmer/modules/selection.html | 2151 +++++++++
docs/api/v0.5.1/shimmer/modules/utils.html | 760 +++
docs/api/v0.5.1/shimmer/modules/vae.html | 1288 +++++
docs/api/v0.5.1/shimmer/types.html | 872 ++++
docs/api/v0.5.1/shimmer/utils.html | 701 +++
14 files changed, 19892 insertions(+)
create mode 100644 docs/api/v0.5.1/index.html
create mode 100644 docs/api/v0.5.1/search.js
create mode 100644 docs/api/v0.5.1/shimmer/cli/ckpt_migration.html
create mode 100644 docs/api/v0.5.1/shimmer/dataset.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/contrastive_loss.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/domain.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/global_workspace.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/gw_module.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/losses.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/selection.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/utils.html
create mode 100644 docs/api/v0.5.1/shimmer/modules/vae.html
create mode 100644 docs/api/v0.5.1/shimmer/types.html
create mode 100644 docs/api/v0.5.1/shimmer/utils.html
diff --git a/docs/api/v0.5.1/index.html b/docs/api/v0.5.1/index.html
new file mode 100644
index 00000000..b62372d7
--- /dev/null
+++ b/docs/api/v0.5.1/index.html
@@ -0,0 +1,249 @@
+
+
+
+
+
+
+ Module List – pdoc 14.4.0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/search.js b/docs/api/v0.5.1/search.js
new file mode 100644
index 00000000..6b97e191
--- /dev/null
+++ b/docs/api/v0.5.1/search.js
@@ -0,0 +1,46 @@
+window.pdocSearch = (function(){
+/** elasticlunr - http://weixsong.github.io * Copyright (C) 2017 Oliver Nightingale * Copyright (C) 2017 Wei Song * MIT Licensed */!function(){function e(e){if(null===e||"object"!=typeof e)return e;var t=e.constructor();for(var n in e)e.hasOwnProperty(n)&&(t[n]=e[n]);return t}var t=function(e){var n=new t.Index;return n.pipeline.add(t.trimmer,t.stopWordFilter,t.stemmer),e&&e.call(n,n),n};t.version="0.9.5",lunr=t,t.utils={},t.utils.warn=function(e){return function(t){e.console&&console.warn&&console.warn(t)}}(this),t.utils.toString=function(e){return void 0===e||null===e?"":e.toString()},t.EventEmitter=function(){this.events={}},t.EventEmitter.prototype.addListener=function(){var e=Array.prototype.slice.call(arguments),t=e.pop(),n=e;if("function"!=typeof t)throw new TypeError("last argument must be a function");n.forEach(function(e){this.hasHandler(e)||(this.events[e]=[]),this.events[e].push(t)},this)},t.EventEmitter.prototype.removeListener=function(e,t){if(this.hasHandler(e)){var n=this.events[e].indexOf(t);-1!==n&&(this.events[e].splice(n,1),0==this.events[e].length&&delete this.events[e])}},t.EventEmitter.prototype.emit=function(e){if(this.hasHandler(e)){var t=Array.prototype.slice.call(arguments,1);this.events[e].forEach(function(e){e.apply(void 0,t)},this)}},t.EventEmitter.prototype.hasHandler=function(e){return e in this.events},t.tokenizer=function(e){if(!arguments.length||null===e||void 0===e)return[];if(Array.isArray(e)){var n=e.filter(function(e){return null===e||void 0===e?!1:!0});n=n.map(function(e){return t.utils.toString(e).toLowerCase()});var i=[];return n.forEach(function(e){var n=e.split(t.tokenizer.seperator);i=i.concat(n)},this),i}return e.toString().trim().toLowerCase().split(t.tokenizer.seperator)},t.tokenizer.defaultSeperator=/[\s\-]+/,t.tokenizer.seperator=t.tokenizer.defaultSeperator,t.tokenizer.setSeperator=function(e){null!==e&&void 0!==e&&"object"==typeof e&&(t.tokenizer.seperator=e)},t.tokenizer.resetSeperator=function(){t.tokenizer.seperator=t.tokenizer.defaultSeperator},t.tokenizer.getSeperator=function(){return t.tokenizer.seperator},t.Pipeline=function(){this._queue=[]},t.Pipeline.registeredFunctions={},t.Pipeline.registerFunction=function(e,n){n in t.Pipeline.registeredFunctions&&t.utils.warn("Overwriting existing registered function: "+n),e.label=n,t.Pipeline.registeredFunctions[n]=e},t.Pipeline.getRegisteredFunction=function(e){return e in t.Pipeline.registeredFunctions!=!0?null:t.Pipeline.registeredFunctions[e]},t.Pipeline.warnIfFunctionNotRegistered=function(e){var n=e.label&&e.label in this.registeredFunctions;n||t.utils.warn("Function is not registered with pipeline. This may cause problems when serialising the index.\n",e)},t.Pipeline.load=function(e){var n=new t.Pipeline;return e.forEach(function(e){var i=t.Pipeline.getRegisteredFunction(e);if(!i)throw new Error("Cannot load un-registered function: "+e);n.add(i)}),n},t.Pipeline.prototype.add=function(){var e=Array.prototype.slice.call(arguments);e.forEach(function(e){t.Pipeline.warnIfFunctionNotRegistered(e),this._queue.push(e)},this)},t.Pipeline.prototype.after=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i+1,0,n)},t.Pipeline.prototype.before=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i,0,n)},t.Pipeline.prototype.remove=function(e){var t=this._queue.indexOf(e);-1!==t&&this._queue.splice(t,1)},t.Pipeline.prototype.run=function(e){for(var t=[],n=e.length,i=this._queue.length,o=0;n>o;o++){for(var r=e[o],s=0;i>s&&(r=this._queue[s](r,o,e),void 0!==r&&null!==r);s++);void 0!==r&&null!==r&&t.push(r)}return t},t.Pipeline.prototype.reset=function(){this._queue=[]},t.Pipeline.prototype.get=function(){return this._queue},t.Pipeline.prototype.toJSON=function(){return this._queue.map(function(e){return t.Pipeline.warnIfFunctionNotRegistered(e),e.label})},t.Index=function(){this._fields=[],this._ref="id",this.pipeline=new t.Pipeline,this.documentStore=new t.DocumentStore,this.index={},this.eventEmitter=new t.EventEmitter,this._idfCache={},this.on("add","remove","update",function(){this._idfCache={}}.bind(this))},t.Index.prototype.on=function(){var e=Array.prototype.slice.call(arguments);return this.eventEmitter.addListener.apply(this.eventEmitter,e)},t.Index.prototype.off=function(e,t){return this.eventEmitter.removeListener(e,t)},t.Index.load=function(e){e.version!==t.version&&t.utils.warn("version mismatch: current "+t.version+" importing "+e.version);var n=new this;n._fields=e.fields,n._ref=e.ref,n.documentStore=t.DocumentStore.load(e.documentStore),n.pipeline=t.Pipeline.load(e.pipeline),n.index={};for(var i in e.index)n.index[i]=t.InvertedIndex.load(e.index[i]);return n},t.Index.prototype.addField=function(e){return this._fields.push(e),this.index[e]=new t.InvertedIndex,this},t.Index.prototype.setRef=function(e){return this._ref=e,this},t.Index.prototype.saveDocument=function(e){return this.documentStore=new t.DocumentStore(e),this},t.Index.prototype.addDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.addDoc(i,e),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));this.documentStore.addFieldLength(i,n,o.length);var r={};o.forEach(function(e){e in r?r[e]+=1:r[e]=1},this);for(var s in r){var u=r[s];u=Math.sqrt(u),this.index[n].addToken(s,{ref:i,tf:u})}},this),n&&this.eventEmitter.emit("add",e,this)}},t.Index.prototype.removeDocByRef=function(e){if(e&&this.documentStore.isDocStored()!==!1&&this.documentStore.hasDoc(e)){var t=this.documentStore.getDoc(e);this.removeDoc(t,!1)}},t.Index.prototype.removeDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.hasDoc(i)&&(this.documentStore.removeDoc(i),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));o.forEach(function(e){this.index[n].removeToken(e,i)},this)},this),n&&this.eventEmitter.emit("remove",e,this))}},t.Index.prototype.updateDoc=function(e,t){var t=void 0===t?!0:t;this.removeDocByRef(e[this._ref],!1),this.addDoc(e,!1),t&&this.eventEmitter.emit("update",e,this)},t.Index.prototype.idf=function(e,t){var n="@"+t+"/"+e;if(Object.prototype.hasOwnProperty.call(this._idfCache,n))return this._idfCache[n];var i=this.index[t].getDocFreq(e),o=1+Math.log(this.documentStore.length/(i+1));return this._idfCache[n]=o,o},t.Index.prototype.getFields=function(){return this._fields.slice()},t.Index.prototype.search=function(e,n){if(!e)return[];e="string"==typeof e?{any:e}:JSON.parse(JSON.stringify(e));var i=null;null!=n&&(i=JSON.stringify(n));for(var o=new t.Configuration(i,this.getFields()).get(),r={},s=Object.keys(e),u=0;u0&&t.push(e);for(var i in n)"docs"!==i&&"df"!==i&&this.expandToken(e+i,t,n[i]);return t},t.InvertedIndex.prototype.toJSON=function(){return{root:this.root}},t.Configuration=function(e,n){var e=e||"";if(void 0==n||null==n)throw new Error("fields should not be null");this.config={};var i;try{i=JSON.parse(e),this.buildUserConfig(i,n)}catch(o){t.utils.warn("user configuration parse failed, will use default configuration"),this.buildDefaultConfig(n)}},t.Configuration.prototype.buildDefaultConfig=function(e){this.reset(),e.forEach(function(e){this.config[e]={boost:1,bool:"OR",expand:!1}},this)},t.Configuration.prototype.buildUserConfig=function(e,n){var i="OR",o=!1;if(this.reset(),"bool"in e&&(i=e.bool||i),"expand"in e&&(o=e.expand||o),"fields"in e)for(var r in e.fields)if(n.indexOf(r)>-1){var s=e.fields[r],u=o;void 0!=s.expand&&(u=s.expand),this.config[r]={boost:s.boost||0===s.boost?s.boost:1,bool:s.bool||i,expand:u}}else t.utils.warn("field name in user configuration not found in index instance fields");else this.addAllFields2UserConfig(i,o,n)},t.Configuration.prototype.addAllFields2UserConfig=function(e,t,n){n.forEach(function(n){this.config[n]={boost:1,bool:e,expand:t}},this)},t.Configuration.prototype.get=function(){return this.config},t.Configuration.prototype.reset=function(){this.config={}},lunr.SortedSet=function(){this.length=0,this.elements=[]},lunr.SortedSet.load=function(e){var t=new this;return t.elements=e,t.length=e.length,t},lunr.SortedSet.prototype.add=function(){var e,t;for(e=0;e1;){if(r===e)return o;e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o]}return r===e?o:-1},lunr.SortedSet.prototype.locationFor=function(e){for(var t=0,n=this.elements.length,i=n-t,o=t+Math.floor(i/2),r=this.elements[o];i>1;)e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o];return r>e?o:e>r?o+1:void 0},lunr.SortedSet.prototype.intersect=function(e){for(var t=new lunr.SortedSet,n=0,i=0,o=this.length,r=e.length,s=this.elements,u=e.elements;;){if(n>o-1||i>r-1)break;s[n]!==u[i]?s[n]u[i]&&i++:(t.add(s[n]),n++,i++)}return t},lunr.SortedSet.prototype.clone=function(){var e=new lunr.SortedSet;return e.elements=this.toArray(),e.length=e.elements.length,e},lunr.SortedSet.prototype.union=function(e){var t,n,i;this.length>=e.length?(t=this,n=e):(t=e,n=this),i=t.clone();for(var o=0,r=n.toArray();o
\n"}, "shimmer.types.RawDomainGroupT": {"fullname": "shimmer.types.RawDomainGroupT", "modulename": "shimmer.types", "qualname": "RawDomainGroupT", "kind": "variable", "doc": "Matched raw unimodal data from multiple domains.\nKeys of the mapping are domains names and values are the domain data.
\n\nAll values in the mapping should be matched and represent the same information.
\n\nExample: \n\n\n \n
def fun ( domain_group : RawDomainGroupT ): ... \n\n\nx = { \n "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ), \n "language" : "This is a picture of a dog." , \n} \n\nfun ( x ) \n
\n
\n \n\nNote: \n\n\n This type uses collections.abc.Mapping
and is used for functions' inputs.\n Use RawDomainGroupDT
for functions' outputs.
\n \n This allows to be more generic and allow passing other mappings.
\n \n", "default_value": "collections.abc.Mapping[str, typing.Any]"}, "shimmer.types.RawDomainGroupDT": {"fullname": "shimmer.types.RawDomainGroupDT", "modulename": "shimmer.types", "qualname": "RawDomainGroupDT", "kind": "variable", "doc": "Output type version of RawDomainGroupT
.\nMatched raw unimodal data from multiple domains.\nKeys of the mapping are domains names and values are the domain data.
\n\nExample: \n\n\n \n
def fun () -> RawDomainGroupDT : \n return { \n "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ), \n "language" : "This is a picture of a dog." , \n } \n
\n
\n \n\nNote: \n\n\n This type uses dict
s and is used for functions' outputs.\n Use RawDomainGroupT
for functions' inputs.
\n \n", "default_value": "dict[str, typing.Any]"}, "shimmer.types.LatentsDomainGroupT": {"fullname": "shimmer.types.LatentsDomainGroupT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupT", "kind": "variable", "doc": "Matched unimodal latent representations from multiple domains.\nKeys of the mapping are domains names and values are torch.Tensor
latent\nrepresentation of the domain.
\n\nExample: \n\n\n \n
def fun ( domain_group : LatentsDomainGroupT ): ... \n\n\nx = { \n "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]), \n "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]), \n} \n\nfun ( x ) \n
\n
\n \n\nNote: \n\n\n This type uses collections.abc.Mapping
and is used for functions' inputs.\n Use LatentsDomainGroupDT
for functions' outputs.
\n \n This allows to be more generic and allow passing other mappings.
\n \n", "default_value": "collections.abc.Mapping[str, torch.Tensor]"}, "shimmer.types.LatentsDomainGroupDT": {"fullname": "shimmer.types.LatentsDomainGroupDT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupDT", "kind": "variable", "doc": "Matched unimodal latent representations from multiple domains.\nKeys of the dict are domains names and values are torch.Tensor
latent\nrepresentation of the domain.
\n\nExample: \n\n\n \n
def fun () -> LatentsDomainGroupDT : \n return { \n "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]), \n "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]), \n } \n
\n
\n \n\nNote: \n\n\n This type uses dict
s and is used for functions' outputs.\n Use LatentsDomainGroupT
for functions' inputs.
\n \n", "default_value": "dict[str, torch.Tensor]"}, "shimmer.types.RawDomainGroupsT": {"fullname": "shimmer.types.RawDomainGroupsT", "modulename": "shimmer.types", "qualname": "RawDomainGroupsT", "kind": "variable", "doc": "Mapping of RawDomainGroupT
. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).
\n\nExample: \n\n\n \n
def fun () -> RawDomainGroupsDT : \n return { \n frozenset ([ "vision" ]): { \n "vision" : PIL . Image . Image ( "path/to/cat/picture.png" ), \n }, \n frozenset ([ "language" ]): { \n "language" : "This is a picture of a rabbit." , \n }, \n frozenset ([ "vision" , "language" ]): { \n "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ), \n "language" : "This is a picture of a dog." , \n }, \n } \n
\n
\n \n\nNote: \n\n\n This type uses dict
s and is used for functions' outputs.\n Use RawDomainGroupsT
for functions' inputs.
\n \n", "default_value": "collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]"}, "shimmer.types.RawDomainGroupsDT": {"fullname": "shimmer.types.RawDomainGroupsDT", "modulename": "shimmer.types", "qualname": "RawDomainGroupsDT", "kind": "variable", "doc": "Mapping of RawDomainGroupT
. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).
\n\nExample: \n\n\n \n
def fun () -> RawDomainGroupsDT : \n return { \n frozenset ([ "vision" ]): { \n "vision" : PIL . Image . Image ( "path/to/cat/picture.png" ), \n }, \n frozenset ([ "language" ]): { \n "language" : "This is a picture of a rabbit." , \n }, \n frozenset ([ "vision" , "language" ]): { \n "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ), \n "language" : "This is a picture of a dog." , \n }, \n } \n
\n
\n \n\nNote: \n\n\n This type uses dict
s and is used for functions' outputs.\n Use RawDomainGroupsT
for functions' inputs.
\n \n", "default_value": "dict[frozenset[str], dict[str, typing.Any]]"}, "shimmer.types.LatentsDomainGroupsT": {"fullname": "shimmer.types.LatentsDomainGroupsT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupsT", "kind": "variable", "doc": "Mapping of LatentsDomainGroupT
. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).
\n\nExample: \n\n\n \n
def fun ( domain_group : LatentsDomainGroupsT ): ... \n\n\nx = { \n frozenset ([ "vision" ]): { \n "vision" : torch . Tensor ([ 1.0 , 0.0 , 0.3 , ... ]), \n }, \n frozenset ([ "language" ]): { \n "language" : torch . Tensor ([ 1.0 , 0.2 , 0.9 , ... ]), \n }, \n frozenset ([ "vision" , "language" ]): { \n "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]), \n "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]), \n }, \n} \n\nfun ( x ) \n
\n
\n \n\nNote: \n\n\n This type uses collections.abc.Mapping
and is used for functions' inputs.\n Use LatentsDomainGroupsDT
for functions' outputs.
\n \n This allows to be more generic and allow passing other mappings.
\n \n", "default_value": "collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]"}, "shimmer.types.LatentsDomainGroupsDT": {"fullname": "shimmer.types.LatentsDomainGroupsDT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupsDT", "kind": "variable", "doc": "Mapping of LatentsDomainGroupDT
.\nKeys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).
\n\nExample: \n\n\n \n
def fun () -> LatentsDomainGroupsDT : \n return { \n frozenset ([ "vision" ]): { \n "vision" : torch . Tensor ([ 1.0 , 0.0 , 0.3 , ... ]), \n }, \n frozenset ([ "language" ]): { \n "language" : torch . Tensor ([ 1.0 , 0.2 , 0.9 , ... ]), \n }, \n frozenset ([ "vision" , "language" ]): { \n "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]), \n "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]), \n }, \n } \n
\n
\n \n\nNote: \n\n\n This type uses dict
s and is used for functions' outputs.\n Use LatentsDomainGroupT
for functions' inputs.
\n \n", "default_value": "dict[frozenset[str], dict[str, torch.Tensor]]"}, "shimmer.types.ModelModeT": {"fullname": "shimmer.types.ModelModeT", "modulename": "shimmer.types", "qualname": "ModelModeT", "kind": "variable", "doc": "Mode used by pytorch lightning (train/val, ...).
\n\nWhen validating or testing in out-of-distribution data, \"val/ood\" or \"test/ood\" mode is\nused.
\n", "default_value": "typing.Literal['train', 'val', 'test', 'val/ood', 'test/ood']"}, "shimmer.modules.global_workspace": {"fullname": "shimmer.modules.global_workspace", "modulename": "shimmer.modules.global_workspace", "kind": "module", "doc": "
\n"}, "shimmer.modules.global_workspace.SchedulerArgs": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs", "kind": "class", "doc": "TypedDict of arguments passed to the OneCycle scheduler
\n", "bases": "typing.TypedDict"}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs.max_lr", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs.max_lr", "kind": "variable", "doc": "Maximum learning rate
\n", "annotation": ": float"}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs.total_steps", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs.total_steps", "kind": "variable", "doc": "Total number of steps
\n", "annotation": ": int"}, "shimmer.modules.global_workspace.GWPredictionsBase": {"fullname": "shimmer.modules.global_workspace.GWPredictionsBase", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictionsBase", "kind": "class", "doc": "TypedDict of the output given when calling GlobalWorkspaceBase.predict
\n", "bases": "typing.TypedDict"}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"fullname": "shimmer.modules.global_workspace.GWPredictionsBase.states", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictionsBase.states", "kind": "variable", "doc": "GW state representation from domain groups with only one domain.\nThe key represent the domain's name.
\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase", "kind": "class", "doc": "Global Workspace Lightning Module.
\n\nThis is the base class to build the Global Workspace.
\n", "bases": "typing.Generic[~_T_gw_mod, ~_T_selection_mod, ~_T_loss_mod], lightning.pytorch.core.module.LightningModule"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.gw_mod", "kind": "variable", "doc": "a GWModuleBase
implementation.
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.selection_mod", "kind": "variable", "doc": "A SelectionBase
implementation.
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.loss_mod", "kind": "variable", "doc": "The module that computes losses of the GW
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.optim_lr", "kind": "variable", "doc": "
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.optim_weight_decay", "kind": "variable", "doc": "
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.scheduler_args", "kind": "variable", "doc": "
\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.domain_mods", "kind": "variable", "doc": "
\n", "annotation": ": collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.workspace_dim", "kind": "variable", "doc": "Dimension of the GW.
\n", "annotation": ": int"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_and_fuse", "kind": "function", "doc": "Encode a group of latent representations into the GW representation.
\n\nArguments: \n\n\nx (LatentsDomainGroupsT
): the input domain representations. \nselection_scores (Mapping[str, torch.Tensor]
): \n \n\nReturns: \n\n\n dict[frozenset[str], torch.Tensor]
: the GW representations.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tselection_module : shimmer . modules . selection . SelectionBase ) -> dict [ frozenset [ str ], torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode", "kind": "function", "doc": "Encode a group of latent representations into the pre-fusion GW representation.
\n\nArguments: \n\n\nx (LatentsDomainGroupsT
): the input domain representations. \n \n\nReturns: \n\n\n LatensDomainGroupsDT
: the GW representations.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.fuse", "kind": "function", "doc": "Fuses a group of latent representations into the GW representation.
\n\nArguments: \n\n\nx (LatentsDomainGroupsT
): the pre-fusion latent representations \nselection_scores (Mapping[frozenset[str], Mapping[str, torch.Tensor]]
): selection scores for each group \n \n\nReturns: \n\n\n dict[frozenset[str], torch.Tensor]
: GW representation of each group
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tselection_scores : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode", "kind": "function", "doc": "Decode the group GW representation into given domains
.
\n\nArguments: \n\n\nz (torch.Tensor
): the GW representation. \ndomains (Iterable[str]
): iterable of domains to decode. \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: the decoded unimodal representations.
\n \n", "signature": "(\tself , \tz : collections . abc . Mapping [ frozenset [ str ], torch . Tensor ] , \tdomains : collections . abc . Iterable [ str ] | None = None ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.batch_gw_states", "kind": "function", "doc": "Comptues GW states of a batch of groups of domains.
\n\nArguments: \n\n\nlatent_domains (LatentsT
): the batch of groups of domains \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: states for each domain.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_domain", "kind": "function", "doc": "Encodes a domain from the domain data into the unimodal representation.
\n\nThis is a convenient proxy for the DomainModule.encode
method and is\nequivalent to:
\n\n\n
self . domain_mods [ name ] . encode ( domain ) \n
\n
\n\nArguments: \n\n\ndomain (Any
): the domain data \nname (str
): domain name to encode \n \n\nReturns: \n\n\n torch.Tensor
: the domain's unimodal representation.
\n \n", "signature": "(self , domain : Any , name : str ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_domains", "kind": "function", "doc": "Encode all domains in the batch.
\n\nArguments: \n\n\nbatch (RawDomainGroupsT
): the batch of\ndomain groups with raw unimodal data to encode into groups of latent\nrepresentations. \n \n\nReturns: \n\n\n LatentsDomainGroupsDT
: the domains' unimodal representations.
\n \n", "signature": "(\tself , \tbatch : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , typing . Any ]] ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode_domain", "kind": "function", "doc": "Decodes a domain from the unimodal representation into the domain data.
\n\nThis is a convenient proxy for the DomainModule.encode
method and is\nequivalent to:
\n\n\n
self . domain_mods [ name ] . decode ( domain ) \n
\n
\n\nArguments: \n\n\ndomain (torch.Tensor
): the domain data \nname (str
): domain name to encode \n \n\nReturns: \n\n\n Any
: the domain's raw data.
\n \n", "signature": "(self , domain : torch . Tensor , name : str ) -> Any : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode_domains", "kind": "function", "doc": "Decodes all domains in the batch.
\n\nArguments: \n\n\nbatch (LatentsDomainGroupsT
): the batch of\ndomain groups with unimodal latent representation to decode into\ngroups of raw data. \n \n\nReturns: \n\n\n LatentsDomainGroupsDT
: the domains' raw data.
\n \n", "signature": "(\tself , \tlatents_domain : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], dict [ str , typing . Any ]] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.generic_step", "kind": "function", "doc": "The generic step used in training_step
, validation_step
and\ntest_step
.
\n\nArguments: \n\n\nbatch (RawDomainGroupsT
): the batch of groups of raw unimodal data. \nmode (ModelModeT
): \n \n\nReturns: \n\n\n torch.Tensor
: the loss to train on.
\n \n", "signature": "(\tself , \tbatch : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , typing . Any ]] , \tmode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.global_workspace.freeze_domain_modules": {"fullname": "shimmer.modules.global_workspace.freeze_domain_modules", "modulename": "shimmer.modules.global_workspace", "qualname": "freeze_domain_modules", "kind": "function", "doc": "Freezes weights and set to eval mode the domain modules.
\n\n\n\n
The output is casted as dict[str, DomainModule]
type for better\nauto-completion, but is actually a torch ModuleDict
.
\n\n
\n\nArguments: \n\n\ndomain_mods (Mapping[str, DomainModule]
): mapping of domain modules to freeze \n \n\nReturns: \n\n\n ModuleDict
: frozen modules.
\n \n", "signature": "(\tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] ) -> dict [ str , shimmer . modules . domain . DomainModule ] : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GWPredictions": {"fullname": "shimmer.modules.global_workspace.GWPredictions", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions", "kind": "class", "doc": "TypedDict of the output given when calling GlobalWorkspaceBase.predict
\n", "bases": "builtins.dict"}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"fullname": "shimmer.modules.global_workspace.GWPredictions.demi_cycles", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.demi_cycles", "kind": "variable", "doc": "Demi-cycle predictions of the model for each domain. Only computed on domain\ngroups with only one domain.
\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"fullname": "shimmer.modules.global_workspace.GWPredictions.cycles", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.cycles", "kind": "variable", "doc": "Cycle predictions of the model from one domain through another one.\nOnly computed on domain groups with more than one domain.\nThe keys are tuple with start domain and intermediary domain.
\n", "annotation": ": dict[tuple[str, str], torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.translations": {"fullname": "shimmer.modules.global_workspace.GWPredictions.translations", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.translations", "kind": "variable", "doc": "Translation predictions of the model from one domain through another one.
\n\nOnly computed on domain groups with more than one domain.\nThe keys are tuples with start domain and target domain.
\n", "annotation": ": dict[tuple[str, str], torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.states": {"fullname": "shimmer.modules.global_workspace.GWPredictions.states", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.states", "kind": "variable", "doc": "
\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains", "kind": "class", "doc": "A simple 2-domains max flavor of GlobalWorkspaceBase.
\n\nThis is used to simplify a Global Workspace instanciation and only overrides the\n__init__
method.
\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModule, shimmer.modules.selection.SingleDomainSelection, shimmer.modules.losses.GWLosses2Domains]"}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains.__init__", "kind": "function", "doc": "Initializes a Global Workspace
\n\nArguments: \n\n\ndomain_mods (Mapping[str, DomainModule]
): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule
. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to decode a\nGW representation into a unimodal latent representations. \nworkspace_dim (int
): dimension of the GW. \nloss_coefs (LossCoefs
): loss coefficients \noptim_lr (float
): learning rate \noptim_weight_decay (float
): weight decay \nscheduler_args (SchedulerArgs | None
): optimization scheduler's arguments \nlearn_logit_scale (bool
): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss. \ncontrastive_loss (ContrastiveLossType | None
): a contrastive loss\nfunction used for alignment. learn_logit_scale
will not affect custom\ncontrastive losses. \n \n", "signature": "(\tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tworkspace_dim : int , \tloss_coefs : shimmer . modules . losses . LossCoefs , \toptim_lr : float = 0.001 , \toptim_weight_decay : float = 0.0 , \tscheduler_args : shimmer . modules . global_workspace . SchedulerArgs | None = None , \tlearn_logit_scale : bool = False , \tcontrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] | None = None ) "}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains.forward", "kind": "function", "doc": "Computes demi-cycles, cycles, and translations.
\n\nArguments: \n\n\nlatent_domains (LatentsT
): Groups of domains for the computation. \n \n\nReturns: \n\n\n GWPredictions
: the predictions on the batch.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> shimmer . modules . global_workspace . GWPredictions : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspace": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace", "kind": "class", "doc": "The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
\n\nThis is used to simplify a Global Workspace instanciation and only overrides the\n__init__
method.
\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModule, shimmer.modules.selection.RandomSelection, shimmer.modules.losses.GWLosses]"}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace.__init__", "kind": "function", "doc": "Initializes a Global Workspace
\n\nArguments: \n\n\ndomain_mods (Mapping[str, DomainModule]
): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule
. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to decode a\nGW representation into a unimodal latent representations. \nworkspace_dim (int
): dimension of the GW. \nloss_coefs (BroadcastLossCoefs
): loss coefs for the losses. \nselection_temperature (float
): temperature value for the RandomSelection\nmodule. \noptim_lr (float
): learning rate \noptim_weight_decay (float
): weight decay \nscheduler_args (SchedulerArgs | None
): optimization scheduler's arguments \nlearn_logit_scale (bool
): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss. \ncontrastive_loss (ContrastiveLossType | None
): a contrastive loss\nfunction used for alignment. learn_logit_scale
will not affect custom\ncontrastive losses. \n \n", "signature": "(\tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tworkspace_dim : int , \tloss_coefs : shimmer . modules . losses . BroadcastLossCoefs , \tselection_temperature : float = 0.2 , \toptim_lr : float = 0.001 , \toptim_weight_decay : float = 0.0 , \tscheduler_args : shimmer . modules . global_workspace . SchedulerArgs | None = None , \tlearn_logit_scale : bool = False , \tcontrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] | None = None ) "}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace.forward", "kind": "function", "doc": "Computes demi-cycles, cycles, and translations.
\n\nArguments: \n\n\nlatent_domains (LatentsT
): Groups of domains for the computation. \n \n\nReturns: \n\n\n GWPredictions
: the predictions on the batch.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> shimmer . modules . global_workspace . GWPredictions : ", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian", "kind": "class", "doc": "A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty\nprediction.
\n\nThis is used to simplify a Global Workspace instanciation and only overrides the\n__init__
method.
\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModuleBayesian, shimmer.modules.selection.FixedSharedSelection, shimmer.modules.losses.GWLossesBayesian]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian.__init__", "kind": "function", "doc": "Initializes a Global Workspace
\n\nArguments: \n\n\ndomain_mods (Mapping[str, DomainModule]
): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule
. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to decode a\nGW representation into a unimodal latent representations. \nworkspace_dim (int
): dimension of the GW. \nloss_coefs (LossCoefs
): loss coefficients \nsensitivity_selection (float
): sensivity coef $c'_1$ \nsensitivity_precision (float
): sensitivity coef $c'_2$ \noptim_lr (float
): learning rate \noptim_weight_decay (float
): weight decay \nscheduler_args (SchedulerArgs | None
): optimization scheduler's arguments \nlearn_logit_scale (bool
): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss. \nuse_normalized_constrastive (bool
): whether to use the normalized cont\nloss by the precision coefs \ncontrastive_loss (ContrastiveLossType | None
): a contrastive loss\nfunction used for alignment. learn_logit_scale
will not affect custom\ncontrastive losses. \nprecision_softmax_temp (float
): temperature to use in softmax of\nprecision \n \n", "signature": "(\tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tworkspace_dim : int , \tloss_coefs : shimmer . modules . losses . BroadcastLossCoefs , \tsensitivity_selection : float = 1 , \tsensitivity_precision : float = 1 , \toptim_lr : float = 0.001 , \toptim_weight_decay : float = 0.0 , \tscheduler_args : shimmer . modules . global_workspace . SchedulerArgs | None = None , \tlearn_logit_scale : bool = False , \tuse_normalized_constrastive : bool = True , \tcontrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] | None = None , \tprecision_softmax_temp : float = 0.01 ) "}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian.forward", "kind": "function", "doc": "Computes demi-cycles, cycles, and translations.
\n\nArguments: \n\n\nlatent_domains (LatentsT
): Groups of domains for the computation. \n \n\nReturns: \n\n\n GWPredictions
: the predictions on the batch.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> shimmer . modules . global_workspace . GWPredictions : ", "funcdef": "def"}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"fullname": "shimmer.modules.global_workspace.pretrained_global_workspace", "modulename": "shimmer.modules.global_workspace", "qualname": "pretrained_global_workspace", "kind": "function", "doc": "Load a GlobalWorkspace
flavor of GlobalWorkspaceBase
from a checkpoint.
\n\nArguments: \n\n\ncheckpoint_path (str | Path
): path to checkpoint \ndomain_mods (Mapping[str, DomainModule]
): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule
. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a torch.nn.Module
class which role is to decode a\nGW representation into a unimodal latent representations. \nworkspace_dim (int
): dimension of the GW. \nloss_coefs (LossCoefs
): loss coefficients \ncontrastive_loss (ContrastiveLossType
): a contrastive loss\nfunction used for alignment. learn_logit_scale
will not affect custom\ncontrastive losses. \n**kwargs: additional arguments to pass to\nGlobalWorkspace.load_from_checkpoint
. \n \n\nReturns: \n\n\n GlobalWorkspace
: the pretrained GlobalWorkspace
.
\n \n\nRaises: \n\n\nTypeError
: if loaded type is not GlobalWorkspace
. \n \n", "signature": "(\tcheckpoint_path : str | pathlib . Path , \tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tworkspace_dim : int , \tloss_coefs : shimmer . modules . losses . LossCoefs , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] , \t** kwargs ) -> shimmer . modules . global_workspace . GlobalWorkspace2Domains : ", "funcdef": "def"}, "shimmer.modules.domain": {"fullname": "shimmer.modules.domain", "modulename": "shimmer.modules.domain", "kind": "module", "doc": "
\n"}, "shimmer.modules.domain.LossOutput": {"fullname": "shimmer.modules.domain.LossOutput", "modulename": "shimmer.modules.domain", "qualname": "LossOutput", "kind": "class", "doc": "This is a python dataclass use as a returned value for losses.\nIt keeps track of what is used for training (loss
) and what is used\nonly for logging (metrics
).
\n"}, "shimmer.modules.domain.LossOutput.__init__": {"fullname": "shimmer.modules.domain.LossOutput.__init__", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.__init__", "kind": "function", "doc": "
\n", "signature": "(loss : torch . Tensor , metrics : dict [ str , torch . Tensor ] = < factory > ) "}, "shimmer.modules.domain.LossOutput.loss": {"fullname": "shimmer.modules.domain.LossOutput.loss", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.loss", "kind": "variable", "doc": "Loss used during training.
\n", "annotation": ": torch.Tensor"}, "shimmer.modules.domain.LossOutput.metrics": {"fullname": "shimmer.modules.domain.LossOutput.metrics", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.metrics", "kind": "variable", "doc": "Some additional metrics to log (not used during training).
\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.domain.LossOutput.all": {"fullname": "shimmer.modules.domain.LossOutput.all", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.all", "kind": "variable", "doc": "Returns a dict with all metrics and loss with \"loss\" key.
\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.domain.DomainModule": {"fullname": "shimmer.modules.domain.DomainModule", "modulename": "shimmer.modules.domain", "qualname": "DomainModule", "kind": "class", "doc": "Base class for a DomainModule that defines domain specific modules of the GW.
\n", "bases": "lightning.pytorch.core.module.LightningModule"}, "shimmer.modules.domain.DomainModule.__init__": {"fullname": "shimmer.modules.domain.DomainModule.__init__", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.__init__", "kind": "function", "doc": "Initializes a DomainModule.
\n\nArguments: \n\n\nlatent_dim (int
): latent dimension of the unimodal module \n \n", "signature": "(latent_dim : int ) "}, "shimmer.modules.domain.DomainModule.latent_dim": {"fullname": "shimmer.modules.domain.DomainModule.latent_dim", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.latent_dim", "kind": "variable", "doc": "The latent dimension of the module.
\n"}, "shimmer.modules.domain.DomainModule.encode": {"fullname": "shimmer.modules.domain.DomainModule.encode", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.encode", "kind": "function", "doc": "Encode the domain data into a unimodal representation.
\n\nArguments: \n\n\nx (Any
): data of the domain. \n \n\nReturns: \n\n\n torch.Tensor
: a unimodal representation.
\n \n", "signature": "(self , x : Any ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.decode": {"fullname": "shimmer.modules.domain.DomainModule.decode", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.decode", "kind": "function", "doc": "Decode data from unimodal representation back to the domain data.
\n\nArguments: \n\n\nz (torch.Tensor
): unimodal representation of the domain. \n \n\nReturns: \n\n\n Any
: the original domain data.
\n \n", "signature": "(self , z : torch . Tensor ) -> Any : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_loss", "kind": "function", "doc": "Generic loss computation the modality.
\n\nArguments: \n\n\npred (torch.Tensor
): prediction of the model \ntarget (torch.Tensor
): target tensor \n \n\nResults: \n\n\n LossOutput
: LossOuput with training loss and additional metrics.
\n \n", "signature": "(\tself , \tpred : torch . Tensor , \ttarget : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_dcy_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_dcy_loss", "kind": "function", "doc": "Computes the loss for a demi-cycle. Override if the demi-cycle loss is\ndifferent that the generic loss.
\n\nArguments: \n\n\npred (torch.Tensor
): prediction of the model \ntarget (torch.Tensor
): target tensor \n \n\nResults: \n\n\n LossOutput
: LossOuput with training loss and additional metrics.
\n \n", "signature": "(\tself , \tpred : torch . Tensor , \ttarget : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_cy_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_cy_loss", "kind": "function", "doc": "Computes the loss for a cycle. Override if the cycle loss is\ndifferent that the generic loss.
\n\nArguments: \n\n\npred (torch.Tensor
): prediction of the model \ntarget (torch.Tensor
): target tensor \n \n\nResults: \n\n\n LossOutput
: LossOuput with training loss and additional metrics.
\n \n", "signature": "(\tself , \tpred : torch . Tensor , \ttarget : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_tr_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_tr_loss", "kind": "function", "doc": "Computes the loss for a translation. Override if the translation loss is\ndifferent that the generic loss.
\n\nArguments: \n\n\npred (torch.Tensor
): prediction of the model \ntarget (torch.Tensor
): target tensor \n \n\nResults: \n\n\n LossOutput
: LossOuput with training loss and additional metrics.
\n \n", "signature": "(\tself , \tpred : torch . Tensor , \ttarget : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_broadcast_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_broadcast_loss", "kind": "function", "doc": "Computes the loss for a broadcast (fusion). Override if the broadcast loss is\ndifferent that the generic loss.
\n\nArguments: \n\n\npred (torch.Tensor
): prediction of the model \ntarget (torch.Tensor
): target tensor \n \n\nResults: \n\n\n LossOutput
: LossOuput with training loss and additional metrics.
\n \n", "signature": "(\tself , \tpred : torch . Tensor , \ttarget : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.gw_module": {"fullname": "shimmer.modules.gw_module", "modulename": "shimmer.modules.gw_module", "kind": "module", "doc": "
\n"}, "shimmer.modules.gw_module.get_n_layers": {"fullname": "shimmer.modules.gw_module.get_n_layers", "modulename": "shimmer.modules.gw_module", "qualname": "get_n_layers", "kind": "function", "doc": "Makes a list of n_layers
nn.Linear
layers with nn.ReLU
.
\n\nArguments: \n\n\nn_layers (int
): number of layers \nhidden_dim (int
): size of the hidden dimension \n \n\nReturns: \n\n\n list[nn.Module]
: list of linear and relu layers.
\n \n", "signature": "(n_layers : int , hidden_dim : int ) -> list [ torch . nn . modules . module . Module ] : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWDecoder": {"fullname": "shimmer.modules.gw_module.GWDecoder", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder", "kind": "class", "doc": "A Decoder network for GWModules.
\n", "bases": "torch.nn.modules.container.Sequential"}, "shimmer.modules.gw_module.GWDecoder.__init__": {"fullname": "shimmer.modules.gw_module.GWDecoder.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.__init__", "kind": "function", "doc": "Initializes the decoder.
\n\nArguments: \n\n\nin_dim (int
): input dimension \nhidden_dim (int
): hidden dimension \nout_dim (int
): output dimension \nn_layers (int
): number of hidden layers. The total number of layers\nwill be n_layers
+ 2 (one before, one after). \n \n", "signature": "(in_dim : int , hidden_dim : int , out_dim : int , n_layers : int ) "}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.in_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.in_dim", "kind": "variable", "doc": "input dimension
\n"}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.hidden_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.hidden_dim", "kind": "variable", "doc": "hidden dimension
\n"}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.out_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.out_dim", "kind": "variable", "doc": "output dimension
\n"}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"fullname": "shimmer.modules.gw_module.GWDecoder.n_layers", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.n_layers", "kind": "variable", "doc": "number of hidden layers. The total number of layers\n will be n_layers
+ 2 (one before, one after).
\n"}, "shimmer.modules.gw_module.GWEncoder": {"fullname": "shimmer.modules.gw_module.GWEncoder", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder", "kind": "class", "doc": "An Encoder network used in GWModules.
\n\nThis is similar to the decoder, but adds a tanh non-linearity at the end.
\n", "bases": "GWDecoder"}, "shimmer.modules.gw_module.GWEncoder.__init__": {"fullname": "shimmer.modules.gw_module.GWEncoder.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder.__init__", "kind": "function", "doc": "Initializes the encoder.
\n\nArguments: \n\n\nin_dim (int
): input dimension \nhidden_dim (int
): hidden dimension \nout_dim (int
): output dimension \nn_layers (int
): number of hidden layers. The total number of layers\nwill be n_layers
+ 2 (one before, one after). \n \n", "signature": "(in_dim : int , hidden_dim : int , out_dim : int , n_layers : int ) "}, "shimmer.modules.gw_module.GWEncoder.forward": {"fullname": "shimmer.modules.gw_module.GWEncoder.forward", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder.forward", "kind": "function", "doc": "Define the computation performed at every call.
\n\nShould be overridden by all subclasses.
\n\n\n\n
Although the recipe for forward pass needs to be defined within\nthis function, one should call the Module
instance afterwards\ninstead of this since the former takes care of running the\nregistered hooks while the latter silently ignores them.
\n\n
\n", "signature": "(self , input : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWEncoderLinear": {"fullname": "shimmer.modules.gw_module.GWEncoderLinear", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoderLinear", "kind": "class", "doc": "A linear Encoder network used in GWModules.
\n", "bases": "torch.nn.modules.linear.Linear"}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"fullname": "shimmer.modules.gw_module.GWEncoderLinear.forward", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoderLinear.forward", "kind": "function", "doc": "Define the computation performed at every call.
\n\nShould be overridden by all subclasses.
\n\n\n\n
Although the recipe for forward pass needs to be defined within\nthis function, one should call the Module
instance afterwards\ninstead of this since the former takes care of running the\nregistered hooks while the latter silently ignores them.
\n\n
\n", "signature": "(self , input : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase": {"fullname": "shimmer.modules.gw_module.GWModuleBase", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase", "kind": "class", "doc": "Base class for GWModule.
\n\nGWModule handles encoding, decoding the unimodal representations\nusing the gw_encoders
andgw_decoders
, and define\nsome common operations in GW like cycles and translations.
\n\nThis is an abstract class and should be implemented.\nFor an implemented interface, see GWModule
.
\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"fullname": "shimmer.modules.gw_module.GWModuleBase.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.__init__", "kind": "function", "doc": "Initializes the GWModule.
\n\nArguments: \n\n\ndomain_modules (Mapping[str, DomainModule]
): the domain modules. \nworkspace_dim (int
): dimension of the GW. \n \n", "signature": "(\tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tworkspace_dim : int , \t* args , \t** kwargs ) "}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"fullname": "shimmer.modules.gw_module.GWModuleBase.domain_mods", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.domain_mods", "kind": "variable", "doc": "The unimodal domain modules.
\n"}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"fullname": "shimmer.modules.gw_module.GWModuleBase.workspace_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.workspace_dim", "kind": "variable", "doc": "Dimension of the GW
\n"}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBase.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.fuse", "kind": "function", "doc": "Merge function used to combine domains.
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the group of latent representation. \nselection_score (Mapping[str, torch.Tensor]
): attention scores to\nuse to encode the reprensetation. \n \n\nReturns: \n\n\n torch.Tensor
: The merged representation.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tselection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.encode": {"fullname": "shimmer.modules.gw_module.GWModuleBase.encode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.encode", "kind": "function", "doc": "Encode the latent representation infos to the pre-fusion GW representation.
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the input domain representations \n \n\nReturns: \n\n\n LatentsDomainGroupT
: pre-fusion GW representations
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.encode_and_fuse", "kind": "function", "doc": "Encode the latent representation infos to the final GW representation.\nIt combines the encode and fuse methods.
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the input domain representations \nselection_score (Mapping[str, torch.Tensor]
): attention scores to\nuse to encode the reprensetation. \n \n\nReturns: \n\n\n torch.Tensor
: The merged representation.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tselection_module : shimmer . modules . selection . SelectionBase ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.decode": {"fullname": "shimmer.modules.gw_module.GWModuleBase.decode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.decode", "kind": "function", "doc": "Decode the GW representation into given domains
.
\n\nArguments: \n\n\nz (torch.Tensor
): the GW representation. \ndomains (Iterable[str]
): iterable of domains to decode. \n \n\nReturns: \n\n\n LatentsDomainGroupDT
: the decoded unimodal representations.
\n \n", "signature": "(\tself , \tz : torch . Tensor , \tdomains : collections . abc . Iterable [ str ] | None = None ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule": {"fullname": "shimmer.modules.gw_module.GWModule", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule", "kind": "class", "doc": "GW nn.Module. Implements GWModuleBase
.
\n", "bases": "GWModuleBase"}, "shimmer.modules.gw_module.GWModule.__init__": {"fullname": "shimmer.modules.gw_module.GWModule.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.__init__", "kind": "function", "doc": "Initializes the GWModule.
\n\nArguments: \n\n\ndomain_modules (Mapping[str, DomainModule]
): the domain modules. \nworkspace_dim (int
): dimension of the GW. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a an torch.nn.Module class that encodes a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a an torch.nn.Module class that decodes a\n GW representation to a unimodal latent representation. \n \n", "signature": "(\tdomain_modules : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tworkspace_dim : int , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] ) "}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"fullname": "shimmer.modules.gw_module.GWModule.gw_encoders", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.gw_encoders", "kind": "variable", "doc": "The module's encoders
\n"}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"fullname": "shimmer.modules.gw_module.GWModule.gw_decoders", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.gw_decoders", "kind": "variable", "doc": "The module's decoders
\n"}, "shimmer.modules.gw_module.GWModule.fuse": {"fullname": "shimmer.modules.gw_module.GWModule.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.fuse", "kind": "function", "doc": "Merge function used to combine domains.
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the group of latent representation. \nselection_score (Mapping[str, torch.Tensor]
): attention scores to\nuse to encode the reprensetation. \n \n\nReturns: \n\n\n torch.Tensor
: The merged representation.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tselection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule.encode": {"fullname": "shimmer.modules.gw_module.GWModule.encode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.encode", "kind": "function", "doc": "Encode the latent representation infos to the pre-fusion GW representation.
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the input domain representations. \n \n\nReturns: \n\n\n LatentsDomainGroupT
: pre-fusion representation
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule.decode": {"fullname": "shimmer.modules.gw_module.GWModule.decode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.decode", "kind": "function", "doc": "Decodes a GW representation to multiple domains.
\n\nArguments: \n\n\nz (torch.Tensor
): the GW representation \ndomains (Iterable[str] | None
): the domains to decode to. Defaults to\nuse keys in gw_interfaces
(all domains). \n \n\nReturns: \n\n\n LatentsDomainGroupDT
: decoded unimodal representation
\n \n", "signature": "(\tself , \tz : torch . Tensor , \tdomains : collections . abc . Iterable [ str ] | None = None ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.gw_module.compute_fusion_scores": {"fullname": "shimmer.modules.gw_module.compute_fusion_scores", "modulename": "shimmer.modules.gw_module", "qualname": "compute_fusion_scores", "kind": "function", "doc": "Combine precision scores using std summation in quadrature
\n\nThe two scores should have the same dimension.
\n\nArguments: \n\n\nscore_1 (torch.Tensor
): First scores. \nscore_2 (torch.Tensor
): Second scores. \nsensitivity_1 (float
): sensitivity for the first score \nsensitivity_2 (float
): sensitivity for the second score \neps (float
): a value added to avoid numerical unstability. \n \n\nReturns: \n\n\n torch.Tensor
: the combined scores
\n \n", "signature": "(\tscore_1 : torch . Tensor , \tscore_2 : torch . Tensor , \tsensitivity_1 : float = 1.0 , \tsensitivity_2 : float = 1.0 , \teps : float = 1e-06 ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBayesian": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian", "kind": "class", "doc": "GWModule
with a Bayesian based uncertainty prediction.
\n", "bases": "GWModule"}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.__init__", "kind": "function", "doc": "Initializes the GWModuleBayesian.
\n\nArguments: \n\n\ndomain_modules (Mapping[str, DomainModule]
): the domain modules. \nworkspace_dim (int
): dimension of the GW. \ngw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a an torch.nn.Module class that encodes a\nunimodal latent representations into a GW representation (pre fusion). \ngw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain\nname to a an torch.nn.Module class that decodes a\n GW representation to a unimodal latent representation. \nsensitivity_selection (float
): sensivity coef $c'_1$ \nsensitivity_precision (float
): sensitivity coef $c'_2$ \nprecision_softmax_temp (float
): temperature to use in softmax of\nprecision \n \n", "signature": "(\tdomain_modules : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tworkspace_dim : int , \tgw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tgw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , \tsensitivity_selection : float = 1 , \tsensitivity_precision : float = 1 , \tprecision_softmax_temp : float = 0.01 ) "}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.precisions", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.precisions", "kind": "variable", "doc": "Precision at the neuron level for every domain.
\n"}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.sensitivity_selection", "kind": "variable", "doc": "
\n"}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.sensitivity_precision", "kind": "variable", "doc": "
\n"}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.precision_softmax_temp", "kind": "variable", "doc": "
\n"}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.get_precision", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.get_precision", "kind": "function", "doc": "Get the precision vector of given domain and batch
\n\nArguments: \n\n\ndomain (str
): \nx (torch.Tensor
): batch of inputs \n \n\nReturns: \n\n\n torch.Tensor
: batch of precision
\n \n", "signature": "(self , domain : str , x : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.fuse", "kind": "function", "doc": "Merge function used to combine domains.
\n\nIn the following, $D$ is the number of domains, $N$ the batch size, and $d$ the\ndimension of the Global Workspace.
\n\nThis function needs to merge two kind of scores:
\n\n\nthe selection scores $a\\in [0,1]^{D\\times N}$; \nthe precision scores $b \\in [0,1]^{D\\times N \\times d}$. \n \n\n\n\n
The precision score is obtained by predicting logits and using a softmax
\n\n
\n\nWe can obtain associated uncertainties to the scores by introducing a std\nvariable and using bayesian integration:
\n\n$$a_k = \\frac{M_1}{\\sigma_k^2}$$\nwhere $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
\n\nSimilarly,\n$$b_k = \\frac{M_2}{\\mu_k^2}$$\nwhere $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
\n\nThe we can sum the variances to obtain the final uncertainty (squared) $\\xi$:\n$$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
\n\nwhich, in terms of $a_k$ and $b_k$ yields:\n$$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$\nwhere $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
\n\nFinally, the finale combined coefficient is\n$$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$\nwhere\n$$M_3 = \\frac{1}{\\sum_{i=1}^D\n \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
\n\nArguments: \n\n\nx (LatentsDomainGroupT
): the group of latent representation. \nselection_score (Mapping[str, torch.Tensor]
): attention scores to\nuse to encode the reprensetation. \n \n\nReturns: \n\n\n torch.Tensor
: The merged representation.
\n \n", "signature": "(\tself , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tselection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.selection": {"fullname": "shimmer.modules.selection", "modulename": "shimmer.modules.selection", "kind": "module", "doc": "
\n"}, "shimmer.modules.selection.SelectionBase": {"fullname": "shimmer.modules.selection.SelectionBase", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase", "kind": "class", "doc": "This is the base class for the selection mechanism.\nThe selection mechanisms handles the \"competition\" between modules and selects \nfusion coefficients for the domains.
\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"fullname": "shimmer.modules.selection.SelectionBase.update_gw_state", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase.update_gw_state", "kind": "function", "doc": "Update the internal copy of the previous GW state.\nBy default, this is not implemented and will raise an error if used.
\n\n:note..\n This is not defined as an abstractmethod as some selection method may\n not need it.
\n\nArguments: \n\n\ngw_state (torch.Tensor
): the previous GW state \n \n", "signature": "(self , gw_state : torch . Tensor ) -> None : ", "funcdef": "def"}, "shimmer.modules.selection.SelectionBase.forward": {"fullname": "shimmer.modules.selection.SelectionBase.forward", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase.forward", "kind": "function", "doc": "Forward pass of the selection method.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): Group of unimodal latent representations. \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: for each domain in the group, the fusion\n coefficient for each item in the batch.
\n \n\nExample: \n\n\n \n
>>> SomeSelectionImplementation () . forward ( \n... { "v" : torch . randn ( 3 , 4 ), "t" : torch . randn ( 3 , 8 )} \n... ) \n{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} \n
\n
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.selection.SingleDomainSelection": {"fullname": "shimmer.modules.selection.SingleDomainSelection", "modulename": "shimmer.modules.selection", "qualname": "SingleDomainSelection", "kind": "class", "doc": "This selection mechanism handles groups that can have multiple domains, but always\nreturn a selection of 1 domain from the group with a uniform distribution.
\n\nFor example, if the group has 2 domains, there is a 50% chance of selecting each\ndomain.
\n", "bases": "SelectionBase"}, "shimmer.modules.selection.SingleDomainSelection.forward": {"fullname": "shimmer.modules.selection.SingleDomainSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "SingleDomainSelection.forward", "kind": "function", "doc": "Forward pass of the module.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): input unimodal latent representations \ngw_state (torch.Tensor
): the previous GW state \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: whether the domain is selected for each input\n in the batch.
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.selection.FixedSharedSelection": {"fullname": "shimmer.modules.selection.FixedSharedSelection", "modulename": "shimmer.modules.selection", "qualname": "FixedSharedSelection", "kind": "class", "doc": "This selection mechanism is deterministic and always shares the weights equally\nbetween domains.
\n\nFor example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
\n", "bases": "SelectionBase"}, "shimmer.modules.selection.FixedSharedSelection.forward": {"fullname": "shimmer.modules.selection.FixedSharedSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "FixedSharedSelection.forward", "kind": "function", "doc": "Forward pass of the module.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): input unimodal latent representations \ngw_state (torch.Tensor
): the previous GW state \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: whether the domain is selected for each input\n in the batch.
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.selection.KQFixedQSelection": {"fullname": "shimmer.modules.selection.KQFixedQSelection", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection", "kind": "class", "doc": "Key-Query attention with a fixed gw vector.
\n", "bases": "SelectionBase"}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"fullname": "shimmer.modules.selection.KQFixedQSelection.__init__", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.__init__", "kind": "function", "doc": "Arguments: \n\n\nhead_size (int
) : dimension of the key and query vectors. \ndomain_dim (int
) : dimension of the input dims (assumed to be the same\nfor now) \ndomain_names (Iterable[str]
) : list of input domains \n \n", "signature": "(\thead_size : int , \tdomain_dim : int , \tdomain_names : collections . abc . Iterable [ str ] ) "}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"fullname": "shimmer.modules.selection.KQFixedQSelection.head_size", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.head_size", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"fullname": "shimmer.modules.selection.KQFixedQSelection.query_layer", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.query_layer", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"fullname": "shimmer.modules.selection.KQFixedQSelection.key_layers", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.key_layers", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.KQFixedQSelection.forward": {"fullname": "shimmer.modules.selection.KQFixedQSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.forward", "kind": "function", "doc": "Compute keys and queries, match them with dot product and softmax.\nDoes this twice, once with the static query and once with a dynamic query.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): Group of unimodal latent representations. \nencodings (LatentsDomainGroupT
): Group of pre-fusion encodings. \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: the attention scores for each domain in the\n group.
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.selection.RandomSelection": {"fullname": "shimmer.modules.selection.RandomSelection", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection", "kind": "class", "doc": "Modified random attention to only utilize uniform-softmax scores across modalities.\nThis version omits the binary scaling factors and focuses on generating attention\ncoefficients using a uniform distribution followed by a domain-wise softmax.
\n", "bases": "SelectionBase"}, "shimmer.modules.selection.RandomSelection.__init__": {"fullname": "shimmer.modules.selection.RandomSelection.__init__", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.__init__", "kind": "function", "doc": "Arguments: \n\n\ntemperature (float
): Temperature of the softmax applied to uniform\nscaling factors. \n \n", "signature": "(temperature : float ) "}, "shimmer.modules.selection.RandomSelection.temperature": {"fullname": "shimmer.modules.selection.RandomSelection.temperature", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.temperature", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.RandomSelection.forward": {"fullname": "shimmer.modules.selection.RandomSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.forward", "kind": "function", "doc": "Generate uniform-then-domain-wise-softmaxed samples for each domain.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): Group of unimodal latent representations.\nThis is not used in the function directly but determines the structure\nof the returned attention coefficients. \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: For each domain in the group, the fusion\n coefficient for each item in the batch, based solely on\n uniform-softmax scores.
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.selection.DynamicQueryAttention": {"fullname": "shimmer.modules.selection.DynamicQueryAttention", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention", "kind": "class", "doc": "Key-Query attention with a dynamic gw vector.\nThe query is updated based on the scaled gw vector.
\n", "bases": "SelectionBase"}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.__init__", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.__init__", "kind": "function", "doc": "Arguments: \n\n\nhead_size (int
) : dimension of the key and query vectors. \ndomain_dim (int
) : dimension of the input dims (assumed to be the same\nfor now) \ndomain_names (Iterable[str]
) : list of input domains \n \n", "signature": "(\thead_size : int , \tdomain_dim : int , \tdomain_names : collections . abc . Iterable [ str ] ) "}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.head_size", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.head_size", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.query_layer", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.query_layer", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.key_layers", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.key_layers", "kind": "variable", "doc": "
\n"}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.fuse_weighted_encodings", "kind": "function", "doc": "Fuse the weighted encodings using the attention scores.
\n\nArguments: \n\n\nencodings (LatentsDomainGroupT
): Unimodal latent representation \nattention_dict (dict[str, torch.Tensor]
): The attention scores for each\ndomain in the group. \n \n\nReturns: \n\n\n torch.Tensor
: The fused tensor.
\n \n", "signature": "(\tself , \tencodings : collections . abc . Mapping [ str , torch . Tensor ] , \tattention_dict : dict [ str , torch . Tensor ] ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.forward", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.forward", "kind": "function", "doc": "Compute keys and queries, match them with dot product and softmax.\nDoes this twice, once with the static query and once with a dynamic query.
\n\nArguments: \n\n\ndomains (LatentsDomainGroupT
): Group of unimodal latent representations. \nencodings (LatentsDomainGroupT
): Group of pre-fusion encodings. \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: the attention scores for each domain in the\n group.
\n \n", "signature": "(\tself , \tdomains : collections . abc . Mapping [ str , torch . Tensor ] , \tencodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses": {"fullname": "shimmer.modules.losses", "modulename": "shimmer.modules.losses", "kind": "module", "doc": "
\n"}, "shimmer.modules.losses.GWLossesBase": {"fullname": "shimmer.modules.losses.GWLossesBase", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBase", "kind": "class", "doc": "Base Abstract Class for Global Workspace (GW) losses. This module is used\nto compute the different losses of the GW (typically translation, cycle,\ndemi-cycle, contrastive losses).
\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.losses.GWLossesBase.step": {"fullname": "shimmer.modules.losses.GWLossesBase.step", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBase.step", "kind": "function", "doc": "Computes the losses.
\n\nArguments: \n\n\ndomain_latents (LatentsDomainGroupsT
): All latent groups \nmode (Literal[\"train\", \"val\", \"test\", \"val/ood\", \"test/ood\"]
): model mode \n \n\nReturns: \n\n\n LossOutput
: the losses
\n \n", "signature": "(\tself , \tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tmode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.losses.demi_cycle_loss": {"fullname": "shimmer.modules.losses.demi_cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "demi_cycle_loss", "kind": "function", "doc": "Computes the demi-cycle loss.
\n\nThis return multiple metrics: \n\n\n \n demi_cycle_{domain_name}
with the demi-cycle of a particular domain; \n demi_cycle_{domain_name}_{metric}
with additional metrics provided by\n the domain_mod's compute_dcy_loss
output; \n demi_cycles
with the average value of all demi_cycle_{domain_name}
values. \n \n \n\nArguments: \n\n\ngw_mod (shimmer.modules.gw_module.GWModuleBase
): The GWModule to use \nselection_mod (shimmer.modules.selection.SelectionBase
): Selection mod to use \ndomain_mods (Mapping[str, DomainModule]
): the domain modules \nlatent_domains (shimmer.types.LatentsDomainGroupsT
): the latent unimodal\ngroups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.cycle_loss": {"fullname": "shimmer.modules.losses.cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "cycle_loss", "kind": "function", "doc": "Computes the cycle loss.
\n\nThis return multiple metrics: \n\n\n \n cycle_{domain_source}_through_{domain_target}
with the cycle of\n a particular domain; \n cycle_{domain_source}_through_{domain_target}_{metric}
with additional\n metrics provided by the domain_mod's compute_cy_loss
output; \n cycles
with the average value of all\n cycle_{domain_source}_through_{domain_target}
values. \n \n \n\nArguments: \n\n\ngw_mod (GWModuleBase
): The GWModule to use \nselection_mod (shimmer.modules.selection.SelectionBase
): Selection mod to use \ndomain_mods (Mapping[str, DomainModule]
): the domain modules \nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.translation_loss": {"fullname": "shimmer.modules.losses.translation_loss", "modulename": "shimmer.modules.losses", "qualname": "translation_loss", "kind": "function", "doc": "Computes the translation loss.
\n\nThis return multiple metrics: \n\n\n \n translation_{domain_source}_to_{domain_target}
with the translation\n from a domain source to a domain target; \n translation_{domain_source}_to_{domain_target}_{metric}
with\n additional metrics provided by the domain_mod's\n compute_tr_loss
output; \n translations
with the average value of all\n translation_{domain_source}_to_{domain_target}
values. \n \n \n\nArguments: \n\n\ngw_mod (GWModuleBase
): The GWModule to use \ndomain_mods (Mapping[str, DomainModule]
): the domain modules \nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.contrastive_loss": {"fullname": "shimmer.modules.losses.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "contrastive_loss", "kind": "function", "doc": "Computes the contrastive loss.
\n\nThis return multiple metrics: \n\n\n \n contrastive_{domain_1}_and_{domain_2}
with the contrastive\n between 2 domains; \n contrastive_{domain_1}_and_{domain_2}_{metric}
with\n additional metrics provided by the domain_mod's\n compute_cont_loss
output; \n contrastives
with the average value of all\n contrastive_{domain_1}_and_{domain_2}
values. \n \n \n\nArguments: \n\n\ngw_mod (GWModuleBase
): The GWModule to use \nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \ncontrastive_fn (ContrastiveLossType
): the contrastive function to apply \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.contrastive_loss_bayesian": {"fullname": "shimmer.modules.losses.contrastive_loss_bayesian", "modulename": "shimmer.modules.losses", "qualname": "contrastive_loss_bayesian", "kind": "function", "doc": "Computes the contrastive loss with a Bayesian based uncertainty prediction.
\n\nThis return multiple metrics: \n\n\n \n contrastive_{domain_1}_and_{domain_2}
with the contrastive\n between 2 domains; \n contrastive_{domain_1}_and_{domain_2}_{metric}
with\n additional metrics provided by the domain_mod's\n compute_cont_loss
output; \n contrastives
with the average value of all\n contrastive_{domain_1}_and_{domain_2}
values. \n \n \n\nArguments: \n\n\ngw_mod (GWModuleBayesian
): The GWModule to use \nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \ncontrastive_fn (ContrastiveLossBayesianType
): the contrastive function\nto apply \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBayesian , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.LossCoefs": {"fullname": "shimmer.modules.losses.LossCoefs", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs", "kind": "class", "doc": "Dict of loss coefficients used in the GWLosses.
\n\nIf one is not provided, the coefficient is assumed to be 0 and will not be logged.\nIf the loss is excplicitely set to 0, it will be logged, but not take part in\nthe total loss.
\n", "bases": "typing.TypedDict"}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"fullname": "shimmer.modules.losses.LossCoefs.demi_cycles", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.demi_cycles", "kind": "variable", "doc": "Demi-cycle loss coefficient.
\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.cycles": {"fullname": "shimmer.modules.losses.LossCoefs.cycles", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.cycles", "kind": "variable", "doc": "Cycle loss coefficient.
\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.translations": {"fullname": "shimmer.modules.losses.LossCoefs.translations", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.translations", "kind": "variable", "doc": "Translation loss coefficient.
\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.contrastives": {"fullname": "shimmer.modules.losses.LossCoefs.contrastives", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.contrastives", "kind": "variable", "doc": "Contrastive loss coefficient.
\n", "annotation": ": float"}, "shimmer.modules.losses.GWLosses2Domains": {"fullname": "shimmer.modules.losses.GWLosses2Domains", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains", "kind": "class", "doc": "Implementation of GWLossesBase
used for GWModule
.
\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"fullname": "shimmer.modules.losses.GWLosses2Domains.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.__init__", "kind": "function", "doc": "Main loss module to use with the GlobalWorkspace
\n\nArguments: \n\n\ngw_mod (GWModule
): the GWModule \nselection_mod (SelectionBase
): selection module \ndomain_mods (dict[str, DomainModule]
): a dict where the key is the\ndomain name and value is the DomainModule \nloss_coefs (LossCoefs
): loss coefficients. LossCoefs object, or a\nmapping to float with correct keys. \ncontrastive_fn (ContrastiveLossType
): the contrastive function to use\nin contrastive loss \n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModule , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : dict [ str , shimmer . modules . domain . DomainModule ] , \tloss_coefs : shimmer . modules . losses . LossCoefs , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] ) "}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"fullname": "shimmer.modules.losses.GWLosses2Domains.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.gw_mod", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"fullname": "shimmer.modules.losses.GWLosses2Domains.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.selection_mod", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"fullname": "shimmer.modules.losses.GWLosses2Domains.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.domain_mods", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"fullname": "shimmer.modules.losses.GWLosses2Domains.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.loss_coefs", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLosses2Domains.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.contrastive_fn", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.demi_cycle_loss", "kind": "function", "doc": "Computes the demi-cycle loss.
\n\nSee shimmer.modules.losses.demi_cycle_loss
.
\n\nArguments: \n\n\nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.cycle_loss", "kind": "function", "doc": "Computes the cycle loss.
\n\nSee shimmer.modules.losses.cycle_loss
.
\n\nArguments: \n\n\nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.translation_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.translation_loss", "kind": "function", "doc": "Computes the translation loss.
\n\nSee shimmer.modules.losses.translation_loss
.
\n\nArguments: \n\n\nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.contrastive_loss", "kind": "function", "doc": "Computes the contrastive loss.
\n\nSee shimmer.modules.losses.contrastive_loss
.
\n\nArguments: \n\n\nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.step": {"fullname": "shimmer.modules.losses.GWLosses2Domains.step", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.step", "kind": "function", "doc": "Computes and returns the losses
\n\nContains: \n\n\n \n Demi-cycle metrics (see GWLosses.demi_cycle_loss
) \n Cycle metrics (see GWLosses.cycle_loss
) \n Translation metrics (see GWLosses.translation_loss
) \n Contrastive metrics (see GWLosses.contrastive_loss
) \n \n \n\nArguments: \n\n\ndomain_latents (LatentsDomainGroupsT
): All latent groups \nmode (ModelModeT
): model mode \n \n\nReturns: \n\n\n LossOutput
: the losses
\n \n", "signature": "(\tself , \tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tmode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.losses.generate_partitions": {"fullname": "shimmer.modules.losses.generate_partitions", "modulename": "shimmer.modules.losses", "qualname": "generate_partitions", "kind": "function", "doc": "Generates all possible partitions of zeros and ones for n
elements,\nexcluding the all-zeros partition.
\n\nArguments: \n\n\nn (int
): The number of modalities to generate partitions for. \n \n\nYields: \n\n\n tuple[int, ...]
: A partition of zeros and ones, excluding the\n all-zeros partition.
\n \n", "signature": "(n : int ) -> collections . abc . Generator [ tuple [ int , ... ], None , None ] : ", "funcdef": "def"}, "shimmer.modules.losses.broadcast_loss": {"fullname": "shimmer.modules.losses.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "broadcast_loss", "kind": "function", "doc": "Computes broadcast loss including demi-cycle, cycle, and translation losses.
\n\nArguments: \n\n\ngw_mod (shimmer.modules.gw_module.GWModuleBase
): The GWModule to use \nselection_mod (shimmer.modules.selection.SelectionBase
): Selection mod to use \ndomain_mods (Mapping[str, DomainModule]
): the domain modules \nlatent_domains: The latent domain representations. \n \n\nReturns: \n\n\n A dictionary with the total loss and additional metrics.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : collections . abc . Mapping [ str , shimmer . modules . domain . DomainModule ] , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.BroadcastLossCoefs": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs", "kind": "class", "doc": "Dict of loss coefficients used in the GWLossesFusion.
\n\nIf one is not provided, the coefficient is assumed to be 0 and will not be logged.\nIf the loss is excplicitely set to 0, it will be logged, but not take part in\nthe total loss.
\n", "bases": "typing.TypedDict"}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.contrastives", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.contrastives", "kind": "variable", "doc": "Contrastive loss coefficient.
\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.fused", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.fused", "kind": "variable", "doc": "fused loss coefficient (encode multiple domains and decode to one of them).
\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.demi_cycles", "kind": "variable", "doc": "demi_cycles loss coefficient. Demi-cycles are always one-to-one
\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.cycles", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.cycles", "kind": "variable", "doc": "cycles loss coefficient. Cycles can be many-to-one
\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.translations", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.translations", "kind": "variable", "doc": "translation loss coefficient. Translation, like cycles, can be many-to-one.
\n", "annotation": ": float"}, "shimmer.modules.losses.GWLosses": {"fullname": "shimmer.modules.losses.GWLosses", "modulename": "shimmer.modules.losses", "qualname": "GWLosses", "kind": "class", "doc": "Implementation of GWLossesBase
for fusion-based models.
\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLosses.__init__": {"fullname": "shimmer.modules.losses.GWLosses.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.__init__", "kind": "function", "doc": "Initializes the loss computation module for a Global Workspace Fusion model.
\n\nArguments: \n\n\ngw_mod: The GWModule for the global workspace. \nselection_mod: The selection mechanism for the model. \ndomain_mods: A mapping of domain names to their respective DomainModule. \nloss_coefs (BroadcastLossCoefs
): coefs for the losses \ncontrastive_fn: The function used for computing contrastive loss. \n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModule , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : dict [ str , shimmer . modules . domain . DomainModule ] , \tloss_coefs : shimmer . modules . losses . BroadcastLossCoefs , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] ) "}, "shimmer.modules.losses.GWLosses.gw_mod": {"fullname": "shimmer.modules.losses.GWLosses.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.gw_mod", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses.selection_mod": {"fullname": "shimmer.modules.losses.GWLosses.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.selection_mod", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses.domain_mods": {"fullname": "shimmer.modules.losses.GWLosses.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.domain_mods", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses.loss_coefs": {"fullname": "shimmer.modules.losses.GWLosses.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.loss_coefs", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLosses.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.contrastive_fn", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLosses.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.contrastive_loss", "kind": "function", "doc": "Computes the contrastive loss for the given latent domains.
\n\nArguments: \n\n\nlatent_domains: The latent domain representations. \n \n\nReturns: \n\n\n A dictionary of contrastive loss metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"fullname": "shimmer.modules.losses.GWLosses.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.broadcast_loss", "kind": "function", "doc": "
\n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLosses.step": {"fullname": "shimmer.modules.losses.GWLosses.step", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.step", "kind": "function", "doc": "Performs a step of loss computation.
\n\nArguments: \n\n\ndomain_latents: Latent representations for all domains. \nmode: The mode in which the model is currently operating. \n \n\nReturns: \n\n\n A LossOutput object containing the loss and metrics for this step.
\n \n", "signature": "(\tself , \tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tmode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian": {"fullname": "shimmer.modules.losses.GWLossesBayesian", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian", "kind": "class", "doc": "Implementation of GWLossesBase
used for GWModuleBayesian
.
\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"fullname": "shimmer.modules.losses.GWLossesBayesian.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.__init__", "kind": "function", "doc": "Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
\n\nArguments: \n\n\ngw_mod (GWModuleBayesian
): the GWModule \nselection_mod (SelectionBase
): selection module \ndomain_mods (dict[str, DomainModule]
): a dict where the key is the\ndomain name and value is the DomainModule \nloss_coefs (BroadcastLossCoefs
): loss coefficients \ncontrastive_fn (ContrastiveLossType
): the contrastive function\nto use in contrastive loss \nuse_normalized_constrastive (bool
): whether to use the normalized cont\nloss by the precision coefs \n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBayesian , \tselection_mod : shimmer . modules . selection . SelectionBase , \tdomain_mods : dict [ str , shimmer . modules . domain . DomainModule ] , \tloss_coefs : shimmer . modules . losses . BroadcastLossCoefs , \tcontrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer . modules . domain . LossOutput ] , \tuse_normalized_constrastive : bool = True ) "}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"fullname": "shimmer.modules.losses.GWLossesBayesian.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.gw_mod", "kind": "variable", "doc": "The GWModule.
\n"}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"fullname": "shimmer.modules.losses.GWLossesBayesian.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.selection_mod", "kind": "variable", "doc": "Selection module
\n"}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"fullname": "shimmer.modules.losses.GWLossesBayesian.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.domain_mods", "kind": "variable", "doc": "Domain modules linked to the GW.
\n"}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"fullname": "shimmer.modules.losses.GWLossesBayesian.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.loss_coefs", "kind": "variable", "doc": "The loss coefficients.
\n"}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLossesBayesian.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.contrastive_fn", "kind": "variable", "doc": "Contrastive loss to use.
\n"}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"fullname": "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.use_normalized_constrastive", "kind": "variable", "doc": "
\n"}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLossesBayesian.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.contrastive_loss", "kind": "function", "doc": "Contrastive loss.
\n\nArguments: \n\n\nlatent_domains (LatentsDomainGroupsT
): the latent unimodal groups \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: a dict of metrics.
\n \n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"fullname": "shimmer.modules.losses.GWLossesBayesian.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.broadcast_loss", "kind": "function", "doc": "
\n", "signature": "(\tself , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian.step": {"fullname": "shimmer.modules.losses.GWLossesBayesian.step", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.step", "kind": "function", "doc": "Performs a step of loss computation.
\n\nArguments: \n\n\ndomain_latents: Latent representations for all domains. \nmode: The mode in which the model is currently operating. \n \n\nReturns: \n\n\n A LossOutput object containing the loss and metrics for this step.
\n \n", "signature": "(\tself , \tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tmode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.modules.contrastive_loss": {"fullname": "shimmer.modules.contrastive_loss", "modulename": "shimmer.modules.contrastive_loss", "kind": "module", "doc": "Various contrastive loss definitions
\n"}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLossType", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLossType", "kind": "variable", "doc": "Contrastive loss function type.
\n\nA function taking the prediction and targets and returning a LossOutput.
\n", "default_value": "collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]"}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLossBayesianType", "kind": "variable", "doc": "Contrastive loss function type for GlobalWorkspaceBayesian.
\n\nA function taking the prediction mean, prediction std, target mean and target std and\n returns a LossOutput.
\n", "default_value": "collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]"}, "shimmer.modules.contrastive_loss.info_nce": {"fullname": "shimmer.modules.contrastive_loss.info_nce", "modulename": "shimmer.modules.contrastive_loss", "qualname": "info_nce", "kind": "function", "doc": "InfoNCE loss
\n\nArguments: \n\n\nx (torch.Tensor
): prediction \ny (torch.Tensor
): target \nlogit_scale (torch.Tensor
): logit scale \nreduction (Literal[\"mean\", \"sum\", \"none\"]
): reduction to apply \n \n\nReturns: the InfoNCE loss
\n", "signature": "(\tx : torch . Tensor , \ty : torch . Tensor , \tlogit_scale : torch . Tensor , \treduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.contrastive_loss.contrastive_loss": {"fullname": "shimmer.modules.contrastive_loss.contrastive_loss", "modulename": "shimmer.modules.contrastive_loss", "qualname": "contrastive_loss", "kind": "function", "doc": "CLIP-like contrastive loss
\n\nArguments: \n\n\nx (torch.Tensor
): prediction \ny (torch.Tensor
): target \nlogit_scale (torch.Tensor
): logit scale \nreduction (Literal[\"mean\", \"sum\", \"none\"]
): reduction to apply \n \n\nReturns: the contrastive loss
\n", "signature": "(\tx : torch . Tensor , \ty : torch . Tensor , \tlogit_scale : torch . Tensor , \treduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss", "kind": "class", "doc": "CLIP-like ContrastiveLoss torch module.
\n", "bases": "torch.nn.modules.module.Module"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.__init__", "kind": "function", "doc": "Initializes a contrastive loss.
\n\nArguments: \n\n\nlogit_scale (torch.Tensor
): logit_scale tensor. \nreduction (Literal[\"mean\", \"sum\", \"none\"]
): reduction to apply to the\nloss. Defaults to \"mean\"
. \nlearn_logit_scale (torch.Tensor
): whether to learn the logit_scale
\nparameter. Defaults to False
. \n \n", "signature": "(\tlogit_scale : torch . Tensor , \treduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' , \tlearn_logit_scale : bool = False ) "}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.learn_logit_scale", "kind": "variable", "doc": "
\n"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.reduction", "kind": "variable", "doc": "
\n", "annotation": ": Literal['mean', 'sum', 'none']"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.forward", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.forward", "kind": "function", "doc": "Computes the loss.
\n\nArguments: \n\n\nx (torch.Tensor
): prediction \ny (torch.Tensor
): target \n \n\nReturns: \n\n\n LossOutput of the loss. Contains a logit_scale
metric.
\n \n", "signature": "(\tself , \tx : torch . Tensor , \ty : torch . Tensor ) -> shimmer . modules . domain . LossOutput : ", "funcdef": "def"}, "shimmer.dataset": {"fullname": "shimmer.dataset", "modulename": "shimmer.dataset", "kind": "module", "doc": "
\n"}, "shimmer.dataset.RepeatedDataset": {"fullname": "shimmer.dataset.RepeatedDataset", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset", "kind": "class", "doc": "Dataset that cycles through its items to have a size of at least min size.\nIf drop_last is True, the size will be exaclty min_size. If drop_last is False,\nthe min_size \u2264 size < min_size + len(dataset).
\n", "bases": "typing.Generic[+T_co]"}, "shimmer.dataset.RepeatedDataset.__init__": {"fullname": "shimmer.dataset.RepeatedDataset.__init__", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.__init__", "kind": "function", "doc": "Arguments: \n\n\ndataset (SizedDataset): dataset to repeat. The dataset should have a size\n(where __len__
is defined). \nmin_size (int): minimum size of the final dataset \ndrop_last (bool): whether to remove overflow when repeating the\ndataset. \n \n", "signature": "(\tdataset : shimmer . dataset . _SizedDataset , \tmin_size : int , \tdrop_last : bool = False ) "}, "shimmer.dataset.RepeatedDataset.dataset": {"fullname": "shimmer.dataset.RepeatedDataset.dataset", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.dataset", "kind": "variable", "doc": "
\n"}, "shimmer.dataset.RepeatedDataset.dataset_size": {"fullname": "shimmer.dataset.RepeatedDataset.dataset_size", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.dataset_size", "kind": "variable", "doc": "
\n"}, "shimmer.modules.vae": {"fullname": "shimmer.modules.vae", "modulename": "shimmer.modules.vae", "kind": "module", "doc": "
\n"}, "shimmer.modules.vae.reparameterize": {"fullname": "shimmer.modules.vae.reparameterize", "modulename": "shimmer.modules.vae", "qualname": "reparameterize", "kind": "function", "doc": "Reparameterization trick for VAE
\n\nArguments: \n\n\nmean (torch.Tensor
): predicted means \nlogvar (torch.Tensor
): predicted log variance \n \n\nReturns: \n\n\n torch.Tensor
: a sample from normal distribution with provided\n parameters, sampled using the reparameterization trick.
\n \n", "signature": "(mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.vae.kl_divergence_loss": {"fullname": "shimmer.modules.vae.kl_divergence_loss", "modulename": "shimmer.modules.vae", "qualname": "kl_divergence_loss", "kind": "function", "doc": "Computes the KL divergence loss used in VAE.
\n\nArguments: \n\n\nmean (torch.Tensor
): predicted means \nlogvar (torch.Tensor
): predicted logvars \n \n\nReturns: \n\n\n torch.Tensor
: the loss
\n \n", "signature": "(mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.vae.gaussian_nll": {"fullname": "shimmer.modules.vae.gaussian_nll", "modulename": "shimmer.modules.vae", "qualname": "gaussian_nll", "kind": "function", "doc": "Computes gaussian nll loss used in VAE.
\n\nArguments: \n\n\nmu (torch.Tensor
): predictions \nlog_sigma (torch.Tensor
): log sigma \nx (torch.Tensor
): targets \n \n\nReturns: \n\n\n torch.Tensor
: the Gaussian NLL loss
\n \n", "signature": "(\tmu : torch . Tensor , \tlog_sigma : torch . Tensor , \tx : torch . Tensor ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.vae.VAEEncoder": {"fullname": "shimmer.modules.vae.VAEEncoder", "modulename": "shimmer.modules.vae", "qualname": "VAEEncoder", "kind": "class", "doc": "Base class for a VAE encoder.
\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.vae.VAEEncoder.forward": {"fullname": "shimmer.modules.vae.VAEEncoder.forward", "modulename": "shimmer.modules.vae", "qualname": "VAEEncoder.forward", "kind": "function", "doc": "Encode representation with VAE.
\n\nArguments: \n\n\nx (Any
): Some input value \n \n\nReturns: \n\n\n tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
\n \n", "signature": "(self , x : Any ) -> tuple [ torch . Tensor , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.vae.VAEDecoder": {"fullname": "shimmer.modules.vae.VAEDecoder", "modulename": "shimmer.modules.vae", "qualname": "VAEDecoder", "kind": "class", "doc": "Base class for a VAE decoder.
\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.vae.VAEDecoder.forward": {"fullname": "shimmer.modules.vae.VAEDecoder.forward", "modulename": "shimmer.modules.vae", "qualname": "VAEDecoder.forward", "kind": "function", "doc": "Decode representation with VAE
\n\nArguments: \n\n\nx (torch.Tensor
): VAE latent representation representation \n \n\nReturns: \n\n\n Any
: the reconstructed input
\n \n", "signature": "(self , x : torch . Tensor ) -> Any : ", "funcdef": "def"}, "shimmer.modules.vae.VAE": {"fullname": "shimmer.modules.vae.VAE", "modulename": "shimmer.modules.vae", "qualname": "VAE", "kind": "class", "doc": "VAE module
\n", "bases": "torch.nn.modules.module.Module"}, "shimmer.modules.vae.VAE.__init__": {"fullname": "shimmer.modules.vae.VAE.__init__", "modulename": "shimmer.modules.vae", "qualname": "VAE.__init__", "kind": "function", "doc": "Initializes a VAE.
\n\nArguments: \n\n\nencoder (VAEEncoder
): VAE encode \ndecoder (VAEDecoder
): VAE decoder \nbeta (float
): beta value for Beta-VAE. Defaults to 1. \n \n", "signature": "(\tencoder : shimmer . modules . vae . VAEEncoder , \tdecoder : shimmer . modules . vae . VAEDecoder , \tbeta : float = 1 ) "}, "shimmer.modules.vae.VAE.beta": {"fullname": "shimmer.modules.vae.VAE.beta", "modulename": "shimmer.modules.vae", "qualname": "VAE.beta", "kind": "variable", "doc": "Beta value for Beta-VAEs
\n"}, "shimmer.modules.vae.VAE.encoder": {"fullname": "shimmer.modules.vae.VAE.encoder", "modulename": "shimmer.modules.vae", "qualname": "VAE.encoder", "kind": "variable", "doc": "The encoder
\n"}, "shimmer.modules.vae.VAE.decoder": {"fullname": "shimmer.modules.vae.VAE.decoder", "modulename": "shimmer.modules.vae", "qualname": "VAE.decoder", "kind": "variable", "doc": "The decoder
\n"}, "shimmer.modules.vae.VAE.encode": {"fullname": "shimmer.modules.vae.VAE.encode", "modulename": "shimmer.modules.vae", "qualname": "VAE.encode", "kind": "function", "doc": "Encode the representation and returns the mean prediction of VAE.
\n\nArguments: \n\n\nx (Any
): Some input value \n \n\nReturns: \n\n\n torch.Tensor
: The mean representation.
\n \n", "signature": "(self , x : Any ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.vae.VAE.decode": {"fullname": "shimmer.modules.vae.VAE.decode", "modulename": "shimmer.modules.vae", "qualname": "VAE.decode", "kind": "function", "doc": "Decode the VAE latent representation into input value.
\n\nArguments: \n\n\nz (torch.Tensor
): the VAE latent representation. \n \n\nReturns: \n\n\n Any
: the reconstructed input.
\n \n", "signature": "(self , z : torch . Tensor ) -> Any : ", "funcdef": "def"}, "shimmer.modules.vae.VAE.forward": {"fullname": "shimmer.modules.vae.VAE.forward", "modulename": "shimmer.modules.vae", "qualname": "VAE.forward", "kind": "function", "doc": "Encode and decodes from x.
\n\nArguments: \n\n\nx (Any
): the input data \n \n\nReturns: \n\n\n tuple[tuple[torch.Tensor, torch.Tensor], Any]
: The\n first tuple contains the mean and logvar of the encoded input,\n the second item is the reconstructed input.
\n \n", "signature": "(self , x : Any ) -> tuple [ tuple [ torch . Tensor , torch . Tensor ], typing . Any ] : ", "funcdef": "def"}, "shimmer.modules.utils": {"fullname": "shimmer.modules.utils", "modulename": "shimmer.modules.utils", "kind": "module", "doc": "
\n"}, "shimmer.modules.utils.translation": {"fullname": "shimmer.modules.utils.translation", "modulename": "shimmer.modules.utils", "qualname": "translation", "kind": "function", "doc": "Translate from multiple domains to one domain.
\n\nArguments: \n\n\ngw_module (GWModuleBase
): GWModule to perform the translation over \nselection_mod (SelectionBase
): selection module \nx (LatentsDomainGroupT
): the group of latent representations \nto (str
): the domain name to encode to \n \n\nReturns: \n\n\n torch.Tensor
: the translated unimodal representation\n of the provided domain.
\n \n", "signature": "(\tgw_module : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tto : str ) -> torch . Tensor : ", "funcdef": "def"}, "shimmer.modules.utils.cycle": {"fullname": "shimmer.modules.utils.cycle", "modulename": "shimmer.modules.utils", "qualname": "cycle", "kind": "function", "doc": "Do a full cycle from a group of representation through one domain.
\n\n[Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
\n\nArguments: \n\n\ngw_module (GWModuleBase
): GWModule to perform the translation over \nselection_mod (SelectionBase
): selection module \nx (LatentsDomainGroupT
): group of unimodal latent representation \nthrough (str
): domain name to cycle through \n \n\nReturns: \n\n\n LatentsDomainGroupDT
: group of unimodal latent representation after\n cycling.
\n \n", "signature": "(\tgw_module : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tx : collections . abc . Mapping [ str , torch . Tensor ] , \tthrough : str ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.utils.batch_demi_cycles": {"fullname": "shimmer.modules.utils.batch_demi_cycles", "modulename": "shimmer.modules.utils", "qualname": "batch_demi_cycles", "kind": "function", "doc": "Computes demi-cycles of a batch of groups of domains.
\n\nArguments: \n\n\ngw_mod (GWModuleBase
): the GWModuleBase \nselection_mod (SelectionBase
): selection module \nlatent_domains (LatentsT
): the batch of groups of domains \n \n\nReturns: \n\n\n dict[str, torch.Tensor]
: demi-cycles predictions for each domain.
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.utils.batch_cycles": {"fullname": "shimmer.modules.utils.batch_cycles", "modulename": "shimmer.modules.utils", "qualname": "batch_cycles", "kind": "function", "doc": "Computes cycles of a batch of groups of domains.
\n\nArguments: \n\n\ngw_mod (GWModuleBase
): GWModule to use for the cycle \nselection_mod (SelectionBase
): selection module \nlatent_domains (LatentsT
): the batch of groups of domains \nout_domains (Iterable[str]
): iterable of domain names to do the cycle through.\nEach domain will be done separetely. \n \n\nReturns: \n\n\n dict[tuple[str, str], torch.Tensor]
: cycles predictions for each\n couple of (start domain, intermediary domain).
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , \tthrough_domains : collections . abc . Iterable [ str ] ) -> dict [ tuple [ str , str ], torch . Tensor ] : ", "funcdef": "def"}, "shimmer.modules.utils.batch_translations": {"fullname": "shimmer.modules.utils.batch_translations", "modulename": "shimmer.modules.utils", "qualname": "batch_translations", "kind": "function", "doc": "Computes translations of a batch of groups of domains.
\n\nArguments: \n\n\ngw_mod (GWModuleBase
): GWModule to do the translation \nselection_mod (SelectionBase
): selection module \nlatent_domains (LatentsT
): the batch of groups of domains \n \n\nReturns: \n\n\n dict[tuple[str, str], torch.Tensor]
: translation predictions for each\n couple of (start domain, target domain).
\n \n", "signature": "(\tgw_mod : shimmer . modules . gw_module . GWModuleBase , \tselection_mod : shimmer . modules . selection . SelectionBase , \tlatent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ tuple [ str , str ], torch . Tensor ] : ", "funcdef": "def"}, "shimmer.utils": {"fullname": "shimmer.utils", "modulename": "shimmer.utils", "kind": "module", "doc": "
\n"}, "shimmer.utils.MIGRATION_DIR": {"fullname": "shimmer.utils.MIGRATION_DIR", "modulename": "shimmer.utils", "qualname": "MIGRATION_DIR", "kind": "variable", "doc": "
\n", "default_value": "PosixPath('/home/runner/work/shimmer/shimmer/shimmer/ckpt_migrations')"}, "shimmer.utils.group_batch_size": {"fullname": "shimmer.utils.group_batch_size", "modulename": "shimmer.utils", "qualname": "group_batch_size", "kind": "function", "doc": "
\n", "signature": "(x : collections . abc . Mapping [ str , torch . Tensor ] ) -> int : ", "funcdef": "def"}, "shimmer.utils.groups_batch_size": {"fullname": "shimmer.utils.groups_batch_size", "modulename": "shimmer.utils", "qualname": "groups_batch_size", "kind": "function", "doc": "Get the batch size of the batch.
\n\nArguments: \n\n\ndomain_latents (LatentsDomainGroupsT
): the batch of groups. \n \n\nReturns: \n\n\n int: the batch size.
\n \n", "signature": "(\tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> int : ", "funcdef": "def"}, "shimmer.utils.groups_device": {"fullname": "shimmer.utils.groups_device", "modulename": "shimmer.utils", "qualname": "groups_device", "kind": "function", "doc": "Get the batch size of the batch.
\n\nArguments: \n\n\ndomain_latents (LatentsDomainGroupsT
): the batch of groups. \n \n\nReturns: \n\n\n int: the batch size.
\n \n", "signature": "(\tdomain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> int : ", "funcdef": "def"}, "shimmer.utils.group_device": {"fullname": "shimmer.utils.group_device", "modulename": "shimmer.utils", "qualname": "group_device", "kind": "function", "doc": "
\n", "signature": "(x : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . device : ", "funcdef": "def"}, "shimmer.utils.migrate_model": {"fullname": "shimmer.utils.migrate_model", "modulename": "shimmer.utils", "qualname": "migrate_model", "kind": "function", "doc": "Migrates a model checkpoint
\n\nAfter the migration, the given checkpoint will be migrated.\nOther versions of the checkpoint will be saved under the stem-version.suffix.
\n\nArguments: \n\n\nckpt_path (str | PathLike
): path to checkpoint \ntorch_load_kwargs: additional args given to torch.load. \n \n", "signature": "(ckpt_path : str | os . PathLike , ** torch_load_kwargs ): ", "funcdef": "def"}, "shimmer.utils.SaveMigrations": {"fullname": "shimmer.utils.SaveMigrations", "modulename": "shimmer.utils", "qualname": "SaveMigrations", "kind": "class", "doc": "Abstract base class used to build new callbacks.
\n\nSubclass this class and override any of the relevant hooks
\n", "bases": "lightning.pytorch.callbacks.callback.Callback"}, "shimmer.utils.SaveMigrations.migrations": {"fullname": "shimmer.utils.SaveMigrations.migrations", "modulename": "shimmer.utils", "qualname": "SaveMigrations.migrations", "kind": "variable", "doc": "
\n"}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"fullname": "shimmer.utils.SaveMigrations.on_save_checkpoint", "modulename": "shimmer.utils", "qualname": "SaveMigrations.on_save_checkpoint", "kind": "function", "doc": "Called when saving a checkpoint to give you a chance to store anything else you might want to save.
\n\nArguments: \n\n\ntrainer: the current ~lightning.pytorch.trainer.trainer.Trainer
instance. \npl_module: the current ~lightning.pytorch.core.LightningModule
instance. \ncheckpoint: the checkpoint dictionary that will be saved. \n \n", "signature": "(\tself , \ttrainer : lightning . pytorch . trainer . trainer . Trainer , \tpl_module : lightning . pytorch . core . module . LightningModule , \tcheckpoint : dict [ str , typing . Any ] ): ", "funcdef": "def"}, "shimmer.cli.ckpt_migration": {"fullname": "shimmer.cli.ckpt_migration", "modulename": "shimmer.cli.ckpt_migration", "kind": "module", "doc": "
\n"}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"fullname": "shimmer.cli.ckpt_migration.migrate_ckpt", "modulename": "shimmer.cli.ckpt_migration", "qualname": "migrate_ckpt", "kind": "variable", "doc": "Script to migrate a list of checkpoints.\nThis can be called with:
\n\n\n
shimmer migrate-ckpt PATH_1 PATH_2 ... PATH_N\n
\n
\n\nwhere paths point to checkpoints.
\n\nInternally, this calls shimmer.utils.migrate_model
for each of the given paths.
\n", "default_value": "<Command migrate-ckpt>"}}, "docInfo": {"shimmer.types": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.types.RawDomainGroupT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 5, "signature": 0, "bases": 0, "doc": 210}, "shimmer.types.RawDomainGroupDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 3, "signature": 0, "bases": 0, "doc": 165}, "shimmer.types.LatentsDomainGroupT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 5, "signature": 0, "bases": 0, "doc": 234}, "shimmer.types.LatentsDomainGroupDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 3, "signature": 0, "bases": 0, "doc": 198}, "shimmer.types.RawDomainGroupsT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 8, "signature": 0, "bases": 0, "doc": 297}, "shimmer.types.RawDomainGroupsDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 4, "signature": 0, "bases": 0, "doc": 297}, "shimmer.types.LatentsDomainGroupsT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 8, "signature": 0, "bases": 0, "doc": 401}, "shimmer.types.LatentsDomainGroupsDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 4, "signature": 0, "bases": 0, "doc": 365}, "shimmer.types.ModelModeT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 18, "signature": 0, "bases": 0, "doc": 27}, "shimmer.modules.global_workspace": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.SchedulerArgs": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 10}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 5}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.global_workspace.GWPredictionsBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 13}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"qualname": 2, "fullname": 6, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 20}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 18, "doc": 20}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 8}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 8}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 10}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"qualname": 3, "fullname": 7, "annotation": 8, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 125, "bases": 0, "doc": 65}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 108, "bases": 0, "doc": 52}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 161, "bases": 0, "doc": 71}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 128, "bases": 0, "doc": 68}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 55}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 39, "bases": 0, "doc": 122}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 108, "bases": 0, "doc": 59}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 39, "bases": 0, "doc": 122}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 109, "bases": 0, "doc": 59}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 143, "bases": 0, "doc": 75}, "shimmer.modules.global_workspace.freeze_domain_modules": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 85, "bases": 0, "doc": 81}, "shimmer.modules.global_workspace.GWPredictions": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 13}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"qualname": 3, "fullname": 7, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 21}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"qualname": 2, "fullname": 6, "annotation": 5, "default_value": 0, "signature": 0, "bases": 0, "doc": 34}, "shimmer.modules.global_workspace.GWPredictions.translations": {"qualname": 2, "fullname": 6, "annotation": 5, "default_value": 0, "signature": 0, "bases": 0, "doc": 37}, "shimmer.modules.global_workspace.GWPredictions.states": {"qualname": 2, "fullname": 6, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 33}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 380, "bases": 0, "doc": 250}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.GlobalWorkspace": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 35}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 400, "bases": 0, "doc": 271}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 37}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 459, "bases": 0, "doc": 318}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 317, "bases": 0, "doc": 264}, "shimmer.modules.domain": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.domain.LossOutput": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 37}, "shimmer.modules.domain.LossOutput.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 61, "bases": 0, "doc": 3}, "shimmer.modules.domain.LossOutput.loss": {"qualname": 2, "fullname": 5, "annotation": 3, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.domain.LossOutput.metrics": {"qualname": 2, "fullname": 5, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 12}, "shimmer.modules.domain.LossOutput.all": {"qualname": 2, "fullname": 5, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 14}, "shimmer.modules.domain.DomainModule": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 16}, "shimmer.modules.domain.DomainModule.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 15, "bases": 0, "doc": 29}, "shimmer.modules.domain.DomainModule.latent_dim": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 9}, "shimmer.modules.domain.DomainModule.encode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.domain.DomainModule.decode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 53}, "shimmer.modules.domain.DomainModule.compute_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 61}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 75}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 73}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 73}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 74}, "shimmer.modules.gw_module": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.get_n_layers": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 57, "bases": 0, "doc": 76}, "shimmer.modules.gw_module.GWDecoder": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 8}, "shimmer.modules.gw_module.GWDecoder.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 48, "bases": 0, "doc": 81}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 24}, "shimmer.modules.gw_module.GWEncoder": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 27}, "shimmer.modules.gw_module.GWEncoder.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 48, "bases": 0, "doc": 81}, "shimmer.modules.gw_module.GWEncoder.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 34, "bases": 0, "doc": 67}, "shimmer.modules.gw_module.GWEncoderLinear": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 10}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 34, "bases": 0, "doc": 67}, "shimmer.modules.gw_module.GWModuleBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 58}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 81, "bases": 0, "doc": 43}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 69}, "shimmer.modules.gw_module.GWModuleBase.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 70, "bases": 0, "doc": 50}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 85, "bases": 0, "doc": 78}, "shimmer.modules.gw_module.GWModuleBase.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 88, "bases": 0, "doc": 65}, "shimmer.modules.gw_module.GWModule": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 10}, "shimmer.modules.gw_module.GWModule.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 173, "bases": 0, "doc": 117}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModule.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 69}, "shimmer.modules.gw_module.GWModule.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 70, "bases": 0, "doc": 50}, "shimmer.modules.gw_module.GWModule.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 88, "bases": 0, "doc": 73}, "shimmer.modules.gw_module.compute_fusion_scores": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 107, "bases": 0, "doc": 119}, "shimmer.modules.gw_module.GWModuleBayesian": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 12}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 232, "bases": 0, "doc": 163}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 11}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 57}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 296}, "shimmer.modules.selection": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.SelectionBase": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 29}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 30, "bases": 0, "doc": 66}, "shimmer.modules.selection.SelectionBase.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 202}, "shimmer.modules.selection.SingleDomainSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 48}, "shimmer.modules.selection.SingleDomainSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 69}, "shimmer.modules.selection.FixedSharedSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 35}, "shimmer.modules.selection.FixedSharedSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 69}, "shimmer.modules.selection.KQFixedQSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 11}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 56, "bases": 0, "doc": 62}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 89}, "shimmer.modules.selection.RandomSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 37}, "shimmer.modules.selection.RandomSelection.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 14, "bases": 0, "doc": 26}, "shimmer.modules.selection.RandomSelection.temperature": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.RandomSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 88}, "shimmer.modules.selection.DynamicQueryAttention": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 21}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 56, "bases": 0, "doc": 62}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 87, "bases": 0, "doc": 69}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 89}, "shimmer.modules.losses": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBase": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 30}, "shimmer.modules.losses.GWLossesBase.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 57}, "shimmer.modules.losses.demi_cycle_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 175}, "shimmer.modules.losses.cycle_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 172}, "shimmer.modules.losses.translation_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 159}, "shimmer.modules.losses.contrastive_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 155}, "shimmer.modules.losses.contrastive_loss_bayesian": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 161}, "shimmer.modules.losses.LossCoefs": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 51}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.losses.LossCoefs.cycles": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.LossCoefs.translations": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.LossCoefs.contrastives": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.GWLosses2Domains": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 13}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 190, "bases": 0, "doc": 107}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 60}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 109}, "shimmer.modules.losses.generate_partitions": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 58, "bases": 0, "doc": 71}, "shimmer.modules.losses.broadcast_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 99}, "shimmer.modules.losses.BroadcastLossCoefs": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 51}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 15}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 13}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 11}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 14}, "shimmer.modules.losses.GWLosses": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 12}, "shimmer.modules.losses.GWLosses.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 190, "bases": 0, "doc": 91}, "shimmer.modules.losses.GWLosses.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 46}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 64}, "shimmer.modules.losses.GWLossesBayesian": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 13}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 210, "bases": 0, "doc": 120}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 5}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 9}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 45}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBayesian.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 64}, "shimmer.modules.contrastive_loss": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 10, "signature": 0, "bases": 0, "doc": 21}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 14, "signature": 0, "bases": 0, "doc": 29}, "shimmer.modules.contrastive_loss.info_nce": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 115, "bases": 0, "doc": 68}, "shimmer.modules.contrastive_loss.contrastive_loss": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 115, "bases": 0, "doc": 70}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 8}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 93, "bases": 0, "doc": 83}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"qualname": 2, "fullname": 6, "annotation": 12, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 56}, "shimmer.dataset": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.dataset.RepeatedDataset": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 3, "doc": 46}, "shimmer.dataset.RepeatedDataset.__init__": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 57, "bases": 0, "doc": 63}, "shimmer.dataset.RepeatedDataset.dataset": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.dataset.RepeatedDataset.dataset_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.vae": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.vae.reparameterize": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 65}, "shimmer.modules.vae.kl_divergence_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 57}, "shimmer.modules.vae.gaussian_nll": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 63, "bases": 0, "doc": 69}, "shimmer.modules.vae.VAEEncoder": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 9}, "shimmer.modules.vae.VAEEncoder.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 46, "bases": 0, "doc": 46}, "shimmer.modules.vae.VAEDecoder": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 9}, "shimmer.modules.vae.VAEDecoder.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 42}, "shimmer.modules.vae.VAE": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 4}, "shimmer.modules.vae.VAE.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 74, "bases": 0, "doc": 53}, "shimmer.modules.vae.VAE.beta": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.vae.VAE.encoder": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.vae.VAE.decoder": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.vae.VAE.encode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.vae.VAE.decode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.vae.VAE.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 61, "bases": 0, "doc": 63}, "shimmer.modules.utils": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.utils.translation": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 118, "bases": 0, "doc": 96}, "shimmer.modules.utils.cycle": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 130, "bases": 0, "doc": 111}, "shimmer.modules.utils.batch_demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 144, "bases": 0, "doc": 81}, "shimmer.modules.utils.batch_cycles": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 115}, "shimmer.modules.utils.batch_translations": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 88}, "shimmer.utils": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.MIGRATION_DIR": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 6, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.group_batch_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 46, "bases": 0, "doc": 3}, "shimmer.utils.groups_batch_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 72, "bases": 0, "doc": 46}, "shimmer.utils.groups_device": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 72, "bases": 0, "doc": 46}, "shimmer.utils.group_device": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 51, "bases": 0, "doc": 3}, "shimmer.utils.migrate_model": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 37, "bases": 0, "doc": 67}, "shimmer.utils.SaveMigrations": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 23}, "shimmer.utils.SaveMigrations.migrations": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"qualname": 4, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 103, "bases": 0, "doc": 74}, "shimmer.cli.ckpt_migration": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 7, "signature": 0, "bases": 0, "doc": 72}}, "length": 232, "save": true}, "index": {"qualname": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 4}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 4}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}}, "df": 2}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 26, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 9, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 5}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.utils.SaveMigrations.migrations": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}}}}, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}}, "df": 2}}}}}}}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}}, "df": 1}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}}, "df": 8, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 6}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 7, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}}, "df": 8}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 8}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 12}}}}}}}}, "docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}}, "df": 10, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 11}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}}, "df": 3}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 19}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 2}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 2}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "w": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 3}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 7, "r": {"docs": {"shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}}, "df": 1}}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 6}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 2}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 10}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 6}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}}, "df": 1}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 2}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 2}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 9, "r": {"docs": {"shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 7, "d": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15}}}}}}, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}}, "df": 3}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 6}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}}, "df": 4, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 8}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 6}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}}, "df": 9, "s": {"docs": {"shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}}, "df": 2}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 3}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19}}, "f": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 2, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}, "k": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 2}}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}, "fullname": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types": {"tf": 1}, "shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 232}}}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 30, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}}, "df": 2}}}}}}}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types": {"tf": 1}, "shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}}, "df": 10}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}}, "df": 1}}}}}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 4}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 4}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}}, "df": 2}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 35, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}}}}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 54}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 9, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 37, "s": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 205}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.utils.SaveMigrations.migrations": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}}}}, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}}, "df": 41, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}}, "df": 3}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 19}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}}, "df": 43, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 6}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 7, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}}, "df": 8}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 8}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 12}}}}}}}}, "docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}}, "df": 10, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 11}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 2}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 2}}}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 42}}}}}}}}, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 7, "r": {"docs": {"shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}}, "df": 1}}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 6}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 24, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 2}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 10}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 6}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}}, "df": 1}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1.4142135623730951}}, "df": 5}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 2}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 9, "r": {"docs": {"shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 7, "d": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15}}}}}}, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}}, "df": 3}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 6}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}}, "df": 4, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 8}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 6}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 18, "s": {"docs": {"shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}}, "df": 2}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 3}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19}}, "f": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 2, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 16}}}}}, "k": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 2}}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}}, "df": 16, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}, "annotation": {"root": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1.4142135623730951}}, "df": 22, "f": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 10}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}}, "df": 2}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 8}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 8}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "x": {"2": {"7": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 2.449489742783178}}, "df": 1}, "docs": {}, "df": 0}, "docs": {}, "df": 0}, "n": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}, "default_value": {"root": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 3, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 6}}}}}}}}}, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 6}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 4}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}}, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}}, "df": 5}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.7320508075688772}}, "df": 6}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 2}}, "df": 6}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}, "x": {"2": {"7": {"docs": {"shimmer.types.ModelModeT": {"tf": 3.1622776601683795}, "shimmer.utils.MIGRATION_DIR": {"tf": 1.4142135623730951}}, "df": 2}, "docs": {}, "df": 0}, "docs": {}, "df": 0}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}}}}}}, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "signature": {"root": {"0": {"0": {"1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "docs": {}, "df": 0}, "1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}, "6": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 5}, "1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 4, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}, "2": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}}, "df": 2}, "3": {"9": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLosses.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 3.1622776601683795}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2.8284271247461903}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2.8284271247461903}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2.8284271247461903}}, "df": 8}, "docs": {}, "df": 0}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 10.04987562112089}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 11.40175425099138}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 10.246950765959598}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 8.774964387392123}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 5.656854249492381}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 5.656854249492381}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 10.535653752852738}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 8.306623862918075}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 17.291616465790582}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 17.72004514666935}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 18.947295321496416}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 15.874507866387544}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 7.0710678118654755}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 3.4641016151377544}, "shimmer.modules.domain.DomainModule.encode": {"tf": 4.898979485566356}, "shimmer.modules.domain.DomainModule.decode": {"tf": 4.898979485566356}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 7.14142842854285}, "shimmer.modules.gw_module.get_n_layers": {"tf": 6.708203932499369}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 6}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 6}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 5.291502622129181}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 5.291502622129181}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 8.12403840463596}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 8.888194417315589}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 7.615773105863909}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 8.306623862918075}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 8.54400374531753}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 11.74734012447073}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 8.888194417315589}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 7.615773105863909}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 8.54400374531753}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 9.1104335791443}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 13.527749258468683}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 6}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 8.888194417315589}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 4.898979485566356}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 6.6332495807108}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 3.4641016151377544}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 6.6332495807108}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 8.426149773176359}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 9.433981132056603}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.cycle_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.translation_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.contrastive_loss": {"tf": 12.041594578792296}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 12.041594578792296}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 12.24744871391589}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.generate_partitions": {"tf": 7}, "shimmer.modules.losses.broadcast_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 12.24744871391589}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 12.84523257866513}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 10.908712114635714}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 9.433981132056603}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 9.433981132056603}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 8.366600265340756}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 7.14142842854285}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 6.782329983125268}, "shimmer.modules.vae.reparameterize": {"tf": 6}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 6}, "shimmer.modules.vae.gaussian_nll": {"tf": 7.14142842854285}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 6.164414002968976}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.__init__": {"tf": 7.810249675906654}, "shimmer.modules.vae.VAE.encode": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.decode": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.forward": {"tf": 7.0710678118654755}, "shimmer.modules.utils.translation": {"tf": 9.695359714832659}, "shimmer.modules.utils.cycle": {"tf": 10.198039027185569}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 10.677078252031311}, "shimmer.modules.utils.batch_cycles": {"tf": 12}, "shimmer.modules.utils.batch_translations": {"tf": 11.045361017187261}, "shimmer.utils.group_batch_size": {"tf": 6.164414002968976}, "shimmer.utils.groups_batch_size": {"tf": 7.615773105863909}, "shimmer.utils.groups_device": {"tf": 7.615773105863909}, "shimmer.utils.group_device": {"tf": 6.48074069840786}, "shimmer.utils.migrate_model": {"tf": 5.385164807134504}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 9.16515138991168}}, "df": 103, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 58}, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 21, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 14}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2}, "shimmer.modules.losses.cycle_loss": {"tf": 2}, "shimmer.modules.losses.translation_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 2}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_cycles": {"tf": 2.23606797749979}, "shimmer.modules.utils.batch_translations": {"tf": 2}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 72}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 39}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 6}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 3, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 3}}}, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 23}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 67}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 9}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 9}}}}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 67}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 13}}, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}}, "df": 4}}}, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 59}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 14, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 24, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.449489742783178}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 39}}}}, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 13}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 5}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}, "u": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 34}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 7}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 5}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 6}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 6}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 91}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 87}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 5}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 2}}, "df": 1}}}}}, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 5}}}}, "h": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 42}}, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 13}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 35, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 30}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 15}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 6}}}}}, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.group_device": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}, "z": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 5}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 2, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 18}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 4, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 11}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 7}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 22, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 7}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 8}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 3}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 19}}}}}}}}, "g": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 6}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 2}}}}, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4}}}}, "t": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 2}}}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 6}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}}}}}, "e": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}}, "df": 1, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "g": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 21, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 10}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 7, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}}, "docs": {}, "df": 0}}}}}}}}}}}}}}, "t": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 6}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 10}}}}}}}}, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 6, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2}}}}}}, "d": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 2, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}}, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "l": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}, "k": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 3}}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 3}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "bases": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}}, "df": 1, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.7320508075688772}}, "df": 1, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 6}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 4}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 9}}}}}, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}}, "df": 1, "[": {"docs": {}, "df": 0, "+": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 4, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}}, "df": 1}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}}, "df": 1}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}}, "df": 1, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 3}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.7320508075688772}}, "df": 1, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}}, "df": 12, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 12}}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 4, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 5}}}}}}}}}}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}}, "df": 1}}}}}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1.7320508075688772}}, "df": 3}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "p": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3}}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 2}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1.4142135623730951}}, "df": 1, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}}}}}, "b": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 1}}}}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 9}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.4142135623730951}}, "df": 5}}}}}, "doc": {"root": {"0": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 3}, "shimmer.types.LatentsDomainGroupDT": {"tf": 3}, "shimmer.types.LatentsDomainGroupsT": {"tf": 4}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 4}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.8284271247461903}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}}, "df": 9}, "1": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 14, "]": {"docs": {}, "df": 0, "^": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}}, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "^": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}}, "/": {"3": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}, "docs": {}, "df": 0}}, "2": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 19, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "b": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}}}, "3": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 7, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}, "4": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1}, "5": {"0": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}, "docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}, "6": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}, "8": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}, "9": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 2}, "docs": {"shimmer.types": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupT": {"tf": 10.816653826391969}, "shimmer.types.RawDomainGroupDT": {"tf": 9.746794344808963}, "shimmer.types.LatentsDomainGroupT": {"tf": 12.24744871391589}, "shimmer.types.LatentsDomainGroupDT": {"tf": 11.357816691600547}, "shimmer.types.RawDomainGroupsT": {"tf": 13.92838827718412}, "shimmer.types.RawDomainGroupsDT": {"tf": 13.92838827718412}, "shimmer.types.LatentsDomainGroupsT": {"tf": 16.822603841260722}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 16.186414056238647}, "shimmer.types.ModelModeT": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 2}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 6}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 5.744562646538029}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 6.164414002968976}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 5.196152422706632}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 8.426149773176359}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 8.426149773176359}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 6.48074069840786}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 6.164414002968976}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 2}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 9.695359714832659}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 10.14889156509222}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 11.045361017187261}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 10.535653752852738}, "shimmer.modules.domain": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 2.6457513110645907}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 4}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.encode": {"tf": 5.291502622129181}, "shimmer.modules.domain.DomainModule.decode": {"tf": 5.291502622129181}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 5.830951894845301}, "shimmer.modules.gw_module": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.get_n_layers": {"tf": 6.324555320336759}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoder": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 3.872983346207417}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 3.872983346207417}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 3.7416573867739413}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 5}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 6}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 5.0990195135927845}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 5.916079783099616}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 6.164414002968976}, "shimmer.modules.gw_module.GWModule": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 6}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 5.196152422706632}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 6.082762530298219}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 7.681145747868608}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 7.874007874011811}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 5.744562646538029}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 8.717797887081348}, "shimmer.modules.selection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 4.358898943540674}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 11.269427669584644}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 2.449489742783178}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 5.830951894845301}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 2.449489742783178}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 5.830951894845301}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 5.385164807134504}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 6}, "shimmer.modules.selection.RandomSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 3.7416573867739413}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 5.291502622129181}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 5.385164807134504}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 5.916079783099616}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 6}, "shimmer.modules.losses": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 5.830951894845301}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.translation_loss": {"tf": 8.366600265340756}, "shimmer.modules.losses.contrastive_loss": {"tf": 8.366600265340756}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 8.366600265340756}, "shimmer.modules.losses.LossCoefs": {"tf": 2.449489742783178}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 6.6332495807108}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 7.874007874011811}, "shimmer.modules.losses.generate_partitions": {"tf": 5.5677643628300215}, "shimmer.modules.losses.broadcast_loss": {"tf": 6.6332495807108}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 2.449489742783178}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 6}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 4.795831523312719}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 5.291502622129181}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 7.0710678118654755}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 5.196152422706632}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 5.291502622129181}, "shimmer.modules.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 2.449489742783178}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 2.449489742783178}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 6.244997998398398}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 6.244997998398398}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 6.48074069840786}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 5.830951894845301}, "shimmer.dataset": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset": {"tf": 2.23606797749979}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 5.196152422706632}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1.7320508075688772}, "shimmer.modules.vae": {"tf": 1.7320508075688772}, "shimmer.modules.vae.reparameterize": {"tf": 5.744562646538029}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 5.744562646538029}, "shimmer.modules.vae.gaussian_nll": {"tf": 6.324555320336759}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 5.0990195135927845}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 5}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 5.5677643628300215}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encode": {"tf": 5.196152422706632}, "shimmer.modules.vae.VAE.decode": {"tf": 5.291502622129181}, "shimmer.modules.vae.VAE.forward": {"tf": 5.196152422706632}, "shimmer.modules.utils": {"tf": 1.7320508075688772}, "shimmer.modules.utils.translation": {"tf": 6.928203230275509}, "shimmer.modules.utils.cycle": {"tf": 7.3484692283495345}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 6.4031242374328485}, "shimmer.modules.utils.batch_cycles": {"tf": 7}, "shimmer.modules.utils.batch_translations": {"tf": 6.4031242374328485}, "shimmer.utils": {"tf": 1.7320508075688772}, "shimmer.utils.MIGRATION_DIR": {"tf": 1.7320508075688772}, "shimmer.utils.group_batch_size": {"tf": 1.7320508075688772}, "shimmer.utils.groups_batch_size": {"tf": 5.0990195135927845}, "shimmer.utils.groups_device": {"tf": 5.0990195135927845}, "shimmer.utils.group_device": {"tf": 1.7320508075688772}, "shimmer.utils.migrate_model": {"tf": 4.898979485566356}, "shimmer.utils.SaveMigrations": {"tf": 2.23606797749979}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1.7320508075688772}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 5.291502622129181}, "shimmer.cli.ckpt_migration": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 5.744562646538029}}, "df": 232, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}}, "df": 1, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}}}}}, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 16, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 18}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}}, "y": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 1}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 2}}}, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 2, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 13}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 5}}, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 2}, "shimmer.modules.losses.cycle_loss": {"tf": 2}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 14, "e": {"docs": {"shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}}, "df": 7, "l": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 15, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 2}}}}}, "s": {"docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 33, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 16}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 14}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2}}}}}}}, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 7, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 23}}}}}, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 3, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 9, "s": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 2}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.7320508075688772}}, "df": 1, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 6, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 2}}, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "b": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 4}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}}, "df": 33, "s": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 29}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 11, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 71}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}}}}}}, "l": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}}, "df": 1}, "e": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}}, "df": 4}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}}, "df": 42}}}}}, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 5}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 26, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 36}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 9}}}}, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 2.23606797749979}}, "df": 2}}}}}}, "o": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.cycle_loss": {"tf": 3.1622776601683795}, "shimmer.modules.losses.translation_loss": {"tf": 3.3166247903554}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.7320508075688772}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 2}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 67, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_cycles": {"tf": 2}, "shimmer.modules.utils.batch_translations": {"tf": 1.7320508075688772}}, "df": 57}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 19}}}}}}}}}}, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}, "t": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 4}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 13, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 3}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 6}, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 8}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.449489742783178}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}}, "df": 13}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}}, "df": 21, "[": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 23}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 2}}}}}}}}}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}}, "df": 9}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 4}}}}}}}}}}, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 13, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}}, "df": 21}}}}}}, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 2}}}}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 3}}}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 17}, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 4}}}}}}}, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"1": {"docs": {}, "df": 0, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}, "docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}}, "df": 1}}}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}}, "df": 17, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 8}}}}}}}, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 19}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 2}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}}, "df": 1}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 66, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 5}}}}, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 3}}}}, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 8}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2}}}, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}, "l": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 2}}}}, "n": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 5}}, "k": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 3}}, "df": 1, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 18}}, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 2}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "^": {"2": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}}, "df": 1}, "docs": {}, "df": 0}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 2.23606797749979}, "shimmer.modules.utils.batch_cycles": {"tf": 2.6457513110645907}, "shimmer.modules.utils.batch_translations": {"tf": 2.449489742783178}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 110}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 4, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 12, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}}}}}}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 4}}}}, "r": {"docs": {"shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}}, "df": 2}}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 10, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 17, "c": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 9}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 2}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 2, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "f": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}}, "b": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}, "j": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 3}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1, "h": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.23606797749979}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 4}, "shimmer.modules.selection.SelectionBase": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 2}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 2}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 2.449489742783178}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 2.23606797749979}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.translation_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.LossCoefs": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.broadcast_loss": {"tf": 2}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 2}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.decode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.forward": {"tf": 2.449489742783178}, "shimmer.modules.utils.translation": {"tf": 2.23606797749979}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 2}, "shimmer.utils.groups_device": {"tf": 2}, "shimmer.utils.migrate_model": {"tf": 2}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 146, "m": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 5}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 38}}, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 11}, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 2}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 12, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}}, "df": 1}}}}}}}}, "o": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.8284271247461903}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2.23606797749979}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 2.23606797749979}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 72, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.23606797749979}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}}, "df": 69}}}, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 7}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 60}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}}}}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2, "/": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 9}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 2}}, "df": 1}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 12, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 6}}}}, "e": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 1}, "[": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 13, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 2}}}}}, "n": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 2}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 3}, "shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 80, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}}, "df": 15}, "g": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 97}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 4}}}, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}}, "df": 5, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 58}, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 10, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 17, "o": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}}}}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}}, "df": 3}}}}}, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 4}}}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 5, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 14}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 11}}}}}}}}, "v": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 5}}}}}}, "\\": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 1}}}, "y": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 5}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 7, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 14, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 12}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}}, "df": 9, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 9}}, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 7}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2}}}}}}}}, "w": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 7}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}}, "df": 8}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 3}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 1}}}, "v": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 1, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 15, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 13}}}, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}}}}}}, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 2}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}}, "df": 11, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 8}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}}}}, "i": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}}, "df": 1, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 34, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}}, "df": 3}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}}, "df": 1}}}}}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1.7320508075688772}}, "df": 17, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 9}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}}}}, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 17, "o": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 17}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 2}}}}}}}, "f": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 1}}}}}, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 3}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 12}}}}}}}}}, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 1}}}}}}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}}, "df": 4}}}, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 5}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}}, "df": 1}}}}}}}}}, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 46}, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 6, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 3, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}}}}, "m": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "f": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}}, "df": 11}, "g": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "^": {"2": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}, "docs": {}, "df": 0}}, "s": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 18, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 6}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 8}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 4}, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 4, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 12}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 8, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 2}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}, "m": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}}, "df": 4, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}, "r": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 9, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}}, "df": 3}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 26, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 11}}}}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 3}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}}, "df": 3}}}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}}, "df": 6}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2}}}}, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 1}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 2}, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 2.6457513110645907}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.7320508075688772}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}}, "df": 8, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 2}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 5, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 2}}, "df": 2}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "b": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 20, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}, "w": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 4}}}}, "a": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}}, "df": 2}}, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}}, "y": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 13}, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 9, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}}, "df": 5}}}, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 2}, "shimmer.utils.groups_device": {"tf": 2}}, "df": 18}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 4}}}}}}, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.domain.DomainModule.decode": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 2}}}, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 6}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 5}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 11}}}}, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}}}, "c": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2}}}}}}}}}, "l": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 26}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 22, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 3}, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 8}}, "d": {"docs": {"shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 3}}}}, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}}, "df": 21, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 29}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 9}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}, "e": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 3}}, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}}, "df": 6}, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 8}, "s": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}}, "df": 10, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 6}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.449489742783178}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 45, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 19, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 12}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4}}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 3}}}}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 3}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}, "x": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 24, "i": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2.8284271247461903}, "shimmer.types.RawDomainGroupDT": {"tf": 2.8284271247461903}, "shimmer.types.LatentsDomainGroupT": {"tf": 2}, "shimmer.types.LatentsDomainGroupDT": {"tf": 2}, "shimmer.types.RawDomainGroupsT": {"tf": 4.898979485566356}, "shimmer.types.RawDomainGroupsDT": {"tf": 4.898979485566356}, "shimmer.types.LatentsDomainGroupsT": {"tf": 4}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 4}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.8284271247461903}}, "df": 9}}, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 6}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}}}}, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.7320508075688772}}, "df": 3, "/": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}}}}}}, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 1}}}, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 6, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}}, "df": 2}}}}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}}, "df": 1, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}}, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 2}}}}, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 12, "d": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 15, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 10}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}}, "df": 2}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 3}}}}}}, "o": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 9}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 8}}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 55, "s": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 6, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 19}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 6}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 19}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 7}}}, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}}, "df": 4}}}}, "m": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 3, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 4, "[": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 1}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 4}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 5, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}, "n": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.6457513110645907}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 2}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 54, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 15}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 12}}}}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 4, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}}, "df": 2}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}}, "df": 3, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}}}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 7}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 22, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 1, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}, "s": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 2}}}}}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 27}, "d": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 3}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 9}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 2}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 7}, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 14, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 11}}}}}}}}}}, "p": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 2}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}}, "df": 2, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 2}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 5}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1.4142135623730951}}, "df": 14}}}, "i": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 2}}}, "y": {"docs": {"shimmer.modules.losses.cycle_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.cycle_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}}, "df": 15, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}}, "df": 12}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}}, "df": 1}}}}}}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.utils.migrate_model": {"tf": 2}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}}, "df": 3, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}}}}, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}, "w": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 8}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 8}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 5}}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 7}}, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 33, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 14}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 2}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}}, "df": 14}}}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 2}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "z": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 5, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 2}}, "df": 1}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}}, "df": 5}}}}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 3}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 4}}, "s": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 3, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "pipeline": ["trimmer"], "_isPrebuiltIndex": true};
+
+ // mirrored in build-search-index.js (part 1)
+ // Also split on html tags. this is a cheap heuristic, but good enough.
+ elasticlunr.tokenizer.setSeperator(/[\s\-.;&_'"=,()]+|<[^>]*>/);
+
+ let searchIndex;
+ if (docs._isPrebuiltIndex) {
+ console.info("using precompiled search index");
+ searchIndex = elasticlunr.Index.load(docs);
+ } else {
+ console.time("building search index");
+ // mirrored in build-search-index.js (part 2)
+ searchIndex = elasticlunr(function () {
+ this.pipeline.remove(elasticlunr.stemmer);
+ this.pipeline.remove(elasticlunr.stopWordFilter);
+ this.addField("qualname");
+ this.addField("fullname");
+ this.addField("annotation");
+ this.addField("default_value");
+ this.addField("signature");
+ this.addField("bases");
+ this.addField("doc");
+ this.setRef("fullname");
+ });
+ for (let doc of docs) {
+ searchIndex.addDoc(doc);
+ }
+ console.timeEnd("building search index");
+ }
+
+ return (term) => searchIndex.search(term, {
+ fields: {
+ qualname: {boost: 4},
+ fullname: {boost: 2},
+ annotation: {boost: 2},
+ default_value: {boost: 2},
+ signature: {boost: 2},
+ bases: {boost: 2},
+ doc: {boost: 1},
+ },
+ expand: true
+ });
+})();
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html b/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html
new file mode 100644
index 00000000..a6f4201a
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html
@@ -0,0 +1,319 @@
+
+
+
+
+
+
+ shimmer.cli.ckpt_migration API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ migrate_ckpt =
+<Command migrate-ckpt>
+
+
+
+
+
+ Script to migrate a list of checkpoints.
+This can be called with:
+
+
+
shimmer migrate-ckpt PATH_1 PATH_2 ... PATH_N
+
+
+
+
where paths point to checkpoints.
+
+
Internally, this calls shimmer.utils.migrate_model
for each of the given paths.
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/dataset.html b/docs/api/v0.5.1/shimmer/dataset.html
new file mode 100644
index 00000000..1ba30a7d
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/dataset.html
@@ -0,0 +1,448 @@
+
+
+
+
+
+
+ shimmer.dataset API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ class
+ RepeatedDataset (typing.Generic[+T_co] ):
+
+ View Source
+
+
+
+ 13 class RepeatedDataset ( Dataset ):
+14 """
+15 Dataset that cycles through its items to have a size of at least min size.
+16 If drop_last is True, the size will be exaclty min_size. If drop_last is False,
+17 the min_size ≤ size < min_size + len(dataset).
+18 """
+19
+20 def __init__ ( self , dataset : _SizedDataset , min_size : int , drop_last : bool = False ):
+21 """
+22 Args:
+23 dataset (SizedDataset): dataset to repeat. The dataset should have a size
+24 (where `__len__` is defined).
+25 min_size (int): minimum size of the final dataset
+26 drop_last (bool): whether to remove overflow when repeating the
+27 dataset.
+28 """
+29 self . dataset = dataset
+30 assert min_size >= len ( self . dataset )
+31 self . dataset_size = len ( self . dataset )
+32 if drop_last :
+33 self . total_size = min_size
+34 else :
+35 self . total_size = (
+36 min_size // self . dataset_size + int ( min_size % self . dataset_size > 0 )
+37 ) * self . dataset_size
+38
+39 def __len__ ( self ) -> int :
+40 """
+41 Size of the dataset. Will be min_size if drop_last is True.
+42 Otherwise, min_size ≤ size < min_size + len(dataset).
+43 """
+44 return self . total_size
+45
+46 def __getitem__ ( self , index : int ) -> Any :
+47 return self . dataset [ index % self . dataset_size ]
+
+
+
+ Dataset that cycles through its items to have a size of at least min size.
+If drop_last is True, the size will be exaclty min_size. If drop_last is False,
+the min_size ≤ size < min_size + len(dataset).
+
+
+
+
+
+
+
+ RepeatedDataset ( dataset : shimmer . dataset . _SizedDataset , min_size : int , drop_last : bool = False )
+
+ View Source
+
+
+
+
20 def __init__ ( self , dataset : _SizedDataset , min_size : int , drop_last : bool = False ):
+21 """
+22 Args:
+23 dataset (SizedDataset): dataset to repeat. The dataset should have a size
+24 (where `__len__` is defined).
+25 min_size (int): minimum size of the final dataset
+26 drop_last (bool): whether to remove overflow when repeating the
+27 dataset.
+28 """
+29 self . dataset = dataset
+30 assert min_size >= len ( self . dataset )
+31 self . dataset_size = len ( self . dataset )
+32 if drop_last :
+33 self . total_size = min_size
+34 else :
+35 self . total_size = (
+36 min_size // self . dataset_size + int ( min_size % self . dataset_size > 0 )
+37 ) * self . dataset_size
+
+
+
+
Arguments:
+
+
+dataset (SizedDataset): dataset to repeat. The dataset should have a size
+(where __len__
is defined).
+min_size (int): minimum size of the final dataset
+drop_last (bool): whether to remove overflow when repeating the
+dataset.
+
+
+
+
+
+
+
+ dataset
+
+
+
+
+
+
+
+
+
+
+ dataset_size
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html b/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html
new file mode 100644
index 00000000..55fb2e87
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html
@@ -0,0 +1,797 @@
+
+
+
+
+
+
+ shimmer.modules.contrastive_loss API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.contrastive_loss
+
+ Various contrastive loss definitions
+
+
+
+
+ View Source
+
+ 1 """Various contrastive loss definitions"""
+ 2
+ 3 from collections.abc import Callable
+ 4 from typing import Literal
+ 5
+ 6 import torch
+ 7 from torch.nn.functional import cross_entropy , normalize
+ 8
+ 9 from shimmer.modules.domain import LossOutput
+ 10
+ 11 ContrastiveLossType = Callable [[ torch . Tensor , torch . Tensor ], LossOutput ]
+ 12 """
+ 13 Contrastive loss function type.
+ 14
+ 15 A function taking the prediction and targets and returning a LossOutput.
+ 16 """
+ 17
+ 18 ContrastiveLossBayesianType = Callable [
+ 19 [ torch . Tensor , torch . Tensor , torch . Tensor , torch . Tensor ], LossOutput
+ 20 ]
+ 21 """
+ 22 Contrastive loss function type for GlobalWorkspaceBayesian.
+ 23
+ 24 A function taking the prediction mean, prediction std, target mean and target std and
+ 25 returns a LossOutput.
+ 26 """
+ 27
+ 28
+ 29 def info_nce (
+ 30 x : torch . Tensor ,
+ 31 y : torch . Tensor ,
+ 32 logit_scale : torch . Tensor ,
+ 33 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+ 34 ) -> torch . Tensor :
+ 35 """
+ 36 InfoNCE loss
+ 37
+ 38 Args:
+ 39 x (`torch.Tensor`): prediction
+ 40 y (`torch.Tensor`): target
+ 41 logit_scale (`torch.Tensor`): logit scale
+ 42 reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+ 43
+ 44 Returns: the InfoNCE loss
+ 45 """
+ 46 xn = normalize ( x )
+ 47 yn = normalize ( y )
+ 48 logits = torch . clamp ( logit_scale . exp (), max = 100 ) * xn @ yn . t ()
+ 49 labels = torch . arange ( xn . size ( 0 )) . to ( logits . device )
+ 50 return cross_entropy ( logits , labels , reduction = reduction )
+ 51
+ 52
+ 53 def contrastive_loss (
+ 54 x : torch . Tensor ,
+ 55 y : torch . Tensor ,
+ 56 logit_scale : torch . Tensor ,
+ 57 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+ 58 ) -> torch . Tensor :
+ 59 """
+ 60 CLIP-like contrastive loss
+ 61
+ 62 Args:
+ 63 x (`torch.Tensor`): prediction
+ 64 y (`torch.Tensor`): target
+ 65 logit_scale (`torch.Tensor`): logit scale
+ 66 reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+ 67
+ 68 Returns: the contrastive loss
+ 69 """
+ 70 xn = normalize ( x )
+ 71 yn = normalize ( y )
+ 72 logits = torch . clamp ( logit_scale . exp (), max = 100 ) * xn @ yn . t ()
+ 73 labels = torch . arange ( xn . size ( 0 )) . to ( logits . device )
+ 74 ce = cross_entropy ( logits , labels , reduction = reduction )
+ 75 ce_t = cross_entropy ( logits . t (), labels , reduction = reduction )
+ 76 return 0.5 * ( ce + ce_t )
+ 77
+ 78
+ 79 class ContrastiveLoss ( torch . nn . Module ):
+ 80 """CLIP-like ContrastiveLoss torch module."""
+ 81
+ 82 def __init__ (
+ 83 self ,
+ 84 logit_scale : torch . Tensor ,
+ 85 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+ 86 learn_logit_scale : bool = False ,
+ 87 ) -> None :
+ 88 """
+ 89 Initializes a contrastive loss.
+ 90
+ 91 Args:
+ 92 logit_scale (`torch.Tensor`): logit_scale tensor.
+ 93 reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 94 loss. Defaults to `"mean"`.
+ 95 learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 96 parameter. Defaults to `False`.
+ 97 """
+ 98 super () . __init__ ()
+ 99
+100 if learn_logit_scale :
+101 self . logit_scale = torch . nn . Parameter ( logit_scale )
+102 else :
+103 self . register_buffer ( "logit_scale" , logit_scale )
+104 self . learn_logit_scale = learn_logit_scale
+105 self . reduction : Literal [ "mean" , "sum" , "none" ] = reduction
+106
+107 def forward ( self , x : torch . Tensor , y : torch . Tensor ) -> LossOutput :
+108 """
+109 Computes the loss.
+110
+111 Args:
+112 x (`torch.Tensor`): prediction
+113 y (`torch.Tensor`): target
+114
+115 Returns:
+116 LossOutput of the loss. Contains a `logit_scale` metric.
+117 """
+118 return LossOutput (
+119 contrastive_loss ( x , y , self . logit_scale , self . reduction ),
+120 { "logit_scale" : self . logit_scale . exp ()},
+121 )
+
+
+
+
+
+
+
+
+ Contrastive loss function type.
+
+
A function taking the prediction and targets and returning a LossOutput.
+
+
+
+
+
+
+
+
+ Contrastive loss function type for GlobalWorkspaceBayesian.
+
+
A function taking the prediction mean, prediction std, target mean and target std and
+ returns a LossOutput.
+
+
+
+
+
+
+
+
+ def
+ info_nce ( x : torch . Tensor , y : torch . Tensor , logit_scale : torch . Tensor , reduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 30 def info_nce (
+31 x : torch . Tensor ,
+32 y : torch . Tensor ,
+33 logit_scale : torch . Tensor ,
+34 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+35 ) -> torch . Tensor :
+36 """
+37 InfoNCE loss
+38
+39 Args:
+40 x (`torch.Tensor`): prediction
+41 y (`torch.Tensor`): target
+42 logit_scale (`torch.Tensor`): logit scale
+43 reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+44
+45 Returns: the InfoNCE loss
+46 """
+47 xn = normalize ( x )
+48 yn = normalize ( y )
+49 logits = torch . clamp ( logit_scale . exp (), max = 100 ) * xn @ yn . t ()
+50 labels = torch . arange ( xn . size ( 0 )) . to ( logits . device )
+51 return cross_entropy ( logits , labels , reduction = reduction )
+
+
+
+ InfoNCE loss
+
+
Arguments:
+
+
+x (torch.Tensor
): prediction
+y (torch.Tensor
): target
+logit_scale (torch.Tensor
): logit scale
+reduction (Literal["mean", "sum", "none"]
): reduction to apply
+
+
+
Returns: the InfoNCE loss
+
+
+
+
+
+
+
+
+ def
+ contrastive_loss ( x : torch . Tensor , y : torch . Tensor , logit_scale : torch . Tensor , reduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 54 def contrastive_loss (
+55 x : torch . Tensor ,
+56 y : torch . Tensor ,
+57 logit_scale : torch . Tensor ,
+58 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+59 ) -> torch . Tensor :
+60 """
+61 CLIP-like contrastive loss
+62
+63 Args:
+64 x (`torch.Tensor`): prediction
+65 y (`torch.Tensor`): target
+66 logit_scale (`torch.Tensor`): logit scale
+67 reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+68
+69 Returns: the contrastive loss
+70 """
+71 xn = normalize ( x )
+72 yn = normalize ( y )
+73 logits = torch . clamp ( logit_scale . exp (), max = 100 ) * xn @ yn . t ()
+74 labels = torch . arange ( xn . size ( 0 )) . to ( logits . device )
+75 ce = cross_entropy ( logits , labels , reduction = reduction )
+76 ce_t = cross_entropy ( logits . t (), labels , reduction = reduction )
+77 return 0.5 * ( ce + ce_t )
+
+
+
+ CLIP-like contrastive loss
+
+
Arguments:
+
+
+x (torch.Tensor
): prediction
+y (torch.Tensor
): target
+logit_scale (torch.Tensor
): logit scale
+reduction (Literal["mean", "sum", "none"]
): reduction to apply
+
+
+
Returns: the contrastive loss
+
+
+
+
+
+
+
+
+ class
+ ContrastiveLoss (torch.nn.modules.module.Module ):
+
+ View Source
+
+
+
+ 80 class ContrastiveLoss ( torch . nn . Module ):
+ 81 """CLIP-like ContrastiveLoss torch module."""
+ 82
+ 83 def __init__ (
+ 84 self ,
+ 85 logit_scale : torch . Tensor ,
+ 86 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+ 87 learn_logit_scale : bool = False ,
+ 88 ) -> None :
+ 89 """
+ 90 Initializes a contrastive loss.
+ 91
+ 92 Args:
+ 93 logit_scale (`torch.Tensor`): logit_scale tensor.
+ 94 reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 95 loss. Defaults to `"mean"`.
+ 96 learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 97 parameter. Defaults to `False`.
+ 98 """
+ 99 super () . __init__ ()
+100
+101 if learn_logit_scale :
+102 self . logit_scale = torch . nn . Parameter ( logit_scale )
+103 else :
+104 self . register_buffer ( "logit_scale" , logit_scale )
+105 self . learn_logit_scale = learn_logit_scale
+106 self . reduction : Literal [ "mean" , "sum" , "none" ] = reduction
+107
+108 def forward ( self , x : torch . Tensor , y : torch . Tensor ) -> LossOutput :
+109 """
+110 Computes the loss.
+111
+112 Args:
+113 x (`torch.Tensor`): prediction
+114 y (`torch.Tensor`): target
+115
+116 Returns:
+117 LossOutput of the loss. Contains a `logit_scale` metric.
+118 """
+119 return LossOutput (
+120 contrastive_loss ( x , y , self . logit_scale , self . reduction ),
+121 { "logit_scale" : self . logit_scale . exp ()},
+122 )
+
+
+
+ CLIP-like ContrastiveLoss torch module.
+
+
+
+
+
+
+
+ ContrastiveLoss ( logit_scale : torch . Tensor , reduction : Literal [ 'mean' , 'sum' , 'none' ] = 'mean' , learn_logit_scale : bool = False )
+
+ View Source
+
+
+
+
83 def __init__ (
+ 84 self ,
+ 85 logit_scale : torch . Tensor ,
+ 86 reduction : Literal [ "mean" , "sum" , "none" ] = "mean" ,
+ 87 learn_logit_scale : bool = False ,
+ 88 ) -> None :
+ 89 """
+ 90 Initializes a contrastive loss.
+ 91
+ 92 Args:
+ 93 logit_scale (`torch.Tensor`): logit_scale tensor.
+ 94 reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 95 loss. Defaults to `"mean"`.
+ 96 learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 97 parameter. Defaults to `False`.
+ 98 """
+ 99 super () . __init__ ()
+100
+101 if learn_logit_scale :
+102 self . logit_scale = torch . nn . Parameter ( logit_scale )
+103 else :
+104 self . register_buffer ( "logit_scale" , logit_scale )
+105 self . learn_logit_scale = learn_logit_scale
+106 self . reduction : Literal [ "mean" , "sum" , "none" ] = reduction
+
+
+
+
Initializes a contrastive loss.
+
+
Arguments:
+
+
+logit_scale (torch.Tensor
): logit_scale tensor.
+reduction (Literal["mean", "sum", "none"]
): reduction to apply to the
+loss. Defaults to "mean"
.
+learn_logit_scale (torch.Tensor
): whether to learn the logit_scale
+parameter. Defaults to False
.
+
+
+
+
+
+
+
+ learn_logit_scale
+
+
+
+
+
+
+
+
+
+
+ reduction : Literal['mean', 'sum', 'none']
+
+
+
+
+
+
+
+
+
+
+
+
+
108 def forward ( self , x : torch . Tensor , y : torch . Tensor ) -> LossOutput :
+109 """
+110 Computes the loss.
+111
+112 Args:
+113 x (`torch.Tensor`): prediction
+114 y (`torch.Tensor`): target
+115
+116 Returns:
+117 LossOutput of the loss. Contains a `logit_scale` metric.
+118 """
+119 return LossOutput (
+120 contrastive_loss ( x , y , self . logit_scale , self . reduction ),
+121 { "logit_scale" : self . logit_scale . exp ()},
+122 )
+
+
+
+
Computes the loss.
+
+
Arguments:
+
+
+x (torch.Tensor
): prediction
+y (torch.Tensor
): target
+
+
+
Returns:
+
+
+ LossOutput of the loss. Contains a logit_scale
metric.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/domain.html b/docs/api/v0.5.1/shimmer/modules/domain.html
new file mode 100644
index 00000000..5462b500
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/domain.html
@@ -0,0 +1,1220 @@
+
+
+
+
+
+
+ shimmer.modules.domain API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.domain
+
+
+
+
+ View Source
+
+ 1 from dataclasses import dataclass , field
+ 2 from typing import Any
+ 3
+ 4 import lightning.pytorch as pl
+ 5 import torch
+ 6
+ 7
+ 8 @dataclass
+ 9 class LossOutput :
+ 10 """
+ 11 This is a python dataclass use as a returned value for losses.
+ 12 It keeps track of what is used for training (`loss`) and what is used
+ 13 only for logging (`metrics`).
+ 14 """
+ 15
+ 16 loss : torch . Tensor
+ 17 """Loss used during training."""
+ 18
+ 19 metrics : dict [ str , torch . Tensor ] = field ( default_factory = dict )
+ 20 """Some additional metrics to log (not used during training)."""
+ 21
+ 22 def __post_init__ ( self ):
+ 23 if "loss" in self . metrics :
+ 24 raise ValueError ( "'loss' cannot be a key of metrics." )
+ 25
+ 26 @property
+ 27 def all ( self ) -> dict [ str , torch . Tensor ]:
+ 28 """
+ 29 Returns a dict with all metrics and loss with "loss" key.
+ 30 """
+ 31 return { ** self . metrics , "loss" : self . loss }
+ 32
+ 33
+ 34 class DomainModule ( pl . LightningModule ):
+ 35 """
+ 36 Base class for a DomainModule that defines domain specific modules of the GW.
+ 37 """
+ 38
+ 39 def __init__ (
+ 40 self ,
+ 41 latent_dim : int ,
+ 42 ) -> None :
+ 43 """
+ 44 Initializes a DomainModule.
+ 45
+ 46 Args:
+ 47 latent_dim (`int`): latent dimension of the unimodal module
+ 48 """
+ 49 super () . __init__ ()
+ 50
+ 51 self . latent_dim = latent_dim
+ 52 """The latent dimension of the module."""
+ 53
+ 54 def encode ( self , x : Any ) -> torch . Tensor :
+ 55 """
+ 56 Encode the domain data into a unimodal representation.
+ 57
+ 58 Args:
+ 59 x (`Any`): data of the domain.
+ 60 Returns:
+ 61 `torch.Tensor`: a unimodal representation.
+ 62 """
+ 63 raise NotImplementedError
+ 64
+ 65 def decode ( self , z : torch . Tensor ) -> Any :
+ 66 """
+ 67 Decode data from unimodal representation back to the domain data.
+ 68
+ 69 Args:
+ 70 z (`torch.Tensor`): unimodal representation of the domain.
+ 71 Returns:
+ 72 `Any`: the original domain data.
+ 73 """
+ 74 raise NotImplementedError
+ 75
+ 76 def compute_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+ 77 """
+ 78 Generic loss computation the modality.
+ 79
+ 80 Args:
+ 81 pred (`torch.Tensor`): prediction of the model
+ 82 target (`torch.Tensor`): target tensor
+ 83 Results:
+ 84 `LossOutput`: LossOuput with training loss and additional metrics.
+ 85 """
+ 86 raise NotImplementedError
+ 87
+ 88 def compute_dcy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+ 89 """
+ 90 Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 91 different that the generic loss.
+ 92
+ 93 Args:
+ 94 pred (`torch.Tensor`): prediction of the model
+ 95 target (`torch.Tensor`): target tensor
+ 96 Results:
+ 97 `LossOutput`: LossOuput with training loss and additional metrics.
+ 98 """
+ 99 return self . compute_loss ( pred , target )
+100
+101 def compute_cy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+102 """
+103 Computes the loss for a cycle. Override if the cycle loss is
+104 different that the generic loss.
+105
+106 Args:
+107 pred (`torch.Tensor`): prediction of the model
+108 target (`torch.Tensor`): target tensor
+109 Results:
+110 `LossOutput`: LossOuput with training loss and additional metrics.
+111 """
+112 return self . compute_loss ( pred , target )
+113
+114 def compute_tr_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+115 """
+116 Computes the loss for a translation. Override if the translation loss is
+117 different that the generic loss.
+118
+119 Args:
+120 pred (`torch.Tensor`): prediction of the model
+121 target (`torch.Tensor`): target tensor
+122 Results:
+123 `LossOutput`: LossOuput with training loss and additional metrics.
+124 """
+125 return self . compute_loss ( pred , target )
+126
+127 def compute_broadcast_loss (
+128 self , pred : torch . Tensor , target : torch . Tensor
+129 ) -> LossOutput :
+130 """
+131 Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+132 different that the generic loss.
+133
+134 Args:
+135 pred (`torch.Tensor`): prediction of the model
+136 target (`torch.Tensor`): target tensor
+137 Results:
+138 `LossOutput`: LossOuput with training loss and additional metrics.
+139 """
+140 return self . compute_loss ( pred , target )
+
+
+
+
+
+
+
+
@dataclass
+
+
class
+
LossOutput :
+
+
View Source
+
+
+
+ 9 @dataclass
+10 class LossOutput :
+11 """
+12 This is a python dataclass use as a returned value for losses.
+13 It keeps track of what is used for training (`loss`) and what is used
+14 only for logging (`metrics`).
+15 """
+16
+17 loss : torch . Tensor
+18 """Loss used during training."""
+19
+20 metrics : dict [ str , torch . Tensor ] = field ( default_factory = dict )
+21 """Some additional metrics to log (not used during training)."""
+22
+23 def __post_init__ ( self ):
+24 if "loss" in self . metrics :
+25 raise ValueError ( "'loss' cannot be a key of metrics." )
+26
+27 @property
+28 def all ( self ) -> dict [ str , torch . Tensor ]:
+29 """
+30 Returns a dict with all metrics and loss with "loss" key.
+31 """
+32 return { ** self . metrics , "loss" : self . loss }
+
+
+
+ This is a python dataclass use as a returned value for losses.
+It keeps track of what is used for training (loss
) and what is used
+only for logging (metrics
).
+
+
+
+
+
+
+ LossOutput (loss : torch . Tensor , metrics : dict [ str , torch . Tensor ] = < factory > )
+
+
+
+
+
+
+
+
+
+
+ loss : torch.Tensor
+
+
+
+
+
+
Loss used during training.
+
+
+
+
+
+
+ metrics : dict[str, torch.Tensor]
+
+
+
+
+
+
Some additional metrics to log (not used during training).
+
+
+
+
+
+
+
+ all : dict[str, torch.Tensor]
+
+ View Source
+
+
+
+
27 @property
+28 def all ( self ) -> dict [ str , torch . Tensor ]:
+29 """
+30 Returns a dict with all metrics and loss with "loss" key.
+31 """
+32 return { ** self . metrics , "loss" : self . loss }
+
+
+
+
Returns a dict with all metrics and loss with "loss" key.
+
+
+
+
+
+
+
+
+
+ class
+ DomainModule (lightning.pytorch.core.module.LightningModule ):
+
+ View Source
+
+
+
+ 35 class DomainModule ( pl . LightningModule ):
+ 36 """
+ 37 Base class for a DomainModule that defines domain specific modules of the GW.
+ 38 """
+ 39
+ 40 def __init__ (
+ 41 self ,
+ 42 latent_dim : int ,
+ 43 ) -> None :
+ 44 """
+ 45 Initializes a DomainModule.
+ 46
+ 47 Args:
+ 48 latent_dim (`int`): latent dimension of the unimodal module
+ 49 """
+ 50 super () . __init__ ()
+ 51
+ 52 self . latent_dim = latent_dim
+ 53 """The latent dimension of the module."""
+ 54
+ 55 def encode ( self , x : Any ) -> torch . Tensor :
+ 56 """
+ 57 Encode the domain data into a unimodal representation.
+ 58
+ 59 Args:
+ 60 x (`Any`): data of the domain.
+ 61 Returns:
+ 62 `torch.Tensor`: a unimodal representation.
+ 63 """
+ 64 raise NotImplementedError
+ 65
+ 66 def decode ( self , z : torch . Tensor ) -> Any :
+ 67 """
+ 68 Decode data from unimodal representation back to the domain data.
+ 69
+ 70 Args:
+ 71 z (`torch.Tensor`): unimodal representation of the domain.
+ 72 Returns:
+ 73 `Any`: the original domain data.
+ 74 """
+ 75 raise NotImplementedError
+ 76
+ 77 def compute_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+ 78 """
+ 79 Generic loss computation the modality.
+ 80
+ 81 Args:
+ 82 pred (`torch.Tensor`): prediction of the model
+ 83 target (`torch.Tensor`): target tensor
+ 84 Results:
+ 85 `LossOutput`: LossOuput with training loss and additional metrics.
+ 86 """
+ 87 raise NotImplementedError
+ 88
+ 89 def compute_dcy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+ 90 """
+ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 92 different that the generic loss.
+ 93
+ 94 Args:
+ 95 pred (`torch.Tensor`): prediction of the model
+ 96 target (`torch.Tensor`): target tensor
+ 97 Results:
+ 98 `LossOutput`: LossOuput with training loss and additional metrics.
+ 99 """
+100 return self . compute_loss ( pred , target )
+101
+102 def compute_cy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+103 """
+104 Computes the loss for a cycle. Override if the cycle loss is
+105 different that the generic loss.
+106
+107 Args:
+108 pred (`torch.Tensor`): prediction of the model
+109 target (`torch.Tensor`): target tensor
+110 Results:
+111 `LossOutput`: LossOuput with training loss and additional metrics.
+112 """
+113 return self . compute_loss ( pred , target )
+114
+115 def compute_tr_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+116 """
+117 Computes the loss for a translation. Override if the translation loss is
+118 different that the generic loss.
+119
+120 Args:
+121 pred (`torch.Tensor`): prediction of the model
+122 target (`torch.Tensor`): target tensor
+123 Results:
+124 `LossOutput`: LossOuput with training loss and additional metrics.
+125 """
+126 return self . compute_loss ( pred , target )
+127
+128 def compute_broadcast_loss (
+129 self , pred : torch . Tensor , target : torch . Tensor
+130 ) -> LossOutput :
+131 """
+132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+133 different that the generic loss.
+134
+135 Args:
+136 pred (`torch.Tensor`): prediction of the model
+137 target (`torch.Tensor`): target tensor
+138 Results:
+139 `LossOutput`: LossOuput with training loss and additional metrics.
+140 """
+141 return self . compute_loss ( pred , target )
+
+
+
+ Base class for a DomainModule that defines domain specific modules of the GW.
+
+
+
+
+
+
+
+ DomainModule (latent_dim : int )
+
+ View Source
+
+
+
+
40 def __init__ (
+41 self ,
+42 latent_dim : int ,
+43 ) -> None :
+44 """
+45 Initializes a DomainModule.
+46
+47 Args:
+48 latent_dim (`int`): latent dimension of the unimodal module
+49 """
+50 super () . __init__ ()
+51
+52 self . latent_dim = latent_dim
+53 """The latent dimension of the module."""
+
+
+
+
Initializes a DomainModule.
+
+
Arguments:
+
+
+latent_dim (int
): latent dimension of the unimodal module
+
+
+
+
+
+
+
+ latent_dim
+
+
+
+
+
+
The latent dimension of the module.
+
+
+
+
+
+
+
+
+ def
+ encode (self , x : Any ) -> torch . Tensor :
+
+ View Source
+
+
+
+
55 def encode ( self , x : Any ) -> torch . Tensor :
+56 """
+57 Encode the domain data into a unimodal representation.
+58
+59 Args:
+60 x (`Any`): data of the domain.
+61 Returns:
+62 `torch.Tensor`: a unimodal representation.
+63 """
+64 raise NotImplementedError
+
+
+
+
Encode the domain data into a unimodal representation.
+
+
Arguments:
+
+
+x (Any
): data of the domain.
+
+
+
Returns:
+
+
+ torch.Tensor
: a unimodal representation.
+
+
+
+
+
+
+
+
+
+ def
+ decode (self , z : torch . Tensor ) -> Any :
+
+ View Source
+
+
+
+
66 def decode ( self , z : torch . Tensor ) -> Any :
+67 """
+68 Decode data from unimodal representation back to the domain data.
+69
+70 Args:
+71 z (`torch.Tensor`): unimodal representation of the domain.
+72 Returns:
+73 `Any`: the original domain data.
+74 """
+75 raise NotImplementedError
+
+
+
+
Decode data from unimodal representation back to the domain data.
+
+
Arguments:
+
+
+z (torch.Tensor
): unimodal representation of the domain.
+
+
+
Returns:
+
+
+ Any
: the original domain data.
+
+
+
+
+
+
+
+
+
+
def
+
compute_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+
+
View Source
+
+
+
+
77 def compute_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+78 """
+79 Generic loss computation the modality.
+80
+81 Args:
+82 pred (`torch.Tensor`): prediction of the model
+83 target (`torch.Tensor`): target tensor
+84 Results:
+85 `LossOutput`: LossOuput with training loss and additional metrics.
+86 """
+87 raise NotImplementedError
+
+
+
+
Generic loss computation the modality.
+
+
Arguments:
+
+
+pred (torch.Tensor
): prediction of the model
+target (torch.Tensor
): target tensor
+
+
+
Results:
+
+
+ LossOutput
: LossOuput with training loss and additional metrics.
+
+
+
+
+
+
+
+
+
+
def
+
compute_dcy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+
+
View Source
+
+
+
+
89 def compute_dcy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+ 90 """
+ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 92 different that the generic loss.
+ 93
+ 94 Args:
+ 95 pred (`torch.Tensor`): prediction of the model
+ 96 target (`torch.Tensor`): target tensor
+ 97 Results:
+ 98 `LossOutput`: LossOuput with training loss and additional metrics.
+ 99 """
+100 return self . compute_loss ( pred , target )
+
+
+
+
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+different that the generic loss.
+
+
Arguments:
+
+
+pred (torch.Tensor
): prediction of the model
+target (torch.Tensor
): target tensor
+
+
+
Results:
+
+
+ LossOutput
: LossOuput with training loss and additional metrics.
+
+
+
+
+
+
+
+
+
+
def
+
compute_cy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+
+
View Source
+
+
+
+
102 def compute_cy_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+103 """
+104 Computes the loss for a cycle. Override if the cycle loss is
+105 different that the generic loss.
+106
+107 Args:
+108 pred (`torch.Tensor`): prediction of the model
+109 target (`torch.Tensor`): target tensor
+110 Results:
+111 `LossOutput`: LossOuput with training loss and additional metrics.
+112 """
+113 return self . compute_loss ( pred , target )
+
+
+
+
Computes the loss for a cycle. Override if the cycle loss is
+different that the generic loss.
+
+
Arguments:
+
+
+pred (torch.Tensor
): prediction of the model
+target (torch.Tensor
): target tensor
+
+
+
Results:
+
+
+ LossOutput
: LossOuput with training loss and additional metrics.
+
+
+
+
+
+
+
+
+
+
def
+
compute_tr_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+
+
View Source
+
+
+
+
115 def compute_tr_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+116 """
+117 Computes the loss for a translation. Override if the translation loss is
+118 different that the generic loss.
+119
+120 Args:
+121 pred (`torch.Tensor`): prediction of the model
+122 target (`torch.Tensor`): target tensor
+123 Results:
+124 `LossOutput`: LossOuput with training loss and additional metrics.
+125 """
+126 return self . compute_loss ( pred , target )
+
+
+
+
Computes the loss for a translation. Override if the translation loss is
+different that the generic loss.
+
+
Arguments:
+
+
+pred (torch.Tensor
): prediction of the model
+target (torch.Tensor
): target tensor
+
+
+
Results:
+
+
+ LossOutput
: LossOuput with training loss and additional metrics.
+
+
+
+
+
+
+
+
+
+
def
+
compute_broadcast_loss ( self , pred : torch . Tensor , target : torch . Tensor ) -> LossOutput :
+
+
View Source
+
+
+
+
128 def compute_broadcast_loss (
+129 self , pred : torch . Tensor , target : torch . Tensor
+130 ) -> LossOutput :
+131 """
+132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+133 different that the generic loss.
+134
+135 Args:
+136 pred (`torch.Tensor`): prediction of the model
+137 target (`torch.Tensor`): target tensor
+138 Results:
+139 `LossOutput`: LossOuput with training loss and additional metrics.
+140 """
+141 return self . compute_loss ( pred , target )
+
+
+
+
Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+different that the generic loss.
+
+
Arguments:
+
+
+pred (torch.Tensor
): prediction of the model
+target (torch.Tensor
): target tensor
+
+
+
Results:
+
+
+ LossOutput
: LossOuput with training loss and additional metrics.
+
+
+
+
+
+
+
Inherited Members
+
+
lightning.pytorch.core.module.LightningModule
+ CHECKPOINT_HYPER_PARAMS_KEY
+ CHECKPOINT_HYPER_PARAMS_NAME
+ CHECKPOINT_HYPER_PARAMS_TYPE
+ optimizers
+ lr_schedulers
+ trainer
+ fabric
+ example_input_array
+ current_epoch
+ global_step
+ global_rank
+ local_rank
+ on_gpu
+ automatic_optimization
+ strict_loading
+ logger
+ loggers
+ print
+ log
+ log_dict
+ all_gather
+ forward
+ training_step
+ validation_step
+ test_step
+ predict_step
+ configure_callbacks
+ configure_optimizers
+ manual_backward
+ backward
+ toggle_optimizer
+ untoggle_optimizer
+ clip_gradients
+ configure_gradient_clipping
+ lr_scheduler_step
+ optimizer_step
+ optimizer_zero_grad
+ freeze
+ unfreeze
+ to_onnx
+ to_torchscript
+ load_from_checkpoint
+
+
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+ dtype
+ device
+ to
+ cuda
+ cpu
+ type
+ float
+ double
+ half
+
+
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+ save_hyperparameters
+ hparams
+ hparams_initial
+
+
+
lightning.pytorch.core.hooks.ModelHooks
+ on_fit_start
+ on_fit_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_train_batch_start
+ on_train_batch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_validation_model_zero_grad
+ on_validation_model_eval
+ on_validation_model_train
+ on_test_model_eval
+ on_test_model_train
+ on_predict_model_eval
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_before_zero_grad
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ configure_sharded_model
+ configure_model
+
+
+
lightning.pytorch.core.hooks.DataHooks
+ prepare_data_per_node
+ allow_zero_length_dataloader_with_multiple_devices
+ prepare_data
+ setup
+ teardown
+ train_dataloader
+ test_dataloader
+ val_dataloader
+ predict_dataloader
+ transfer_batch_to_device
+ on_before_batch_transfer
+ on_after_batch_transfer
+
+
+
lightning.pytorch.core.hooks.CheckpointHooks
+ on_load_checkpoint
+ on_save_checkpoint
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+ get_extra_state
+ set_extra_state
+ apply
+ ipu
+ xpu
+ bfloat16
+ to_empty
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ extra_repr
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/global_workspace.html b/docs/api/v0.5.1/shimmer/modules/global_workspace.html
new file mode 100644
index 00000000..03727560
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/global_workspace.html
@@ -0,0 +1,4260 @@
+
+
+
+
+
+
+ shimmer.modules.global_workspace API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.global_workspace
+
+
+
+
+ View Source
+
+ 1 from collections.abc import Iterable , Mapping
+ 2 from pathlib import Path
+ 3 from typing import Any , Generic , TypedDict , TypeVar , cast
+ 4
+ 5 import torch
+ 6 from lightning.pytorch import LightningModule
+ 7 from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
+ 8 from torch.nn import Module , ModuleDict
+ 9 from torch.optim.lr_scheduler import OneCycleLR
+ 10
+ 11 from shimmer.modules.contrastive_loss import ContrastiveLoss , ContrastiveLossType
+ 12 from shimmer.modules.domain import DomainModule
+ 13 from shimmer.modules.gw_module import (
+ 14 GWModule ,
+ 15 GWModuleBase ,
+ 16 GWModuleBayesian ,
+ 17 )
+ 18 from shimmer.modules.losses import (
+ 19 BroadcastLossCoefs ,
+ 20 GWLosses ,
+ 21 GWLosses2Domains ,
+ 22 GWLossesBase ,
+ 23 GWLossesBayesian ,
+ 24 LossCoefs ,
+ 25 )
+ 26 from shimmer.modules.selection import (
+ 27 FixedSharedSelection ,
+ 28 RandomSelection ,
+ 29 SelectionBase ,
+ 30 SingleDomainSelection ,
+ 31 )
+ 32 from shimmer.modules.utils import batch_cycles , batch_demi_cycles , batch_translations
+ 33 from shimmer.types import (
+ 34 LatentsDomainGroupsDT ,
+ 35 LatentsDomainGroupsT ,
+ 36 ModelModeT ,
+ 37 RawDomainGroupsDT ,
+ 38 RawDomainGroupsT ,
+ 39 RawDomainGroupT ,
+ 40 )
+ 41 from shimmer.utils import groups_batch_size
+ 42
+ 43
+ 44 class SchedulerArgs ( TypedDict , total = False ):
+ 45 """TypedDict of arguments passed to the OneCycle scheduler"""
+ 46
+ 47 max_lr : float
+ 48 """Maximum learning rate"""
+ 49
+ 50 total_steps : int
+ 51 """Total number of steps"""
+ 52
+ 53
+ 54 class GWPredictionsBase ( TypedDict ):
+ 55 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+ 56
+ 57 states : dict [ str , torch . Tensor ]
+ 58 """
+ 59 GW state representation from domain groups with only one domain.
+ 60 The key represent the domain's name.
+ 61 """
+ 62
+ 63
+ 64 _T_gw_mod = TypeVar ( "_T_gw_mod" , bound = GWModuleBase )
+ 65 _T_selection_mod = TypeVar ( "_T_selection_mod" , bound = SelectionBase )
+ 66 _T_loss_mod = TypeVar ( "_T_loss_mod" , bound = GWLossesBase )
+ 67
+ 68
+ 69 class GlobalWorkspaceBase (
+ 70 Generic [ _T_gw_mod , _T_selection_mod , _T_loss_mod ], LightningModule
+ 71 ):
+ 72 """
+ 73 Global Workspace Lightning Module.
+ 74
+ 75 This is the base class to build the Global Workspace.
+ 76 """
+ 77
+ 78 def __init__ (
+ 79 self ,
+ 80 gw_mod : _T_gw_mod ,
+ 81 selection_mod : _T_selection_mod ,
+ 82 loss_mod : _T_loss_mod ,
+ 83 optim_lr : float = 1e-3 ,
+ 84 optim_weight_decay : float = 0.0 ,
+ 85 scheduler_args : SchedulerArgs | None = None ,
+ 86 ) -> None :
+ 87 """
+ 88 Initializes a GW
+ 89
+ 90 Args:
+ 91 gw_mod (`GWModuleBase`): the GWModule
+ 92 selection_mod (`SelectionBase`): selection module
+ 93 loss_mod (`GWLossesBase`): module to compute the GW losses.
+ 94 optim_lr (`float`): learning rate
+ 95 optim_weight_decay (`float`): weight decay
+ 96 scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
+ 97 scheduler parameters.
+ 98 """
+ 99 super () . __init__ ()
+100 self . save_hyperparameters (
+101 ignore = [
+102 "gw_mod" ,
+103 "selection_mod" ,
+104 "domain_mods" ,
+105 "loss_mod" ,
+106 "domain_descriptions" ,
+107 "contrastive_loss" ,
+108 "cont_loss_bayesian" ,
+109 "gw_encoders" ,
+110 "gw_decoders" ,
+111 ]
+112 )
+113
+114 self . gw_mod = gw_mod
+115 """ a `GWModuleBase` implementation."""
+116
+117 self . selection_mod = selection_mod
+118 """A `SelectionBase` implementation."""
+119
+120 self . loss_mod = loss_mod
+121 """The module that computes losses of the GW"""
+122
+123 self . optim_lr = optim_lr
+124 self . optim_weight_decay = optim_weight_decay
+125 self . scheduler_args = SchedulerArgs ( max_lr = optim_lr , total_steps = 1 )
+126 if scheduler_args is not None :
+127 self . scheduler_args . update ( scheduler_args )
+128
+129 @property
+130 def domain_mods ( self ) -> Mapping [ str , DomainModule ]:
+131 return self . gw_mod . domain_mods
+132
+133 @property
+134 def workspace_dim ( self ) -> int :
+135 """Dimension of the GW."""
+136 return self . gw_mod . workspace_dim
+137
+138 def encode_and_fuse (
+139 self , x : LatentsDomainGroupsT , selection_module : SelectionBase
+140 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+141 """
+142 Encode a group of latent representations into the GW representation.
+143
+144 Args:
+145 x (`LatentsDomainGroupsT`): the input domain representations.
+146 selection_scores (`Mapping[str, torch.Tensor]`):
+147
+148 Returns:
+149 `dict[frozenset[str], torch.Tensor]`: the GW representations.
+150 """
+151 return {
+152 domains : self . gw_mod . encode_and_fuse ( latents , selection_module )
+153 for domains , latents in x . items ()
+154 }
+155
+156 def encode ( self , x : LatentsDomainGroupsT ) -> LatentsDomainGroupsDT :
+157 """
+158 Encode a group of latent representations into the pre-fusion GW representation.
+159
+160 Args:
+161 x (`LatentsDomainGroupsT`): the input domain representations.
+162
+163 Returns:
+164 `LatensDomainGroupsDT`: the GW representations.
+165 """
+166 return { domains : self . gw_mod . encode ( latents ) for domains , latents in x . items ()}
+167
+168 def fuse (
+169 self ,
+170 x : LatentsDomainGroupsT ,
+171 selection_scores : Mapping [ frozenset [ str ], Mapping [ str , torch . Tensor ]],
+172 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+173 """
+174 Fuses a group of latent representations into the GW representation.
+175
+176 Args:
+177 x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+178 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+179 selection scores for each group
+180
+181 Returns:
+182 `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+183 """
+184 return {
+185 domains : self . gw_mod . fuse ( latents , selection_scores [ domains ])
+186 for domains , latents in x . items ()
+187 }
+188
+189 def decode (
+190 self ,
+191 z : Mapping [ frozenset [ str ], torch . Tensor ],
+192 domains : Iterable [ str ] | None = None ,
+193 ) -> LatentsDomainGroupsDT :
+194 """
+195 Decode the group GW representation into given `domains`.
+196
+197 Args:
+198 z (`torch.Tensor`): the GW representation.
+199 domains (`Iterable[str]`): iterable of domains to decode.
+200
+201 Returns:
+202 `dict[str, torch.Tensor]`: the decoded unimodal representations.
+203 """
+204 return {
+205 domain_names : self . gw_mod . decode ( gw_rep , domains )
+206 for domain_names , gw_rep in z . items ()
+207 }
+208
+209 def forward ( # type: ignore
+210 self ,
+211 latent_domains : LatentsDomainGroupsT ,
+212 ) -> GWPredictionsBase :
+213 """
+214 Computes demi-cycles, cycles, and translations.
+215
+216 Args:
+217 latent_domains (`LatentsT`): Groups of domains for the computation.
+218
+219 Returns:
+220 `GWPredictionsBase`: the predictions on the batch.
+221 """
+222
+223 return GWPredictionsBase ( states = self . batch_gw_states ( latent_domains ))
+224
+225 def batch_gw_states (
+226 self , latent_domains : LatentsDomainGroupsT
+227 ) -> dict [ str , torch . Tensor ]:
+228 """
+229 Comptues GW states of a batch of groups of domains.
+230
+231 Args:
+232 latent_domains (`LatentsT`): the batch of groups of domains
+233
+234 Returns:
+235 `dict[str, torch.Tensor]`: states for each domain.
+236 """
+237 predictions : dict [ str , torch . Tensor ] = {}
+238 for domains , latents in latent_domains . items ():
+239 if len ( domains ) > 1 :
+240 continue
+241 domain_name = list ( domains )[ 0 ]
+242 z = self . gw_mod . encode_and_fuse (
+243 latents , selection_module = self . selection_mod
+244 )
+245 predictions [ domain_name ] = z
+246 return predictions
+247
+248 def encode_domain ( self , domain : Any , name : str ) -> torch . Tensor :
+249 """
+250 Encodes a domain from the domain data into the unimodal representation.
+251
+252 This is a convenient proxy for the `DomainModule.encode` method and is
+253 equivalent to:
+254 ```python
+255 self.domain_mods[name].encode(domain)
+256 ```
+257
+258 Args:
+259 domain (`Any`): the domain data
+260 name (`str`): domain name to encode
+261
+262 Returns:
+263 `torch.Tensor`: the domain's unimodal representation.
+264 """
+265 return self . domain_mods [ name ] . encode ( domain )
+266
+267 def encode_domains ( self , batch : RawDomainGroupsT ) -> LatentsDomainGroupsDT :
+268 """
+269 Encode all domains in the batch.
+270
+271 Args:
+272 batch (`RawDomainGroupsT`): the batch of
+273 domain groups with raw unimodal data to encode into groups of latent
+274 representations.
+275
+276 Returns:
+277 `LatentsDomainGroupsDT`: the domains' unimodal representations.
+278 """
+279 return {
+280 domains : {
+281 name : self . domain_mods [ name ] . encode ( domain )
+282 for name , domain in data . items ()
+283 }
+284 for domains , data in batch . items ()
+285 }
+286
+287 def decode_domain ( self , domain : torch . Tensor , name : str ) -> Any :
+288 """
+289 Decodes a domain from the unimodal representation into the domain data.
+290
+291 This is a convenient proxy for the `DomainModule.encode` method and is
+292 equivalent to:
+293 ```python
+294 self.domain_mods[name].decode(domain)
+295 ```
+296
+297 Args:
+298 domain (`torch.Tensor`): the domain data
+299 name (`str`): domain name to encode
+300
+301 Returns:
+302 `Any`: the domain's raw data.
+303 """
+304 return self . domain_mods [ name ] . decode ( domain )
+305
+306 def decode_domains ( self , latents_domain : LatentsDomainGroupsT ) -> RawDomainGroupsDT :
+307 """
+308 Decodes all domains in the batch.
+309
+310 Args:
+311 batch (`LatentsDomainGroupsT`): the batch of
+312 domain groups with unimodal latent representation to decode into
+313 groups of raw data.
+314
+315 Returns:
+316 `LatentsDomainGroupsDT`: the domains' raw data.
+317 """
+318 return {
+319 domains : {
+320 name : self . domain_mods [ name ] . decode ( domain )
+321 for name , domain in latents . items ()
+322 }
+323 for domains , latents in latents_domain . items ()
+324 }
+325
+326 def generic_step ( self , batch : RawDomainGroupsT , mode : ModelModeT ) -> torch . Tensor :
+327 """
+328 The generic step used in `training_step`, `validation_step` and
+329 `test_step`.
+330
+331 Args:
+332 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+333 mode (`ModelModeT`):
+334
+335 Returns:
+336 `torch.Tensor`: the loss to train on.
+337 """
+338 domain_latents = self . encode_domains ( batch )
+339 batch_size = groups_batch_size ( domain_latents )
+340
+341 loss_output = self . loss_mod . step ( domain_latents , mode )
+342
+343 for name , metric in loss_output . all . items ():
+344 self . log (
+345 f " { mode } / { name } " ,
+346 metric ,
+347 batch_size = batch_size ,
+348 add_dataloader_idx = False ,
+349 )
+350
+351 return loss_output . loss
+352
+353 def validation_step ( # type: ignore
+354 self , data : RawDomainGroupT , batch_idx : int , dataloader_idx : int = 0
+355 ) -> torch . Tensor :
+356 """Validation step used by lightning"""
+357
+358 batch = { frozenset ( data . keys ()): data }
+359 for domain in data :
+360 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+361 if dataloader_idx == 0 :
+362 return self . generic_step ( batch , mode = "val" )
+363 return self . generic_step ( batch , mode = "val/ood" )
+364
+365 def test_step ( # type: ignore
+366 self , data : Mapping [ str , Any ], batch_idx : int , dataloader_idx : int = 0
+367 ) -> torch . Tensor :
+368 """Test step used by lightning"""
+369
+370 batch = { frozenset ( data . keys ()): data }
+371 for domain in data :
+372 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+373 if dataloader_idx == 0 :
+374 return self . generic_step ( batch , mode = "test" )
+375 return self . generic_step ( batch , mode = "test/ood" )
+376
+377 def training_step ( # type: ignore
+378 self , batch : Mapping [ frozenset [ str ], Mapping [ str , Any ]], batch_idx : int
+379 ) -> torch . Tensor :
+380 """Training step used by lightning"""
+381
+382 return self . generic_step ( batch , mode = "train" )
+383
+384 def predict_step ( # type: ignore
+385 self , data : Mapping [ str , Any ], batch_idx : int
+386 ) -> GWPredictionsBase :
+387 """Predict step used by lightning"""
+388
+389 batch = { frozenset ( data . keys ()): data }
+390 for domain in data :
+391 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+392
+393 domain_latents = self . encode_domains ( batch )
+394 return self . forward ( domain_latents )
+395
+396 def configure_optimizers ( self ) -> OptimizerLRSchedulerConfig :
+397 """
+398 Configure models optimizers.
+399
+400 Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
+401 scheduler.
+402 """
+403
+404 optimizer = torch . optim . AdamW (
+405 self . parameters (),
+406 lr = self . optim_lr ,
+407 weight_decay = self . optim_weight_decay ,
+408 )
+409
+410 lr_scheduler = OneCycleLR ( optimizer , ** self . scheduler_args )
+411
+412 return {
+413 "optimizer" : optimizer ,
+414 "lr_scheduler" : {
+415 "scheduler" : lr_scheduler ,
+416 "interval" : "step" ,
+417 },
+418 }
+419
+420
+421 def freeze_domain_modules (
+422 domain_mods : Mapping [ str , DomainModule ],
+423 ) -> dict [ str , DomainModule ]:
+424 """
+425 Freezes weights and set to eval mode the domain modules.
+426
+427 .. note::
+428 The output is casted as `dict[str, DomainModule]` type for better
+429 auto-completion, but is actually a torch `ModuleDict`.
+430
+431 Args:
+432 domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
+433
+434 Returns:
+435 `ModuleDict`: frozen modules.
+436 """
+437
+438 for mod in domain_mods . values ():
+439 mod . freeze ()
+440 # Cast for better auto-completion at the expense of ModuleDict
+441 return cast ( dict [ str , DomainModule ], ModuleDict ( domain_mods ))
+442
+443
+444 class GWPredictions ( GWPredictionsBase ):
+445 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+446
+447 demi_cycles : dict [ str , torch . Tensor ]
+448 """
+449 Demi-cycle predictions of the model for each domain. Only computed on domain
+450 groups with only one domain.
+451 """
+452
+453 cycles : dict [ tuple [ str , str ], torch . Tensor ]
+454 """
+455 Cycle predictions of the model from one domain through another one.
+456 Only computed on domain groups with more than one domain.
+457 The keys are tuple with start domain and intermediary domain.
+458 """
+459
+460 translations : dict [ tuple [ str , str ], torch . Tensor ]
+461 """
+462 Translation predictions of the model from one domain through another one.
+463
+464 Only computed on domain groups with more than one domain.
+465 The keys are tuples with start domain and target domain.
+466 """
+467
+468
+469 class GlobalWorkspace2Domains (
+470 GlobalWorkspaceBase [ GWModule , SingleDomainSelection , GWLosses2Domains ]
+471 ):
+472 """
+473 A simple 2-domains max flavor of GlobalWorkspaceBase.
+474
+475 This is used to simplify a Global Workspace instanciation and only overrides the
+476 `__init__` method.
+477 """
+478
+479 def __init__ (
+480 self ,
+481 domain_mods : Mapping [ str , DomainModule ],
+482 gw_encoders : Mapping [ str , Module ],
+483 gw_decoders : Mapping [ str , Module ],
+484 workspace_dim : int ,
+485 loss_coefs : LossCoefs ,
+486 optim_lr : float = 1e-3 ,
+487 optim_weight_decay : float = 0.0 ,
+488 scheduler_args : SchedulerArgs | None = None ,
+489 learn_logit_scale : bool = False ,
+490 contrastive_loss : ContrastiveLossType | None = None ,
+491 ) -> None :
+492 """
+493 Initializes a Global Workspace
+494
+495 Args:
+496 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+497 connected to the GW. Keys are domain names, values are the
+498 `DomainModule`.
+499 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+500 name to a `torch.nn.Module` class which role is to encode a
+501 unimodal latent representations into a GW representation (pre fusion).
+502 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+503 name to a `torch.nn.Module` class which role is to decode a
+504 GW representation into a unimodal latent representations.
+505 workspace_dim (`int`): dimension of the GW.
+506 loss_coefs (`LossCoefs`): loss coefficients
+507 optim_lr (`float`): learning rate
+508 optim_weight_decay (`float`): weight decay
+509 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+510 learn_logit_scale (`bool`): whether to learn the contrastive learning
+511 contrastive loss when using the default contrastive loss.
+512 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+513 function used for alignment. `learn_logit_scale` will not affect custom
+514 contrastive losses.
+515 """
+516 domain_mods = freeze_domain_modules ( domain_mods )
+517
+518 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+519 if contrastive_loss is None :
+520 contrastive_loss = ContrastiveLoss (
+521 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+522 )
+523 selection_mod = SingleDomainSelection ()
+524 loss_mod = GWLosses2Domains (
+525 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+526 )
+527
+528 super () . __init__ (
+529 gw_mod ,
+530 selection_mod ,
+531 loss_mod ,
+532 optim_lr ,
+533 optim_weight_decay ,
+534 scheduler_args ,
+535 )
+536
+537 def forward ( # type: ignore
+538 self ,
+539 latent_domains : LatentsDomainGroupsT ,
+540 ) -> GWPredictions :
+541 """
+542 Computes demi-cycles, cycles, and translations.
+543
+544 Args:
+545 latent_domains (`LatentsT`): Groups of domains for the computation.
+546
+547 Returns:
+548 `GWPredictions`: the predictions on the batch.
+549 """
+550 return GWPredictions (
+551 demi_cycles = batch_demi_cycles (
+552 self . gw_mod , self . selection_mod , latent_domains
+553 ),
+554 cycles = batch_cycles (
+555 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+556 ),
+557 translations = batch_translations (
+558 self . gw_mod , self . selection_mod , latent_domains
+559 ),
+560 ** super () . forward ( latent_domains ),
+561 )
+562
+563
+564 class GlobalWorkspace ( GlobalWorkspaceBase [ GWModule , RandomSelection , GWLosses ]):
+565 """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
+566
+567 This is used to simplify a Global Workspace instanciation and only overrides the
+568 `__init__` method.
+569 """
+570
+571 def __init__ (
+572 self ,
+573 domain_mods : Mapping [ str , DomainModule ],
+574 gw_encoders : Mapping [ str , Module ],
+575 gw_decoders : Mapping [ str , Module ],
+576 workspace_dim : int ,
+577 loss_coefs : BroadcastLossCoefs ,
+578 selection_temperature : float = 0.2 ,
+579 optim_lr : float = 1e-3 ,
+580 optim_weight_decay : float = 0.0 ,
+581 scheduler_args : SchedulerArgs | None = None ,
+582 learn_logit_scale : bool = False ,
+583 contrastive_loss : ContrastiveLossType | None = None ,
+584 ) -> None :
+585 """
+586 Initializes a Global Workspace
+587
+588 Args:
+589 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+590 connected to the GW. Keys are domain names, values are the
+591 `DomainModule`.
+592 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+593 name to a `torch.nn.Module` class which role is to encode a
+594 unimodal latent representations into a GW representation (pre fusion).
+595 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+596 name to a `torch.nn.Module` class which role is to decode a
+597 GW representation into a unimodal latent representations.
+598 workspace_dim (`int`): dimension of the GW.
+599 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+600 selection_temperature (`float`): temperature value for the RandomSelection
+601 module.
+602 optim_lr (`float`): learning rate
+603 optim_weight_decay (`float`): weight decay
+604 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+605 learn_logit_scale (`bool`): whether to learn the contrastive learning
+606 contrastive loss when using the default contrastive loss.
+607 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+608 function used for alignment. `learn_logit_scale` will not affect custom
+609 contrastive losses.
+610 """
+611 domain_mods = freeze_domain_modules ( domain_mods )
+612 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+613
+614 if contrastive_loss is None :
+615 contrastive_loss = ContrastiveLoss (
+616 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+617 )
+618
+619 selection_mod = RandomSelection ( selection_temperature )
+620 loss_mod = GWLosses (
+621 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+622 )
+623
+624 super () . __init__ (
+625 gw_mod ,
+626 selection_mod ,
+627 loss_mod ,
+628 optim_lr ,
+629 optim_weight_decay ,
+630 scheduler_args ,
+631 )
+632
+633 def forward ( # type: ignore
+634 self ,
+635 latent_domains : LatentsDomainGroupsT ,
+636 ) -> GWPredictions :
+637 """
+638 Computes demi-cycles, cycles, and translations.
+639
+640 Args:
+641 latent_domains (`LatentsT`): Groups of domains for the computation.
+642
+643 Returns:
+644 `GWPredictions`: the predictions on the batch.
+645 """
+646 return GWPredictions (
+647 demi_cycles = batch_demi_cycles (
+648 self . gw_mod , self . selection_mod , latent_domains
+649 ),
+650 cycles = batch_cycles (
+651 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+652 ),
+653 translations = batch_translations (
+654 self . gw_mod , self . selection_mod , latent_domains
+655 ),
+656 # TODO: add other combinations
+657 ** super () . forward ( latent_domains ),
+658 )
+659
+660
+661 class GlobalWorkspaceBayesian (
+662 GlobalWorkspaceBase [ GWModuleBayesian , FixedSharedSelection , GWLossesBayesian ]
+663 ):
+664 """
+665 A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
+666 prediction.
+667
+668 This is used to simplify a Global Workspace instanciation and only overrides the
+669 `__init__` method.
+670 """
+671
+672 def __init__ (
+673 self ,
+674 domain_mods : Mapping [ str , DomainModule ],
+675 gw_encoders : Mapping [ str , Module ],
+676 gw_decoders : Mapping [ str , Module ],
+677 workspace_dim : int ,
+678 loss_coefs : BroadcastLossCoefs ,
+679 sensitivity_selection : float = 1 ,
+680 sensitivity_precision : float = 1 ,
+681 optim_lr : float = 1e-3 ,
+682 optim_weight_decay : float = 0.0 ,
+683 scheduler_args : SchedulerArgs | None = None ,
+684 learn_logit_scale : bool = False ,
+685 use_normalized_constrastive : bool = True ,
+686 contrastive_loss : ContrastiveLossType | None = None ,
+687 precision_softmax_temp : float = 0.01 ,
+688 ) -> None :
+689 """
+690 Initializes a Global Workspace
+691
+692 Args:
+693 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+694 connected to the GW. Keys are domain names, values are the
+695 `DomainModule`.
+696 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+697 name to a `torch.nn.Module` class which role is to encode a
+698 unimodal latent representations into a GW representation (pre fusion).
+699 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+700 name to a `torch.nn.Module` class which role is to decode a
+701 GW representation into a unimodal latent representations.
+702 workspace_dim (`int`): dimension of the GW.
+703 loss_coefs (`LossCoefs`): loss coefficients
+704 sensitivity_selection (`float`): sensivity coef $c'_1$
+705 sensitivity_precision (`float`): sensitivity coef $c'_2$
+706 optim_lr (`float`): learning rate
+707 optim_weight_decay (`float`): weight decay
+708 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+709 learn_logit_scale (`bool`): whether to learn the contrastive learning
+710 contrastive loss when using the default contrastive loss.
+711 use_normalized_constrastive (`bool`): whether to use the normalized cont
+712 loss by the precision coefs
+713 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+714 function used for alignment. `learn_logit_scale` will not affect custom
+715 contrastive losses.
+716 precision_softmax_temp (`float`): temperature to use in softmax of
+717 precision
+718 """
+719 domain_mods = freeze_domain_modules ( domain_mods )
+720
+721 gw_mod = GWModuleBayesian (
+722 domain_mods ,
+723 workspace_dim ,
+724 gw_encoders ,
+725 gw_decoders ,
+726 sensitivity_selection ,
+727 sensitivity_precision ,
+728 precision_softmax_temp ,
+729 )
+730
+731 selection_mod = FixedSharedSelection ()
+732
+733 contrastive_loss = ContrastiveLoss (
+734 torch . tensor ([ 1 ]) . log (), "mean" , learn_logit_scale
+735 )
+736
+737 loss_mod = GWLossesBayesian (
+738 gw_mod ,
+739 selection_mod ,
+740 domain_mods ,
+741 loss_coefs ,
+742 contrastive_loss ,
+743 use_normalized_constrastive ,
+744 )
+745
+746 super () . __init__ (
+747 gw_mod ,
+748 selection_mod ,
+749 loss_mod ,
+750 optim_lr ,
+751 optim_weight_decay ,
+752 scheduler_args ,
+753 )
+754
+755 def forward ( # type: ignore
+756 self ,
+757 latent_domains : LatentsDomainGroupsT ,
+758 ) -> GWPredictions :
+759 """
+760 Computes demi-cycles, cycles, and translations.
+761
+762 Args:
+763 latent_domains (`LatentsT`): Groups of domains for the computation.
+764
+765 Returns:
+766 `GWPredictions`: the predictions on the batch.
+767 """
+768 return GWPredictions (
+769 demi_cycles = batch_demi_cycles (
+770 self . gw_mod , self . selection_mod , latent_domains
+771 ),
+772 cycles = batch_cycles (
+773 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+774 ),
+775 translations = batch_translations (
+776 self . gw_mod , self . selection_mod , latent_domains
+777 ),
+778 ** super () . forward ( latent_domains ),
+779 )
+780
+781
+782 def pretrained_global_workspace (
+783 checkpoint_path : str | Path ,
+784 domain_mods : Mapping [ str , DomainModule ],
+785 gw_encoders : Mapping [ str , Module ],
+786 gw_decoders : Mapping [ str , Module ],
+787 workspace_dim : int ,
+788 loss_coefs : LossCoefs ,
+789 contrastive_fn : ContrastiveLossType ,
+790 ** kwargs ,
+791 ) -> GlobalWorkspace2Domains :
+792 """
+793 Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
+794
+795 Args:
+796 checkpoint_path (`str | Path`): path to checkpoint
+797 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+798 connected to the GW. Keys are domain names, values are the
+799 `DomainModule`.
+800 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+801 name to a `torch.nn.Module` class which role is to encode a
+802 unimodal latent representations into a GW representation (pre fusion).
+803 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+804 name to a `torch.nn.Module` class which role is to decode a
+805 GW representation into a unimodal latent representations.
+806 workspace_dim (`int`): dimension of the GW.
+807 loss_coefs (`LossCoefs`): loss coefficients
+808 contrastive_loss (`ContrastiveLossType`): a contrastive loss
+809 function used for alignment. `learn_logit_scale` will not affect custom
+810 contrastive losses.
+811 **kwargs: additional arguments to pass to
+812 `GlobalWorkspace.load_from_checkpoint`.
+813
+814 Returns:
+815 `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
+816
+817 Raises:
+818 `TypeError`: if loaded type is not `GlobalWorkspace`.
+819 """
+820 domain_mods = freeze_domain_modules ( domain_mods )
+821 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+822 selection_mod = SingleDomainSelection ()
+823 loss_mod = GWLosses2Domains (
+824 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_fn
+825 )
+826
+827 gw = GlobalWorkspace2Domains . load_from_checkpoint (
+828 checkpoint_path ,
+829 gw_mod = gw_mod ,
+830 selection_mid = selection_mod ,
+831 loss_coefs = loss_coefs ,
+832 loss_mod = loss_mod ,
+833 ** kwargs ,
+834 )
+835 if not isinstance ( gw , GlobalWorkspace2Domains ):
+836 raise TypeError ( "model should be of type GlobalWorkspace" )
+837 return gw
+
+
+
+
+
+
+
+
+ class
+ SchedulerArgs (typing.TypedDict ):
+
+ View Source
+
+
+
+ 45 class SchedulerArgs ( TypedDict , total = False ):
+46 """TypedDict of arguments passed to the OneCycle scheduler"""
+47
+48 max_lr : float
+49 """Maximum learning rate"""
+50
+51 total_steps : int
+52 """Total number of steps"""
+
+
+
+ TypedDict of arguments passed to the OneCycle scheduler
+
+
+
+
+
+ max_lr : float
+
+
+
+
+
+
+
+
+
+
+
+ total_steps : int
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ class
+ GWPredictionsBase (typing.TypedDict ):
+
+ View Source
+
+
+
+ 55 class GWPredictionsBase ( TypedDict ):
+56 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+57
+58 states : dict [ str , torch . Tensor ]
+59 """
+60 GW state representation from domain groups with only one domain.
+61 The key represent the domain's name.
+62 """
+
+
+
+ TypedDict of the output given when calling GlobalWorkspaceBase.predict
+
+
+
+
+
+ states : dict[str, torch.Tensor]
+
+
+
+
+
+
GW state representation from domain groups with only one domain.
+The key represent the domain's name.
+
+
+
+
+
+
+
+
+
+ class
+ GlobalWorkspaceBase (typing.Generic[~_T_gw_mod, ~_T_selection_mod, ~_T_loss_mod] , lightning.pytorch.core.module.LightningModule ):
+
+ View Source
+
+
+
+ 70 class GlobalWorkspaceBase (
+ 71 Generic [ _T_gw_mod , _T_selection_mod , _T_loss_mod ], LightningModule
+ 72 ):
+ 73 """
+ 74 Global Workspace Lightning Module.
+ 75
+ 76 This is the base class to build the Global Workspace.
+ 77 """
+ 78
+ 79 def __init__ (
+ 80 self ,
+ 81 gw_mod : _T_gw_mod ,
+ 82 selection_mod : _T_selection_mod ,
+ 83 loss_mod : _T_loss_mod ,
+ 84 optim_lr : float = 1e-3 ,
+ 85 optim_weight_decay : float = 0.0 ,
+ 86 scheduler_args : SchedulerArgs | None = None ,
+ 87 ) -> None :
+ 88 """
+ 89 Initializes a GW
+ 90
+ 91 Args:
+ 92 gw_mod (`GWModuleBase`): the GWModule
+ 93 selection_mod (`SelectionBase`): selection module
+ 94 loss_mod (`GWLossesBase`): module to compute the GW losses.
+ 95 optim_lr (`float`): learning rate
+ 96 optim_weight_decay (`float`): weight decay
+ 97 scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
+ 98 scheduler parameters.
+ 99 """
+100 super () . __init__ ()
+101 self . save_hyperparameters (
+102 ignore = [
+103 "gw_mod" ,
+104 "selection_mod" ,
+105 "domain_mods" ,
+106 "loss_mod" ,
+107 "domain_descriptions" ,
+108 "contrastive_loss" ,
+109 "cont_loss_bayesian" ,
+110 "gw_encoders" ,
+111 "gw_decoders" ,
+112 ]
+113 )
+114
+115 self . gw_mod = gw_mod
+116 """ a `GWModuleBase` implementation."""
+117
+118 self . selection_mod = selection_mod
+119 """A `SelectionBase` implementation."""
+120
+121 self . loss_mod = loss_mod
+122 """The module that computes losses of the GW"""
+123
+124 self . optim_lr = optim_lr
+125 self . optim_weight_decay = optim_weight_decay
+126 self . scheduler_args = SchedulerArgs ( max_lr = optim_lr , total_steps = 1 )
+127 if scheduler_args is not None :
+128 self . scheduler_args . update ( scheduler_args )
+129
+130 @property
+131 def domain_mods ( self ) -> Mapping [ str , DomainModule ]:
+132 return self . gw_mod . domain_mods
+133
+134 @property
+135 def workspace_dim ( self ) -> int :
+136 """Dimension of the GW."""
+137 return self . gw_mod . workspace_dim
+138
+139 def encode_and_fuse (
+140 self , x : LatentsDomainGroupsT , selection_module : SelectionBase
+141 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+142 """
+143 Encode a group of latent representations into the GW representation.
+144
+145 Args:
+146 x (`LatentsDomainGroupsT`): the input domain representations.
+147 selection_scores (`Mapping[str, torch.Tensor]`):
+148
+149 Returns:
+150 `dict[frozenset[str], torch.Tensor]`: the GW representations.
+151 """
+152 return {
+153 domains : self . gw_mod . encode_and_fuse ( latents , selection_module )
+154 for domains , latents in x . items ()
+155 }
+156
+157 def encode ( self , x : LatentsDomainGroupsT ) -> LatentsDomainGroupsDT :
+158 """
+159 Encode a group of latent representations into the pre-fusion GW representation.
+160
+161 Args:
+162 x (`LatentsDomainGroupsT`): the input domain representations.
+163
+164 Returns:
+165 `LatensDomainGroupsDT`: the GW representations.
+166 """
+167 return { domains : self . gw_mod . encode ( latents ) for domains , latents in x . items ()}
+168
+169 def fuse (
+170 self ,
+171 x : LatentsDomainGroupsT ,
+172 selection_scores : Mapping [ frozenset [ str ], Mapping [ str , torch . Tensor ]],
+173 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+174 """
+175 Fuses a group of latent representations into the GW representation.
+176
+177 Args:
+178 x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+179 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+180 selection scores for each group
+181
+182 Returns:
+183 `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+184 """
+185 return {
+186 domains : self . gw_mod . fuse ( latents , selection_scores [ domains ])
+187 for domains , latents in x . items ()
+188 }
+189
+190 def decode (
+191 self ,
+192 z : Mapping [ frozenset [ str ], torch . Tensor ],
+193 domains : Iterable [ str ] | None = None ,
+194 ) -> LatentsDomainGroupsDT :
+195 """
+196 Decode the group GW representation into given `domains`.
+197
+198 Args:
+199 z (`torch.Tensor`): the GW representation.
+200 domains (`Iterable[str]`): iterable of domains to decode.
+201
+202 Returns:
+203 `dict[str, torch.Tensor]`: the decoded unimodal representations.
+204 """
+205 return {
+206 domain_names : self . gw_mod . decode ( gw_rep , domains )
+207 for domain_names , gw_rep in z . items ()
+208 }
+209
+210 def forward ( # type: ignore
+211 self ,
+212 latent_domains : LatentsDomainGroupsT ,
+213 ) -> GWPredictionsBase :
+214 """
+215 Computes demi-cycles, cycles, and translations.
+216
+217 Args:
+218 latent_domains (`LatentsT`): Groups of domains for the computation.
+219
+220 Returns:
+221 `GWPredictionsBase`: the predictions on the batch.
+222 """
+223
+224 return GWPredictionsBase ( states = self . batch_gw_states ( latent_domains ))
+225
+226 def batch_gw_states (
+227 self , latent_domains : LatentsDomainGroupsT
+228 ) -> dict [ str , torch . Tensor ]:
+229 """
+230 Comptues GW states of a batch of groups of domains.
+231
+232 Args:
+233 latent_domains (`LatentsT`): the batch of groups of domains
+234
+235 Returns:
+236 `dict[str, torch.Tensor]`: states for each domain.
+237 """
+238 predictions : dict [ str , torch . Tensor ] = {}
+239 for domains , latents in latent_domains . items ():
+240 if len ( domains ) > 1 :
+241 continue
+242 domain_name = list ( domains )[ 0 ]
+243 z = self . gw_mod . encode_and_fuse (
+244 latents , selection_module = self . selection_mod
+245 )
+246 predictions [ domain_name ] = z
+247 return predictions
+248
+249 def encode_domain ( self , domain : Any , name : str ) -> torch . Tensor :
+250 """
+251 Encodes a domain from the domain data into the unimodal representation.
+252
+253 This is a convenient proxy for the `DomainModule.encode` method and is
+254 equivalent to:
+255 ```python
+256 self.domain_mods[name].encode(domain)
+257 ```
+258
+259 Args:
+260 domain (`Any`): the domain data
+261 name (`str`): domain name to encode
+262
+263 Returns:
+264 `torch.Tensor`: the domain's unimodal representation.
+265 """
+266 return self . domain_mods [ name ] . encode ( domain )
+267
+268 def encode_domains ( self , batch : RawDomainGroupsT ) -> LatentsDomainGroupsDT :
+269 """
+270 Encode all domains in the batch.
+271
+272 Args:
+273 batch (`RawDomainGroupsT`): the batch of
+274 domain groups with raw unimodal data to encode into groups of latent
+275 representations.
+276
+277 Returns:
+278 `LatentsDomainGroupsDT`: the domains' unimodal representations.
+279 """
+280 return {
+281 domains : {
+282 name : self . domain_mods [ name ] . encode ( domain )
+283 for name , domain in data . items ()
+284 }
+285 for domains , data in batch . items ()
+286 }
+287
+288 def decode_domain ( self , domain : torch . Tensor , name : str ) -> Any :
+289 """
+290 Decodes a domain from the unimodal representation into the domain data.
+291
+292 This is a convenient proxy for the `DomainModule.encode` method and is
+293 equivalent to:
+294 ```python
+295 self.domain_mods[name].decode(domain)
+296 ```
+297
+298 Args:
+299 domain (`torch.Tensor`): the domain data
+300 name (`str`): domain name to encode
+301
+302 Returns:
+303 `Any`: the domain's raw data.
+304 """
+305 return self . domain_mods [ name ] . decode ( domain )
+306
+307 def decode_domains ( self , latents_domain : LatentsDomainGroupsT ) -> RawDomainGroupsDT :
+308 """
+309 Decodes all domains in the batch.
+310
+311 Args:
+312 batch (`LatentsDomainGroupsT`): the batch of
+313 domain groups with unimodal latent representation to decode into
+314 groups of raw data.
+315
+316 Returns:
+317 `LatentsDomainGroupsDT`: the domains' raw data.
+318 """
+319 return {
+320 domains : {
+321 name : self . domain_mods [ name ] . decode ( domain )
+322 for name , domain in latents . items ()
+323 }
+324 for domains , latents in latents_domain . items ()
+325 }
+326
+327 def generic_step ( self , batch : RawDomainGroupsT , mode : ModelModeT ) -> torch . Tensor :
+328 """
+329 The generic step used in `training_step`, `validation_step` and
+330 `test_step`.
+331
+332 Args:
+333 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+334 mode (`ModelModeT`):
+335
+336 Returns:
+337 `torch.Tensor`: the loss to train on.
+338 """
+339 domain_latents = self . encode_domains ( batch )
+340 batch_size = groups_batch_size ( domain_latents )
+341
+342 loss_output = self . loss_mod . step ( domain_latents , mode )
+343
+344 for name , metric in loss_output . all . items ():
+345 self . log (
+346 f " { mode } / { name } " ,
+347 metric ,
+348 batch_size = batch_size ,
+349 add_dataloader_idx = False ,
+350 )
+351
+352 return loss_output . loss
+353
+354 def validation_step ( # type: ignore
+355 self , data : RawDomainGroupT , batch_idx : int , dataloader_idx : int = 0
+356 ) -> torch . Tensor :
+357 """Validation step used by lightning"""
+358
+359 batch = { frozenset ( data . keys ()): data }
+360 for domain in data :
+361 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+362 if dataloader_idx == 0 :
+363 return self . generic_step ( batch , mode = "val" )
+364 return self . generic_step ( batch , mode = "val/ood" )
+365
+366 def test_step ( # type: ignore
+367 self , data : Mapping [ str , Any ], batch_idx : int , dataloader_idx : int = 0
+368 ) -> torch . Tensor :
+369 """Test step used by lightning"""
+370
+371 batch = { frozenset ( data . keys ()): data }
+372 for domain in data :
+373 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+374 if dataloader_idx == 0 :
+375 return self . generic_step ( batch , mode = "test" )
+376 return self . generic_step ( batch , mode = "test/ood" )
+377
+378 def training_step ( # type: ignore
+379 self , batch : Mapping [ frozenset [ str ], Mapping [ str , Any ]], batch_idx : int
+380 ) -> torch . Tensor :
+381 """Training step used by lightning"""
+382
+383 return self . generic_step ( batch , mode = "train" )
+384
+385 def predict_step ( # type: ignore
+386 self , data : Mapping [ str , Any ], batch_idx : int
+387 ) -> GWPredictionsBase :
+388 """Predict step used by lightning"""
+389
+390 batch = { frozenset ( data . keys ()): data }
+391 for domain in data :
+392 batch [ frozenset ([ domain ])] = { domain : data [ domain ]}
+393
+394 domain_latents = self . encode_domains ( batch )
+395 return self . forward ( domain_latents )
+396
+397 def configure_optimizers ( self ) -> OptimizerLRSchedulerConfig :
+398 """
+399 Configure models optimizers.
+400
+401 Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
+402 scheduler.
+403 """
+404
+405 optimizer = torch . optim . AdamW (
+406 self . parameters (),
+407 lr = self . optim_lr ,
+408 weight_decay = self . optim_weight_decay ,
+409 )
+410
+411 lr_scheduler = OneCycleLR ( optimizer , ** self . scheduler_args )
+412
+413 return {
+414 "optimizer" : optimizer ,
+415 "lr_scheduler" : {
+416 "scheduler" : lr_scheduler ,
+417 "interval" : "step" ,
+418 },
+419 }
+
+
+
+ Global Workspace Lightning Module.
+
+
This is the base class to build the Global Workspace.
+
+
+
+
+
+ gw_mod
+
+
+
+
+
+
a GWModuleBase
implementation.
+
+
+
+
+
+
+ selection_mod
+
+
+
+
+
+
A SelectionBase
implementation.
+
+
+
+
+
+
+ loss_mod
+
+
+
+
+
+
The module that computes losses of the GW
+
+
+
+
+
+
+ optim_lr
+
+
+
+
+
+
+
+
+
+
+ optim_weight_decay
+
+
+
+
+
+
+
+
+
+
+ scheduler_args
+
+
+
+
+
+
+
+
+
+
+
+
+
130 @property
+131 def domain_mods ( self ) -> Mapping [ str , DomainModule ]:
+132 return self . gw_mod . domain_mods
+
+
+
+
+
+
+
+
+
+ workspace_dim : int
+
+ View Source
+
+
+
+
134 @property
+135 def workspace_dim ( self ) -> int :
+136 """Dimension of the GW."""
+137 return self . gw_mod . workspace_dim
+
+
+
+
+
+
+
+
+
+
+
+
def
+
encode_and_fuse ( self , x : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , selection_module : shimmer.modules.selection.SelectionBase ) -> dict [ frozenset [ str ], torch . Tensor ] :
+
+
View Source
+
+
+
+
139 def encode_and_fuse (
+140 self , x : LatentsDomainGroupsT , selection_module : SelectionBase
+141 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+142 """
+143 Encode a group of latent representations into the GW representation.
+144
+145 Args:
+146 x (`LatentsDomainGroupsT`): the input domain representations.
+147 selection_scores (`Mapping[str, torch.Tensor]`):
+148
+149 Returns:
+150 `dict[frozenset[str], torch.Tensor]`: the GW representations.
+151 """
+152 return {
+153 domains : self . gw_mod . encode_and_fuse ( latents , selection_module )
+154 for domains , latents in x . items ()
+155 }
+
+
+
+
Encode a group of latent representations into the GW representation.
+
+
Arguments:
+
+
+x (LatentsDomainGroupsT
): the input domain representations.
+selection_scores (Mapping[str, torch.Tensor]
):
+
+
+
Returns:
+
+
+ dict[frozenset[str], torch.Tensor]
: the GW representations.
+
+
+
+
+
+
+
+
+
+ def
+ encode ( self , x : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] :
+
+ View Source
+
+
+
+
157 def encode ( self , x : LatentsDomainGroupsT ) -> LatentsDomainGroupsDT :
+158 """
+159 Encode a group of latent representations into the pre-fusion GW representation.
+160
+161 Args:
+162 x (`LatentsDomainGroupsT`): the input domain representations.
+163
+164 Returns:
+165 `LatensDomainGroupsDT`: the GW representations.
+166 """
+167 return { domains : self . gw_mod . encode ( latents ) for domains , latents in x . items ()}
+
+
+
+
Encode a group of latent representations into the pre-fusion GW representation.
+
+
Arguments:
+
+
+x (LatentsDomainGroupsT
): the input domain representations.
+
+
+
Returns:
+
+
+ LatensDomainGroupsDT
: the GW representations.
+
+
+
+
+
+
+
+
+
+ def
+ fuse ( self , x : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , selection_scores : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], torch . Tensor ] :
+
+ View Source
+
+
+
+
169 def fuse (
+170 self ,
+171 x : LatentsDomainGroupsT ,
+172 selection_scores : Mapping [ frozenset [ str ], Mapping [ str , torch . Tensor ]],
+173 ) -> dict [ frozenset [ str ], torch . Tensor ]:
+174 """
+175 Fuses a group of latent representations into the GW representation.
+176
+177 Args:
+178 x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+179 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+180 selection scores for each group
+181
+182 Returns:
+183 `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+184 """
+185 return {
+186 domains : self . gw_mod . fuse ( latents , selection_scores [ domains ])
+187 for domains , latents in x . items ()
+188 }
+
+
+
+
Fuses a group of latent representations into the GW representation.
+
+
Arguments:
+
+
+x (LatentsDomainGroupsT
): the pre-fusion latent representations
+selection_scores (Mapping[frozenset[str], Mapping[str, torch.Tensor]]
): selection scores for each group
+
+
+
Returns:
+
+
+ dict[frozenset[str], torch.Tensor]
: GW representation of each group
+
+
+
+
+
+
+
+
+
+ def
+ decode ( self , z : collections . abc . Mapping [ frozenset [ str ], torch . Tensor ] , domains : collections . abc . Iterable [ str ] | None = None ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] :
+
+ View Source
+
+
+
+
190 def decode (
+191 self ,
+192 z : Mapping [ frozenset [ str ], torch . Tensor ],
+193 domains : Iterable [ str ] | None = None ,
+194 ) -> LatentsDomainGroupsDT :
+195 """
+196 Decode the group GW representation into given `domains`.
+197
+198 Args:
+199 z (`torch.Tensor`): the GW representation.
+200 domains (`Iterable[str]`): iterable of domains to decode.
+201
+202 Returns:
+203 `dict[str, torch.Tensor]`: the decoded unimodal representations.
+204 """
+205 return {
+206 domain_names : self . gw_mod . decode ( gw_rep , domains )
+207 for domain_names , gw_rep in z . items ()
+208 }
+
+
+
+
Decode the group GW representation into given domains
.
+
+
Arguments:
+
+
+z (torch.Tensor
): the GW representation.
+domains (Iterable[str]
): iterable of domains to decode.
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: the decoded unimodal representations.
+
+
+
+
+
+
+
+
+
+ def
+ batch_gw_states ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
226 def batch_gw_states (
+227 self , latent_domains : LatentsDomainGroupsT
+228 ) -> dict [ str , torch . Tensor ]:
+229 """
+230 Comptues GW states of a batch of groups of domains.
+231
+232 Args:
+233 latent_domains (`LatentsT`): the batch of groups of domains
+234
+235 Returns:
+236 `dict[str, torch.Tensor]`: states for each domain.
+237 """
+238 predictions : dict [ str , torch . Tensor ] = {}
+239 for domains , latents in latent_domains . items ():
+240 if len ( domains ) > 1 :
+241 continue
+242 domain_name = list ( domains )[ 0 ]
+243 z = self . gw_mod . encode_and_fuse (
+244 latents , selection_module = self . selection_mod
+245 )
+246 predictions [ domain_name ] = z
+247 return predictions
+
+
+
+
Comptues GW states of a batch of groups of domains.
+
+
Arguments:
+
+
+latent_domains (LatentsT
): the batch of groups of domains
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: states for each domain.
+
+
+
+
+
+
+
+
+
+ def
+ encode_domain (self , domain : Any , name : str ) -> torch . Tensor :
+
+ View Source
+
+
+
+
249 def encode_domain ( self , domain : Any , name : str ) -> torch . Tensor :
+250 """
+251 Encodes a domain from the domain data into the unimodal representation.
+252
+253 This is a convenient proxy for the `DomainModule.encode` method and is
+254 equivalent to:
+255 ```python
+256 self.domain_mods[name].encode(domain)
+257 ```
+258
+259 Args:
+260 domain (`Any`): the domain data
+261 name (`str`): domain name to encode
+262
+263 Returns:
+264 `torch.Tensor`: the domain's unimodal representation.
+265 """
+266 return self . domain_mods [ name ] . encode ( domain )
+
+
+
+
Encodes a domain from the domain data into the unimodal representation.
+
+
This is a convenient proxy for the DomainModule.encode
method and is
+equivalent to:
+
+
+
self . domain_mods [ name ] . encode ( domain )
+
+
+
+
Arguments:
+
+
+domain (Any
): the domain data
+name (str
): domain name to encode
+
+
+
Returns:
+
+
+ torch.Tensor
: the domain's unimodal representation.
+
+
+
+
+
+
+
+
+
+ def
+ encode_domains ( self , batch : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , typing . Any ]] ) -> dict [ frozenset [ str ], dict [ str , torch . Tensor ]] :
+
+ View Source
+
+
+
+
268 def encode_domains ( self , batch : RawDomainGroupsT ) -> LatentsDomainGroupsDT :
+269 """
+270 Encode all domains in the batch.
+271
+272 Args:
+273 batch (`RawDomainGroupsT`): the batch of
+274 domain groups with raw unimodal data to encode into groups of latent
+275 representations.
+276
+277 Returns:
+278 `LatentsDomainGroupsDT`: the domains' unimodal representations.
+279 """
+280 return {
+281 domains : {
+282 name : self . domain_mods [ name ] . encode ( domain )
+283 for name , domain in data . items ()
+284 }
+285 for domains , data in batch . items ()
+286 }
+
+
+
+
Encode all domains in the batch.
+
+
Arguments:
+
+
+batch (RawDomainGroupsT
): the batch of
+domain groups with raw unimodal data to encode into groups of latent
+representations.
+
+
+
Returns:
+
+
+ LatentsDomainGroupsDT
: the domains' unimodal representations.
+
+
+
+
+
+
+
+
+
+ def
+ decode_domain (self , domain : torch . Tensor , name : str ) -> Any :
+
+ View Source
+
+
+
+
288 def decode_domain ( self , domain : torch . Tensor , name : str ) -> Any :
+289 """
+290 Decodes a domain from the unimodal representation into the domain data.
+291
+292 This is a convenient proxy for the `DomainModule.encode` method and is
+293 equivalent to:
+294 ```python
+295 self.domain_mods[name].decode(domain)
+296 ```
+297
+298 Args:
+299 domain (`torch.Tensor`): the domain data
+300 name (`str`): domain name to encode
+301
+302 Returns:
+303 `Any`: the domain's raw data.
+304 """
+305 return self . domain_mods [ name ] . decode ( domain )
+
+
+
+
Decodes a domain from the unimodal representation into the domain data.
+
+
This is a convenient proxy for the DomainModule.encode
method and is
+equivalent to:
+
+
+
self . domain_mods [ name ] . decode ( domain )
+
+
+
+
Arguments:
+
+
+domain (torch.Tensor
): the domain data
+name (str
): domain name to encode
+
+
+
Returns:
+
+
+ Any
: the domain's raw data.
+
+
+
+
+
+
+
+
+
+ def
+ decode_domains ( self , latents_domain : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ frozenset [ str ], dict [ str , typing . Any ]] :
+
+ View Source
+
+
+
+
307 def decode_domains ( self , latents_domain : LatentsDomainGroupsT ) -> RawDomainGroupsDT :
+308 """
+309 Decodes all domains in the batch.
+310
+311 Args:
+312 batch (`LatentsDomainGroupsT`): the batch of
+313 domain groups with unimodal latent representation to decode into
+314 groups of raw data.
+315
+316 Returns:
+317 `LatentsDomainGroupsDT`: the domains' raw data.
+318 """
+319 return {
+320 domains : {
+321 name : self . domain_mods [ name ] . decode ( domain )
+322 for name , domain in latents . items ()
+323 }
+324 for domains , latents in latents_domain . items ()
+325 }
+
+
+
+
Decodes all domains in the batch.
+
+
Arguments:
+
+
+batch (LatentsDomainGroupsT
): the batch of
+domain groups with unimodal latent representation to decode into
+groups of raw data.
+
+
+
Returns:
+
+
+ LatentsDomainGroupsDT
: the domains' raw data.
+
+
+
+
+
+
+
+
+
+ def
+ generic_step ( self , batch : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , typing . Any ]] , mode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> torch . Tensor :
+
+ View Source
+
+
+
+
327 def generic_step ( self , batch : RawDomainGroupsT , mode : ModelModeT ) -> torch . Tensor :
+328 """
+329 The generic step used in `training_step`, `validation_step` and
+330 `test_step`.
+331
+332 Args:
+333 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+334 mode (`ModelModeT`):
+335
+336 Returns:
+337 `torch.Tensor`: the loss to train on.
+338 """
+339 domain_latents = self . encode_domains ( batch )
+340 batch_size = groups_batch_size ( domain_latents )
+341
+342 loss_output = self . loss_mod . step ( domain_latents , mode )
+343
+344 for name , metric in loss_output . all . items ():
+345 self . log (
+346 f " { mode } / { name } " ,
+347 metric ,
+348 batch_size = batch_size ,
+349 add_dataloader_idx = False ,
+350 )
+351
+352 return loss_output . loss
+
+
+
+
The generic step used in training_step
, validation_step
and
+test_step
.
+
+
Arguments:
+
+
+batch (RawDomainGroupsT
): the batch of groups of raw unimodal data.
+mode (ModelModeT
):
+
+
+
Returns:
+
+
+ torch.Tensor
: the loss to train on.
+
+
+
+
+
+
+
Inherited Members
+
+
lightning.pytorch.core.module.LightningModule
+ LightningModule
+ CHECKPOINT_HYPER_PARAMS_KEY
+ CHECKPOINT_HYPER_PARAMS_NAME
+ CHECKPOINT_HYPER_PARAMS_TYPE
+ optimizers
+ lr_schedulers
+ trainer
+ fabric
+ example_input_array
+ current_epoch
+ global_step
+ global_rank
+ local_rank
+ on_gpu
+ automatic_optimization
+ strict_loading
+ logger
+ loggers
+ print
+ log
+ log_dict
+ all_gather
+ forward
+ training_step
+ validation_step
+ test_step
+ predict_step
+ configure_callbacks
+ configure_optimizers
+ manual_backward
+ backward
+ toggle_optimizer
+ untoggle_optimizer
+ clip_gradients
+ configure_gradient_clipping
+ lr_scheduler_step
+ optimizer_step
+ optimizer_zero_grad
+ freeze
+ unfreeze
+ to_onnx
+ to_torchscript
+ load_from_checkpoint
+
+
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+ dtype
+ device
+ to
+ cuda
+ cpu
+ type
+ float
+ double
+ half
+
+
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+ save_hyperparameters
+ hparams
+ hparams_initial
+
+
+
lightning.pytorch.core.hooks.ModelHooks
+ on_fit_start
+ on_fit_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_train_batch_start
+ on_train_batch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_validation_model_zero_grad
+ on_validation_model_eval
+ on_validation_model_train
+ on_test_model_eval
+ on_test_model_train
+ on_predict_model_eval
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_before_zero_grad
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ configure_sharded_model
+ configure_model
+
+
+
lightning.pytorch.core.hooks.DataHooks
+ prepare_data_per_node
+ allow_zero_length_dataloader_with_multiple_devices
+ prepare_data
+ setup
+ teardown
+ train_dataloader
+ test_dataloader
+ val_dataloader
+ predict_dataloader
+ transfer_batch_to_device
+ on_before_batch_transfer
+ on_after_batch_transfer
+
+
+
lightning.pytorch.core.hooks.CheckpointHooks
+ on_load_checkpoint
+ on_save_checkpoint
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ ipu
+ xpu
+ bfloat16
+ to_empty
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ 422 def freeze_domain_modules (
+423 domain_mods : Mapping [ str , DomainModule ],
+424 ) -> dict [ str , DomainModule ]:
+425 """
+426 Freezes weights and set to eval mode the domain modules.
+427
+428 .. note::
+429 The output is casted as `dict[str, DomainModule]` type for better
+430 auto-completion, but is actually a torch `ModuleDict`.
+431
+432 Args:
+433 domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
+434
+435 Returns:
+436 `ModuleDict`: frozen modules.
+437 """
+438
+439 for mod in domain_mods . values ():
+440 mod . freeze ()
+441 # Cast for better auto-completion at the expense of ModuleDict
+442 return cast ( dict [ str , DomainModule ], ModuleDict ( domain_mods ))
+
+
+
+ Freezes weights and set to eval mode the domain modules.
+
+
+
+
The output is casted as dict[str, DomainModule]
type for better
+auto-completion, but is actually a torch ModuleDict
.
+
+
+
+
Arguments:
+
+
+domain_mods (Mapping[str, DomainModule]
): mapping of domain modules to freeze
+
+
+
Returns:
+
+
+ ModuleDict
: frozen modules.
+
+
+
+
+
+
+
+
+
+ class
+ GWPredictions (builtins.dict ):
+
+ View Source
+
+
+
+ 445 class GWPredictions ( GWPredictionsBase ):
+446 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+447
+448 demi_cycles : dict [ str , torch . Tensor ]
+449 """
+450 Demi-cycle predictions of the model for each domain. Only computed on domain
+451 groups with only one domain.
+452 """
+453
+454 cycles : dict [ tuple [ str , str ], torch . Tensor ]
+455 """
+456 Cycle predictions of the model from one domain through another one.
+457 Only computed on domain groups with more than one domain.
+458 The keys are tuple with start domain and intermediary domain.
+459 """
+460
+461 translations : dict [ tuple [ str , str ], torch . Tensor ]
+462 """
+463 Translation predictions of the model from one domain through another one.
+464
+465 Only computed on domain groups with more than one domain.
+466 The keys are tuples with start domain and target domain.
+467 """
+
+
+
+ TypedDict of the output given when calling GlobalWorkspaceBase.predict
+
+
+
+
+
+ demi_cycles : dict[str, torch.Tensor]
+
+
+
+
+
+
Demi-cycle predictions of the model for each domain. Only computed on domain
+groups with only one domain.
+
+
+
+
+
+
+ cycles : dict[tuple[str, str], torch.Tensor]
+
+
+
+
+
+
Cycle predictions of the model from one domain through another one.
+Only computed on domain groups with more than one domain.
+The keys are tuple with start domain and intermediary domain.
+
+
+
+
+
+
+ translations : dict[tuple[str, str], torch.Tensor]
+
+
+
+
+
+
Translation predictions of the model from one domain through another one.
+
+
Only computed on domain groups with more than one domain.
+The keys are tuples with start domain and target domain.
+
+
+
+
+
+
+ states : dict[str, torch.Tensor]
+
+
+
+
+
+
+
+
+
+
Inherited Members
+
+
builtins.dict
+ get
+ setdefault
+ pop
+ popitem
+ keys
+ items
+ values
+ update
+ fromkeys
+ clear
+ copy
+
+
+
+
+
+
+
+
+
+ 470 class GlobalWorkspace2Domains (
+471 GlobalWorkspaceBase [ GWModule , SingleDomainSelection , GWLosses2Domains ]
+472 ):
+473 """
+474 A simple 2-domains max flavor of GlobalWorkspaceBase.
+475
+476 This is used to simplify a Global Workspace instanciation and only overrides the
+477 `__init__` method.
+478 """
+479
+480 def __init__ (
+481 self ,
+482 domain_mods : Mapping [ str , DomainModule ],
+483 gw_encoders : Mapping [ str , Module ],
+484 gw_decoders : Mapping [ str , Module ],
+485 workspace_dim : int ,
+486 loss_coefs : LossCoefs ,
+487 optim_lr : float = 1e-3 ,
+488 optim_weight_decay : float = 0.0 ,
+489 scheduler_args : SchedulerArgs | None = None ,
+490 learn_logit_scale : bool = False ,
+491 contrastive_loss : ContrastiveLossType | None = None ,
+492 ) -> None :
+493 """
+494 Initializes a Global Workspace
+495
+496 Args:
+497 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+498 connected to the GW. Keys are domain names, values are the
+499 `DomainModule`.
+500 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+501 name to a `torch.nn.Module` class which role is to encode a
+502 unimodal latent representations into a GW representation (pre fusion).
+503 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+504 name to a `torch.nn.Module` class which role is to decode a
+505 GW representation into a unimodal latent representations.
+506 workspace_dim (`int`): dimension of the GW.
+507 loss_coefs (`LossCoefs`): loss coefficients
+508 optim_lr (`float`): learning rate
+509 optim_weight_decay (`float`): weight decay
+510 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+511 learn_logit_scale (`bool`): whether to learn the contrastive learning
+512 contrastive loss when using the default contrastive loss.
+513 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+514 function used for alignment. `learn_logit_scale` will not affect custom
+515 contrastive losses.
+516 """
+517 domain_mods = freeze_domain_modules ( domain_mods )
+518
+519 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+520 if contrastive_loss is None :
+521 contrastive_loss = ContrastiveLoss (
+522 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+523 )
+524 selection_mod = SingleDomainSelection ()
+525 loss_mod = GWLosses2Domains (
+526 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+527 )
+528
+529 super () . __init__ (
+530 gw_mod ,
+531 selection_mod ,
+532 loss_mod ,
+533 optim_lr ,
+534 optim_weight_decay ,
+535 scheduler_args ,
+536 )
+537
+538 def forward ( # type: ignore
+539 self ,
+540 latent_domains : LatentsDomainGroupsT ,
+541 ) -> GWPredictions :
+542 """
+543 Computes demi-cycles, cycles, and translations.
+544
+545 Args:
+546 latent_domains (`LatentsT`): Groups of domains for the computation.
+547
+548 Returns:
+549 `GWPredictions`: the predictions on the batch.
+550 """
+551 return GWPredictions (
+552 demi_cycles = batch_demi_cycles (
+553 self . gw_mod , self . selection_mod , latent_domains
+554 ),
+555 cycles = batch_cycles (
+556 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+557 ),
+558 translations = batch_translations (
+559 self . gw_mod , self . selection_mod , latent_domains
+560 ),
+561 ** super () . forward ( latent_domains ),
+562 )
+
+
+
+ A simple 2-domains max flavor of GlobalWorkspaceBase.
+
+
This is used to simplify a Global Workspace instanciation and only overrides the
+__init__
method.
+
+
+
+
+
+
+
+
GlobalWorkspace2Domains ( domain_mods : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , workspace_dim : int , loss_coefs : shimmer.modules.losses.LossCoefs , optim_lr : float = 0.001 , optim_weight_decay : float = 0.0 , scheduler_args : SchedulerArgs | None = None , learn_logit_scale : bool = False , contrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer.modules.domain.LossOutput ] | None = None )
+
+
View Source
+
+
+
+
480 def __init__ (
+481 self ,
+482 domain_mods : Mapping [ str , DomainModule ],
+483 gw_encoders : Mapping [ str , Module ],
+484 gw_decoders : Mapping [ str , Module ],
+485 workspace_dim : int ,
+486 loss_coefs : LossCoefs ,
+487 optim_lr : float = 1e-3 ,
+488 optim_weight_decay : float = 0.0 ,
+489 scheduler_args : SchedulerArgs | None = None ,
+490 learn_logit_scale : bool = False ,
+491 contrastive_loss : ContrastiveLossType | None = None ,
+492 ) -> None :
+493 """
+494 Initializes a Global Workspace
+495
+496 Args:
+497 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+498 connected to the GW. Keys are domain names, values are the
+499 `DomainModule`.
+500 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+501 name to a `torch.nn.Module` class which role is to encode a
+502 unimodal latent representations into a GW representation (pre fusion).
+503 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+504 name to a `torch.nn.Module` class which role is to decode a
+505 GW representation into a unimodal latent representations.
+506 workspace_dim (`int`): dimension of the GW.
+507 loss_coefs (`LossCoefs`): loss coefficients
+508 optim_lr (`float`): learning rate
+509 optim_weight_decay (`float`): weight decay
+510 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+511 learn_logit_scale (`bool`): whether to learn the contrastive learning
+512 contrastive loss when using the default contrastive loss.
+513 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+514 function used for alignment. `learn_logit_scale` will not affect custom
+515 contrastive losses.
+516 """
+517 domain_mods = freeze_domain_modules ( domain_mods )
+518
+519 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+520 if contrastive_loss is None :
+521 contrastive_loss = ContrastiveLoss (
+522 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+523 )
+524 selection_mod = SingleDomainSelection ()
+525 loss_mod = GWLosses2Domains (
+526 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+527 )
+528
+529 super () . __init__ (
+530 gw_mod ,
+531 selection_mod ,
+532 loss_mod ,
+533 optim_lr ,
+534 optim_weight_decay ,
+535 scheduler_args ,
+536 )
+
+
+
+
Initializes a Global Workspace
+
+
Arguments:
+
+
+domain_mods (Mapping[str, DomainModule]
): mapping of the domains
+connected to the GW. Keys are domain names, values are the
+DomainModule
.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to encode a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to decode a
+GW representation into a unimodal latent representations.
+workspace_dim (int
): dimension of the GW.
+loss_coefs (LossCoefs
): loss coefficients
+optim_lr (float
): learning rate
+optim_weight_decay (float
): weight decay
+scheduler_args (SchedulerArgs | None
): optimization scheduler's arguments
+learn_logit_scale (bool
): whether to learn the contrastive learning
+contrastive loss when using the default contrastive loss.
+contrastive_loss (ContrastiveLossType | None
): a contrastive loss
+function used for alignment. learn_logit_scale
will not affect custom
+contrastive losses.
+
+
+
+
+
+
+
+
+
+
def
+
forward ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> GWPredictions :
+
+
View Source
+
+
+
+
538 def forward ( # type: ignore
+539 self ,
+540 latent_domains : LatentsDomainGroupsT ,
+541 ) -> GWPredictions :
+542 """
+543 Computes demi-cycles, cycles, and translations.
+544
+545 Args:
+546 latent_domains (`LatentsT`): Groups of domains for the computation.
+547
+548 Returns:
+549 `GWPredictions`: the predictions on the batch.
+550 """
+551 return GWPredictions (
+552 demi_cycles = batch_demi_cycles (
+553 self . gw_mod , self . selection_mod , latent_domains
+554 ),
+555 cycles = batch_cycles (
+556 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+557 ),
+558 translations = batch_translations (
+559 self . gw_mod , self . selection_mod , latent_domains
+560 ),
+561 ** super () . forward ( latent_domains ),
+562 )
+
+
+
+
Computes demi-cycles, cycles, and translations.
+
+
Arguments:
+
+
+latent_domains (LatentsT
): Groups of domains for the computation.
+
+
+
Returns:
+
+
+ GWPredictions
: the predictions on the batch.
+
+
+
+
+
+
+
Inherited Members
+
+
+
lightning.pytorch.core.module.LightningModule
+ CHECKPOINT_HYPER_PARAMS_KEY
+ CHECKPOINT_HYPER_PARAMS_NAME
+ CHECKPOINT_HYPER_PARAMS_TYPE
+ optimizers
+ lr_schedulers
+ trainer
+ fabric
+ example_input_array
+ current_epoch
+ global_step
+ global_rank
+ local_rank
+ on_gpu
+ automatic_optimization
+ strict_loading
+ logger
+ loggers
+ print
+ log
+ log_dict
+ all_gather
+ configure_callbacks
+ manual_backward
+ backward
+ toggle_optimizer
+ untoggle_optimizer
+ clip_gradients
+ configure_gradient_clipping
+ lr_scheduler_step
+ optimizer_step
+ optimizer_zero_grad
+ freeze
+ unfreeze
+ to_onnx
+ to_torchscript
+ load_from_checkpoint
+
+
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+ dtype
+ device
+ to
+ cuda
+ cpu
+ type
+ float
+ double
+ half
+
+
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+ save_hyperparameters
+ hparams
+ hparams_initial
+
+
+
lightning.pytorch.core.hooks.ModelHooks
+ on_fit_start
+ on_fit_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_train_batch_start
+ on_train_batch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_validation_model_zero_grad
+ on_validation_model_eval
+ on_validation_model_train
+ on_test_model_eval
+ on_test_model_train
+ on_predict_model_eval
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_before_zero_grad
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ configure_sharded_model
+ configure_model
+
+
+
lightning.pytorch.core.hooks.DataHooks
+ prepare_data_per_node
+ allow_zero_length_dataloader_with_multiple_devices
+ prepare_data
+ setup
+ teardown
+ train_dataloader
+ test_dataloader
+ val_dataloader
+ predict_dataloader
+ transfer_batch_to_device
+ on_before_batch_transfer
+ on_after_batch_transfer
+
+
+
lightning.pytorch.core.hooks.CheckpointHooks
+ on_load_checkpoint
+ on_save_checkpoint
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+ get_extra_state
+ set_extra_state
+ apply
+ ipu
+ xpu
+ bfloat16
+ to_empty
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ extra_repr
+ compile
+
+
+
+
+
+
+
+
+
+ 565 class GlobalWorkspace ( GlobalWorkspaceBase [ GWModule , RandomSelection , GWLosses ]):
+566 """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
+567
+568 This is used to simplify a Global Workspace instanciation and only overrides the
+569 `__init__` method.
+570 """
+571
+572 def __init__ (
+573 self ,
+574 domain_mods : Mapping [ str , DomainModule ],
+575 gw_encoders : Mapping [ str , Module ],
+576 gw_decoders : Mapping [ str , Module ],
+577 workspace_dim : int ,
+578 loss_coefs : BroadcastLossCoefs ,
+579 selection_temperature : float = 0.2 ,
+580 optim_lr : float = 1e-3 ,
+581 optim_weight_decay : float = 0.0 ,
+582 scheduler_args : SchedulerArgs | None = None ,
+583 learn_logit_scale : bool = False ,
+584 contrastive_loss : ContrastiveLossType | None = None ,
+585 ) -> None :
+586 """
+587 Initializes a Global Workspace
+588
+589 Args:
+590 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+591 connected to the GW. Keys are domain names, values are the
+592 `DomainModule`.
+593 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+594 name to a `torch.nn.Module` class which role is to encode a
+595 unimodal latent representations into a GW representation (pre fusion).
+596 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+597 name to a `torch.nn.Module` class which role is to decode a
+598 GW representation into a unimodal latent representations.
+599 workspace_dim (`int`): dimension of the GW.
+600 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+601 selection_temperature (`float`): temperature value for the RandomSelection
+602 module.
+603 optim_lr (`float`): learning rate
+604 optim_weight_decay (`float`): weight decay
+605 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+606 learn_logit_scale (`bool`): whether to learn the contrastive learning
+607 contrastive loss when using the default contrastive loss.
+608 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+609 function used for alignment. `learn_logit_scale` will not affect custom
+610 contrastive losses.
+611 """
+612 domain_mods = freeze_domain_modules ( domain_mods )
+613 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+614
+615 if contrastive_loss is None :
+616 contrastive_loss = ContrastiveLoss (
+617 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+618 )
+619
+620 selection_mod = RandomSelection ( selection_temperature )
+621 loss_mod = GWLosses (
+622 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+623 )
+624
+625 super () . __init__ (
+626 gw_mod ,
+627 selection_mod ,
+628 loss_mod ,
+629 optim_lr ,
+630 optim_weight_decay ,
+631 scheduler_args ,
+632 )
+633
+634 def forward ( # type: ignore
+635 self ,
+636 latent_domains : LatentsDomainGroupsT ,
+637 ) -> GWPredictions :
+638 """
+639 Computes demi-cycles, cycles, and translations.
+640
+641 Args:
+642 latent_domains (`LatentsT`): Groups of domains for the computation.
+643
+644 Returns:
+645 `GWPredictions`: the predictions on the batch.
+646 """
+647 return GWPredictions (
+648 demi_cycles = batch_demi_cycles (
+649 self . gw_mod , self . selection_mod , latent_domains
+650 ),
+651 cycles = batch_cycles (
+652 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+653 ),
+654 translations = batch_translations (
+655 self . gw_mod , self . selection_mod , latent_domains
+656 ),
+657 # TODO: add other combinations
+658 ** super () . forward ( latent_domains ),
+659 )
+
+
+
+ The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
+
+
This is used to simplify a Global Workspace instanciation and only overrides the
+__init__
method.
+
+
+
+
+
+
+
+
GlobalWorkspace ( domain_mods : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , workspace_dim : int , loss_coefs : shimmer.modules.losses.BroadcastLossCoefs , selection_temperature : float = 0.2 , optim_lr : float = 0.001 , optim_weight_decay : float = 0.0 , scheduler_args : SchedulerArgs | None = None , learn_logit_scale : bool = False , contrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer.modules.domain.LossOutput ] | None = None )
+
+
View Source
+
+
+
+
572 def __init__ (
+573 self ,
+574 domain_mods : Mapping [ str , DomainModule ],
+575 gw_encoders : Mapping [ str , Module ],
+576 gw_decoders : Mapping [ str , Module ],
+577 workspace_dim : int ,
+578 loss_coefs : BroadcastLossCoefs ,
+579 selection_temperature : float = 0.2 ,
+580 optim_lr : float = 1e-3 ,
+581 optim_weight_decay : float = 0.0 ,
+582 scheduler_args : SchedulerArgs | None = None ,
+583 learn_logit_scale : bool = False ,
+584 contrastive_loss : ContrastiveLossType | None = None ,
+585 ) -> None :
+586 """
+587 Initializes a Global Workspace
+588
+589 Args:
+590 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+591 connected to the GW. Keys are domain names, values are the
+592 `DomainModule`.
+593 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+594 name to a `torch.nn.Module` class which role is to encode a
+595 unimodal latent representations into a GW representation (pre fusion).
+596 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+597 name to a `torch.nn.Module` class which role is to decode a
+598 GW representation into a unimodal latent representations.
+599 workspace_dim (`int`): dimension of the GW.
+600 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+601 selection_temperature (`float`): temperature value for the RandomSelection
+602 module.
+603 optim_lr (`float`): learning rate
+604 optim_weight_decay (`float`): weight decay
+605 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+606 learn_logit_scale (`bool`): whether to learn the contrastive learning
+607 contrastive loss when using the default contrastive loss.
+608 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+609 function used for alignment. `learn_logit_scale` will not affect custom
+610 contrastive losses.
+611 """
+612 domain_mods = freeze_domain_modules ( domain_mods )
+613 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+614
+615 if contrastive_loss is None :
+616 contrastive_loss = ContrastiveLoss (
+617 torch . tensor ([ 1 / 0.07 ]) . log (), "mean" , learn_logit_scale
+618 )
+619
+620 selection_mod = RandomSelection ( selection_temperature )
+621 loss_mod = GWLosses (
+622 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_loss
+623 )
+624
+625 super () . __init__ (
+626 gw_mod ,
+627 selection_mod ,
+628 loss_mod ,
+629 optim_lr ,
+630 optim_weight_decay ,
+631 scheduler_args ,
+632 )
+
+
+
+
Initializes a Global Workspace
+
+
Arguments:
+
+
+domain_mods (Mapping[str, DomainModule]
): mapping of the domains
+connected to the GW. Keys are domain names, values are the
+DomainModule
.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to encode a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to decode a
+GW representation into a unimodal latent representations.
+workspace_dim (int
): dimension of the GW.
+loss_coefs (BroadcastLossCoefs
): loss coefs for the losses.
+selection_temperature (float
): temperature value for the RandomSelection
+module.
+optim_lr (float
): learning rate
+optim_weight_decay (float
): weight decay
+scheduler_args (SchedulerArgs | None
): optimization scheduler's arguments
+learn_logit_scale (bool
): whether to learn the contrastive learning
+contrastive loss when using the default contrastive loss.
+contrastive_loss (ContrastiveLossType | None
): a contrastive loss
+function used for alignment. learn_logit_scale
will not affect custom
+contrastive losses.
+
+
+
+
+
+
+
+
+
+
def
+
forward ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> GWPredictions :
+
+
View Source
+
+
+
+
634 def forward ( # type: ignore
+635 self ,
+636 latent_domains : LatentsDomainGroupsT ,
+637 ) -> GWPredictions :
+638 """
+639 Computes demi-cycles, cycles, and translations.
+640
+641 Args:
+642 latent_domains (`LatentsT`): Groups of domains for the computation.
+643
+644 Returns:
+645 `GWPredictions`: the predictions on the batch.
+646 """
+647 return GWPredictions (
+648 demi_cycles = batch_demi_cycles (
+649 self . gw_mod , self . selection_mod , latent_domains
+650 ),
+651 cycles = batch_cycles (
+652 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+653 ),
+654 translations = batch_translations (
+655 self . gw_mod , self . selection_mod , latent_domains
+656 ),
+657 # TODO: add other combinations
+658 ** super () . forward ( latent_domains ),
+659 )
+
+
+
+
Computes demi-cycles, cycles, and translations.
+
+
Arguments:
+
+
+latent_domains (LatentsT
): Groups of domains for the computation.
+
+
+
Returns:
+
+
+ GWPredictions
: the predictions on the batch.
+
+
+
+
+
+
+
Inherited Members
+
+
+
lightning.pytorch.core.module.LightningModule
+ CHECKPOINT_HYPER_PARAMS_KEY
+ CHECKPOINT_HYPER_PARAMS_NAME
+ CHECKPOINT_HYPER_PARAMS_TYPE
+ optimizers
+ lr_schedulers
+ trainer
+ fabric
+ example_input_array
+ current_epoch
+ global_step
+ global_rank
+ local_rank
+ on_gpu
+ automatic_optimization
+ strict_loading
+ logger
+ loggers
+ print
+ log
+ log_dict
+ all_gather
+ configure_callbacks
+ manual_backward
+ backward
+ toggle_optimizer
+ untoggle_optimizer
+ clip_gradients
+ configure_gradient_clipping
+ lr_scheduler_step
+ optimizer_step
+ optimizer_zero_grad
+ freeze
+ unfreeze
+ to_onnx
+ to_torchscript
+ load_from_checkpoint
+
+
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+ dtype
+ device
+ to
+ cuda
+ cpu
+ type
+ float
+ double
+ half
+
+
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+ save_hyperparameters
+ hparams
+ hparams_initial
+
+
+
lightning.pytorch.core.hooks.ModelHooks
+ on_fit_start
+ on_fit_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_train_batch_start
+ on_train_batch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_validation_model_zero_grad
+ on_validation_model_eval
+ on_validation_model_train
+ on_test_model_eval
+ on_test_model_train
+ on_predict_model_eval
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_before_zero_grad
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ configure_sharded_model
+ configure_model
+
+
+
lightning.pytorch.core.hooks.DataHooks
+ prepare_data_per_node
+ allow_zero_length_dataloader_with_multiple_devices
+ prepare_data
+ setup
+ teardown
+ train_dataloader
+ test_dataloader
+ val_dataloader
+ predict_dataloader
+ transfer_batch_to_device
+ on_before_batch_transfer
+ on_after_batch_transfer
+
+
+
lightning.pytorch.core.hooks.CheckpointHooks
+ on_load_checkpoint
+ on_save_checkpoint
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ ipu
+ xpu
+ bfloat16
+ to_empty
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ 662 class GlobalWorkspaceBayesian (
+663 GlobalWorkspaceBase [ GWModuleBayesian , FixedSharedSelection , GWLossesBayesian ]
+664 ):
+665 """
+666 A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
+667 prediction.
+668
+669 This is used to simplify a Global Workspace instanciation and only overrides the
+670 `__init__` method.
+671 """
+672
+673 def __init__ (
+674 self ,
+675 domain_mods : Mapping [ str , DomainModule ],
+676 gw_encoders : Mapping [ str , Module ],
+677 gw_decoders : Mapping [ str , Module ],
+678 workspace_dim : int ,
+679 loss_coefs : BroadcastLossCoefs ,
+680 sensitivity_selection : float = 1 ,
+681 sensitivity_precision : float = 1 ,
+682 optim_lr : float = 1e-3 ,
+683 optim_weight_decay : float = 0.0 ,
+684 scheduler_args : SchedulerArgs | None = None ,
+685 learn_logit_scale : bool = False ,
+686 use_normalized_constrastive : bool = True ,
+687 contrastive_loss : ContrastiveLossType | None = None ,
+688 precision_softmax_temp : float = 0.01 ,
+689 ) -> None :
+690 """
+691 Initializes a Global Workspace
+692
+693 Args:
+694 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+695 connected to the GW. Keys are domain names, values are the
+696 `DomainModule`.
+697 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+698 name to a `torch.nn.Module` class which role is to encode a
+699 unimodal latent representations into a GW representation (pre fusion).
+700 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+701 name to a `torch.nn.Module` class which role is to decode a
+702 GW representation into a unimodal latent representations.
+703 workspace_dim (`int`): dimension of the GW.
+704 loss_coefs (`LossCoefs`): loss coefficients
+705 sensitivity_selection (`float`): sensivity coef $c'_1$
+706 sensitivity_precision (`float`): sensitivity coef $c'_2$
+707 optim_lr (`float`): learning rate
+708 optim_weight_decay (`float`): weight decay
+709 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+710 learn_logit_scale (`bool`): whether to learn the contrastive learning
+711 contrastive loss when using the default contrastive loss.
+712 use_normalized_constrastive (`bool`): whether to use the normalized cont
+713 loss by the precision coefs
+714 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+715 function used for alignment. `learn_logit_scale` will not affect custom
+716 contrastive losses.
+717 precision_softmax_temp (`float`): temperature to use in softmax of
+718 precision
+719 """
+720 domain_mods = freeze_domain_modules ( domain_mods )
+721
+722 gw_mod = GWModuleBayesian (
+723 domain_mods ,
+724 workspace_dim ,
+725 gw_encoders ,
+726 gw_decoders ,
+727 sensitivity_selection ,
+728 sensitivity_precision ,
+729 precision_softmax_temp ,
+730 )
+731
+732 selection_mod = FixedSharedSelection ()
+733
+734 contrastive_loss = ContrastiveLoss (
+735 torch . tensor ([ 1 ]) . log (), "mean" , learn_logit_scale
+736 )
+737
+738 loss_mod = GWLossesBayesian (
+739 gw_mod ,
+740 selection_mod ,
+741 domain_mods ,
+742 loss_coefs ,
+743 contrastive_loss ,
+744 use_normalized_constrastive ,
+745 )
+746
+747 super () . __init__ (
+748 gw_mod ,
+749 selection_mod ,
+750 loss_mod ,
+751 optim_lr ,
+752 optim_weight_decay ,
+753 scheduler_args ,
+754 )
+755
+756 def forward ( # type: ignore
+757 self ,
+758 latent_domains : LatentsDomainGroupsT ,
+759 ) -> GWPredictions :
+760 """
+761 Computes demi-cycles, cycles, and translations.
+762
+763 Args:
+764 latent_domains (`LatentsT`): Groups of domains for the computation.
+765
+766 Returns:
+767 `GWPredictions`: the predictions on the batch.
+768 """
+769 return GWPredictions (
+770 demi_cycles = batch_demi_cycles (
+771 self . gw_mod , self . selection_mod , latent_domains
+772 ),
+773 cycles = batch_cycles (
+774 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+775 ),
+776 translations = batch_translations (
+777 self . gw_mod , self . selection_mod , latent_domains
+778 ),
+779 ** super () . forward ( latent_domains ),
+780 )
+
+
+
+ A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
+prediction.
+
+
This is used to simplify a Global Workspace instanciation and only overrides the
+__init__
method.
+
+
+
+
+
+
+
+
GlobalWorkspaceBayesian ( domain_mods : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , workspace_dim : int , loss_coefs : shimmer.modules.losses.BroadcastLossCoefs , sensitivity_selection : float = 1 , sensitivity_precision : float = 1 , optim_lr : float = 0.001 , optim_weight_decay : float = 0.0 , scheduler_args : SchedulerArgs | None = None , learn_logit_scale : bool = False , use_normalized_constrastive : bool = True , contrastive_loss : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer.modules.domain.LossOutput ] | None = None , precision_softmax_temp : float = 0.01 )
+
+
View Source
+
+
+
+
673 def __init__ (
+674 self ,
+675 domain_mods : Mapping [ str , DomainModule ],
+676 gw_encoders : Mapping [ str , Module ],
+677 gw_decoders : Mapping [ str , Module ],
+678 workspace_dim : int ,
+679 loss_coefs : BroadcastLossCoefs ,
+680 sensitivity_selection : float = 1 ,
+681 sensitivity_precision : float = 1 ,
+682 optim_lr : float = 1e-3 ,
+683 optim_weight_decay : float = 0.0 ,
+684 scheduler_args : SchedulerArgs | None = None ,
+685 learn_logit_scale : bool = False ,
+686 use_normalized_constrastive : bool = True ,
+687 contrastive_loss : ContrastiveLossType | None = None ,
+688 precision_softmax_temp : float = 0.01 ,
+689 ) -> None :
+690 """
+691 Initializes a Global Workspace
+692
+693 Args:
+694 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+695 connected to the GW. Keys are domain names, values are the
+696 `DomainModule`.
+697 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+698 name to a `torch.nn.Module` class which role is to encode a
+699 unimodal latent representations into a GW representation (pre fusion).
+700 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+701 name to a `torch.nn.Module` class which role is to decode a
+702 GW representation into a unimodal latent representations.
+703 workspace_dim (`int`): dimension of the GW.
+704 loss_coefs (`LossCoefs`): loss coefficients
+705 sensitivity_selection (`float`): sensivity coef $c'_1$
+706 sensitivity_precision (`float`): sensitivity coef $c'_2$
+707 optim_lr (`float`): learning rate
+708 optim_weight_decay (`float`): weight decay
+709 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+710 learn_logit_scale (`bool`): whether to learn the contrastive learning
+711 contrastive loss when using the default contrastive loss.
+712 use_normalized_constrastive (`bool`): whether to use the normalized cont
+713 loss by the precision coefs
+714 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+715 function used for alignment. `learn_logit_scale` will not affect custom
+716 contrastive losses.
+717 precision_softmax_temp (`float`): temperature to use in softmax of
+718 precision
+719 """
+720 domain_mods = freeze_domain_modules ( domain_mods )
+721
+722 gw_mod = GWModuleBayesian (
+723 domain_mods ,
+724 workspace_dim ,
+725 gw_encoders ,
+726 gw_decoders ,
+727 sensitivity_selection ,
+728 sensitivity_precision ,
+729 precision_softmax_temp ,
+730 )
+731
+732 selection_mod = FixedSharedSelection ()
+733
+734 contrastive_loss = ContrastiveLoss (
+735 torch . tensor ([ 1 ]) . log (), "mean" , learn_logit_scale
+736 )
+737
+738 loss_mod = GWLossesBayesian (
+739 gw_mod ,
+740 selection_mod ,
+741 domain_mods ,
+742 loss_coefs ,
+743 contrastive_loss ,
+744 use_normalized_constrastive ,
+745 )
+746
+747 super () . __init__ (
+748 gw_mod ,
+749 selection_mod ,
+750 loss_mod ,
+751 optim_lr ,
+752 optim_weight_decay ,
+753 scheduler_args ,
+754 )
+
+
+
+
Initializes a Global Workspace
+
+
Arguments:
+
+
+domain_mods (Mapping[str, DomainModule]
): mapping of the domains
+connected to the GW. Keys are domain names, values are the
+DomainModule
.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to encode a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to decode a
+GW representation into a unimodal latent representations.
+workspace_dim (int
): dimension of the GW.
+loss_coefs (LossCoefs
): loss coefficients
+sensitivity_selection (float
): sensivity coef $c'_1$
+sensitivity_precision (float
): sensitivity coef $c'_2$
+optim_lr (float
): learning rate
+optim_weight_decay (float
): weight decay
+scheduler_args (SchedulerArgs | None
): optimization scheduler's arguments
+learn_logit_scale (bool
): whether to learn the contrastive learning
+contrastive loss when using the default contrastive loss.
+use_normalized_constrastive (bool
): whether to use the normalized cont
+loss by the precision coefs
+contrastive_loss (ContrastiveLossType | None
): a contrastive loss
+function used for alignment. learn_logit_scale
will not affect custom
+contrastive losses.
+precision_softmax_temp (float
): temperature to use in softmax of
+precision
+
+
+
+
+
+
+
+
+
+
def
+
forward ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> GWPredictions :
+
+
View Source
+
+
+
+
756 def forward ( # type: ignore
+757 self ,
+758 latent_domains : LatentsDomainGroupsT ,
+759 ) -> GWPredictions :
+760 """
+761 Computes demi-cycles, cycles, and translations.
+762
+763 Args:
+764 latent_domains (`LatentsT`): Groups of domains for the computation.
+765
+766 Returns:
+767 `GWPredictions`: the predictions on the batch.
+768 """
+769 return GWPredictions (
+770 demi_cycles = batch_demi_cycles (
+771 self . gw_mod , self . selection_mod , latent_domains
+772 ),
+773 cycles = batch_cycles (
+774 self . gw_mod , self . selection_mod , latent_domains , self . domain_mods . keys ()
+775 ),
+776 translations = batch_translations (
+777 self . gw_mod , self . selection_mod , latent_domains
+778 ),
+779 ** super () . forward ( latent_domains ),
+780 )
+
+
+
+
Computes demi-cycles, cycles, and translations.
+
+
Arguments:
+
+
+latent_domains (LatentsT
): Groups of domains for the computation.
+
+
+
Returns:
+
+
+ GWPredictions
: the predictions on the batch.
+
+
+
+
+
+
+
Inherited Members
+
+
+
lightning.pytorch.core.module.LightningModule
+ CHECKPOINT_HYPER_PARAMS_KEY
+ CHECKPOINT_HYPER_PARAMS_NAME
+ CHECKPOINT_HYPER_PARAMS_TYPE
+ optimizers
+ lr_schedulers
+ trainer
+ fabric
+ example_input_array
+ current_epoch
+ global_step
+ global_rank
+ local_rank
+ on_gpu
+ automatic_optimization
+ strict_loading
+ logger
+ loggers
+ print
+ log
+ log_dict
+ all_gather
+ configure_callbacks
+ manual_backward
+ backward
+ toggle_optimizer
+ untoggle_optimizer
+ clip_gradients
+ configure_gradient_clipping
+ lr_scheduler_step
+ optimizer_step
+ optimizer_zero_grad
+ freeze
+ unfreeze
+ to_onnx
+ to_torchscript
+ load_from_checkpoint
+
+
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+ dtype
+ device
+ to
+ cuda
+ cpu
+ type
+ float
+ double
+ half
+
+
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+ save_hyperparameters
+ hparams
+ hparams_initial
+
+
+
lightning.pytorch.core.hooks.ModelHooks
+ on_fit_start
+ on_fit_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_train_batch_start
+ on_train_batch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_validation_model_zero_grad
+ on_validation_model_eval
+ on_validation_model_train
+ on_test_model_eval
+ on_test_model_train
+ on_predict_model_eval
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_before_zero_grad
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ configure_sharded_model
+ configure_model
+
+
+
lightning.pytorch.core.hooks.DataHooks
+ prepare_data_per_node
+ allow_zero_length_dataloader_with_multiple_devices
+ prepare_data
+ setup
+ teardown
+ train_dataloader
+ test_dataloader
+ val_dataloader
+ predict_dataloader
+ transfer_batch_to_device
+ on_before_batch_transfer
+ on_after_batch_transfer
+
+
+
lightning.pytorch.core.hooks.CheckpointHooks
+ on_load_checkpoint
+ on_save_checkpoint
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ ipu
+ xpu
+ bfloat16
+ to_empty
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+
def
+
pretrained_global_workspace ( checkpoint_path : str | pathlib . Path , domain_mods : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , workspace_dim : int , loss_coefs : shimmer.modules.losses.LossCoefs , contrastive_fn : collections . abc . Callable [[ torch . Tensor , torch . Tensor ], shimmer.modules.domain.LossOutput ] , ** kwargs ) -> GlobalWorkspace2Domains :
+
+
View Source
+
+
+
+ 783 def pretrained_global_workspace (
+784 checkpoint_path : str | Path ,
+785 domain_mods : Mapping [ str , DomainModule ],
+786 gw_encoders : Mapping [ str , Module ],
+787 gw_decoders : Mapping [ str , Module ],
+788 workspace_dim : int ,
+789 loss_coefs : LossCoefs ,
+790 contrastive_fn : ContrastiveLossType ,
+791 ** kwargs ,
+792 ) -> GlobalWorkspace2Domains :
+793 """
+794 Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
+795
+796 Args:
+797 checkpoint_path (`str | Path`): path to checkpoint
+798 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+799 connected to the GW. Keys are domain names, values are the
+800 `DomainModule`.
+801 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+802 name to a `torch.nn.Module` class which role is to encode a
+803 unimodal latent representations into a GW representation (pre fusion).
+804 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+805 name to a `torch.nn.Module` class which role is to decode a
+806 GW representation into a unimodal latent representations.
+807 workspace_dim (`int`): dimension of the GW.
+808 loss_coefs (`LossCoefs`): loss coefficients
+809 contrastive_loss (`ContrastiveLossType`): a contrastive loss
+810 function used for alignment. `learn_logit_scale` will not affect custom
+811 contrastive losses.
+812 **kwargs: additional arguments to pass to
+813 `GlobalWorkspace.load_from_checkpoint`.
+814
+815 Returns:
+816 `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
+817
+818 Raises:
+819 `TypeError`: if loaded type is not `GlobalWorkspace`.
+820 """
+821 domain_mods = freeze_domain_modules ( domain_mods )
+822 gw_mod = GWModule ( domain_mods , workspace_dim , gw_encoders , gw_decoders )
+823 selection_mod = SingleDomainSelection ()
+824 loss_mod = GWLosses2Domains (
+825 gw_mod , selection_mod , domain_mods , loss_coefs , contrastive_fn
+826 )
+827
+828 gw = GlobalWorkspace2Domains . load_from_checkpoint (
+829 checkpoint_path ,
+830 gw_mod = gw_mod ,
+831 selection_mid = selection_mod ,
+832 loss_coefs = loss_coefs ,
+833 loss_mod = loss_mod ,
+834 ** kwargs ,
+835 )
+836 if not isinstance ( gw , GlobalWorkspace2Domains ):
+837 raise TypeError ( "model should be of type GlobalWorkspace" )
+838 return gw
+
+
+
+ Load a GlobalWorkspace
flavor of GlobalWorkspaceBase
from a checkpoint.
+
+
Arguments:
+
+
+checkpoint_path (str | Path
): path to checkpoint
+domain_mods (Mapping[str, DomainModule]
): mapping of the domains
+connected to the GW. Keys are domain names, values are the
+DomainModule
.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to encode a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a torch.nn.Module
class which role is to decode a
+GW representation into a unimodal latent representations.
+workspace_dim (int
): dimension of the GW.
+loss_coefs (LossCoefs
): loss coefficients
+contrastive_loss (ContrastiveLossType
): a contrastive loss
+function used for alignment. learn_logit_scale
will not affect custom
+contrastive losses.
+**kwargs: additional arguments to pass to
+GlobalWorkspace.load_from_checkpoint
.
+
+
+
Returns:
+
+
+ GlobalWorkspace
: the pretrained GlobalWorkspace
.
+
+
+
Raises:
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/gw_module.html b/docs/api/v0.5.1/shimmer/modules/gw_module.html
new file mode 100644
index 00000000..bf539184
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/gw_module.html
@@ -0,0 +1,2893 @@
+
+
+
+
+
+
+ shimmer.modules.gw_module API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.gw_module
+
+
+
+
+ View Source
+
+ 1 from abc import ABC , abstractmethod
+ 2 from collections.abc import Iterable , Mapping
+ 3 from typing import cast
+ 4
+ 5 import torch
+ 6 from torch import nn
+ 7
+ 8 from shimmer.modules.domain import DomainModule
+ 9 from shimmer.modules.selection import SelectionBase
+ 10 from shimmer.types import LatentsDomainGroupDT , LatentsDomainGroupT
+ 11
+ 12
+ 13 def get_n_layers ( n_layers : int , hidden_dim : int ) -> list [ nn . Module ]:
+ 14 """
+ 15 Makes a list of `n_layers` `nn.Linear` layers with `nn.ReLU`.
+ 16
+ 17 Args:
+ 18 n_layers (`int`): number of layers
+ 19 hidden_dim (`int`): size of the hidden dimension
+ 20
+ 21 Returns:
+ 22 `list[nn.Module]`: list of linear and relu layers.
+ 23 """
+ 24 layers : list [ nn . Module ] = []
+ 25 for _ in range ( n_layers ):
+ 26 layers . extend ([ nn . Linear ( hidden_dim , hidden_dim ), nn . ReLU ()])
+ 27 return layers
+ 28
+ 29
+ 30 class GWDecoder ( nn . Sequential ):
+ 31 """A Decoder network for GWModules."""
+ 32
+ 33 def __init__ (
+ 34 self ,
+ 35 in_dim : int ,
+ 36 hidden_dim : int ,
+ 37 out_dim : int ,
+ 38 n_layers : int ,
+ 39 ):
+ 40 """
+ 41 Initializes the decoder.
+ 42
+ 43 Args:
+ 44 in_dim (`int`): input dimension
+ 45 hidden_dim (`int`): hidden dimension
+ 46 out_dim (`int`): output dimension
+ 47 n_layers (`int`): number of hidden layers. The total number of layers
+ 48 will be `n_layers` + 2 (one before, one after).
+ 49 """
+ 50
+ 51 self . in_dim = in_dim
+ 52 """input dimension"""
+ 53
+ 54 self . hidden_dim = hidden_dim
+ 55 """hidden dimension"""
+ 56
+ 57 self . out_dim = out_dim
+ 58 """output dimension"""
+ 59
+ 60 self . n_layers = n_layers
+ 61 """
+ 62 number of hidden layers. The total number of layers
+ 63 will be `n_layers` + 2 (one before, one after)."""
+ 64
+ 65 super () . __init__ (
+ 66 nn . Linear ( self . in_dim , self . hidden_dim ),
+ 67 nn . ReLU (),
+ 68 * get_n_layers ( n_layers , self . hidden_dim ),
+ 69 nn . Linear ( self . hidden_dim , self . out_dim ),
+ 70 )
+ 71
+ 72
+ 73 class GWEncoder ( GWDecoder ):
+ 74 """
+ 75 An Encoder network used in GWModules.
+ 76
+ 77 This is similar to the decoder, but adds a tanh non-linearity at the end.
+ 78 """
+ 79
+ 80 def __init__ (
+ 81 self ,
+ 82 in_dim : int ,
+ 83 hidden_dim : int ,
+ 84 out_dim : int ,
+ 85 n_layers : int ,
+ 86 ):
+ 87 """
+ 88 Initializes the encoder.
+ 89
+ 90 Args:
+ 91 in_dim (`int`): input dimension
+ 92 hidden_dim (`int`): hidden dimension
+ 93 out_dim (`int`): output dimension
+ 94 n_layers (`int`): number of hidden layers. The total number of layers
+ 95 will be `n_layers` + 2 (one before, one after).
+ 96 """
+ 97 super () . __init__ ( in_dim , hidden_dim , out_dim , n_layers )
+ 98
+ 99 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+100 return super () . forward ( input )
+101
+102
+103 class GWEncoderLinear ( nn . Linear ):
+104 """A linear Encoder network used in GWModules."""
+105
+106 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+107 return torch . tanh ( super () . forward ( input ))
+108
+109
+110 class GWModuleBase ( nn . Module , ABC ):
+111 """
+112 Base class for GWModule.
+113
+114 GWModule handles encoding, decoding the unimodal representations
+115 using the `gw_encoders` and`gw_decoders`, and define
+116 some common operations in GW like cycles and translations.
+117
+118 This is an abstract class and should be implemented.
+119 For an implemented interface, see `GWModule`.
+120 """
+121
+122 def __init__ (
+123 self ,
+124 domain_mods : Mapping [ str , DomainModule ],
+125 workspace_dim : int ,
+126 * args ,
+127 ** kwargs ,
+128 ) -> None :
+129 """
+130 Initializes the GWModule.
+131
+132 Args:
+133 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+134 workspace_dim (`int`): dimension of the GW.
+135 """
+136 super () . __init__ ()
+137
+138 self . domain_mods = domain_mods
+139 """The unimodal domain modules."""
+140
+141 self . workspace_dim = workspace_dim
+142 """Dimension of the GW"""
+143
+144 @abstractmethod
+145 def fuse (
+146 self , x : LatentsDomainGroupT , selection_scores : Mapping [ str , torch . Tensor ]
+147 ) -> torch . Tensor :
+148 """
+149 Merge function used to combine domains.
+150
+151 Args:
+152 x (`LatentsDomainGroupT`): the group of latent representation.
+153 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+154 use to encode the reprensetation.
+155 Returns:
+156 `torch.Tensor`: The merged representation.
+157 """
+158 ...
+159
+160 @abstractmethod
+161 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+162 """
+163 Encode the latent representation infos to the pre-fusion GW representation.
+164
+165 Args:
+166 x (`LatentsDomainGroupT`): the input domain representations
+167
+168 Returns:
+169 `LatentsDomainGroupT`: pre-fusion GW representations
+170 """
+171 ...
+172
+173 def encode_and_fuse (
+174 self , x : LatentsDomainGroupT , selection_module : SelectionBase
+175 ) -> torch . Tensor :
+176 """
+177 Encode the latent representation infos to the final GW representation.
+178 It combines the encode and fuse methods.
+179
+180 Args:
+181 x (`LatentsDomainGroupT`): the input domain representations
+182 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+183 use to encode the reprensetation.
+184
+185 Returns:
+186 `torch.Tensor`: The merged representation.
+187 """
+188 encodings = self . encode ( x )
+189 selection_scores = selection_module ( x , encodings )
+190 return self . fuse ( encodings , selection_scores )
+191
+192 @abstractmethod
+193 def decode (
+194 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+195 ) -> LatentsDomainGroupDT :
+196 """
+197 Decode the GW representation into given `domains`.
+198
+199 Args:
+200 z (`torch.Tensor`): the GW representation.
+201 domains (`Iterable[str]`): iterable of domains to decode.
+202
+203 Returns:
+204 `LatentsDomainGroupDT`: the decoded unimodal representations.
+205 """
+206 ...
+207
+208
+209 class GWModule ( GWModuleBase ):
+210 """GW nn.Module. Implements `GWModuleBase`."""
+211
+212 def __init__ (
+213 self ,
+214 domain_modules : Mapping [ str , DomainModule ],
+215 workspace_dim : int ,
+216 gw_encoders : Mapping [ str , nn . Module ],
+217 gw_decoders : Mapping [ str , nn . Module ],
+218 ) -> None :
+219 """
+220 Initializes the GWModule.
+221
+222 Args:
+223 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+224 workspace_dim (`int`): dimension of the GW.
+225 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+226 name to a an torch.nn.Module class that encodes a
+227 unimodal latent representations into a GW representation (pre fusion).
+228 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+229 name to a an torch.nn.Module class that decodes a
+230 GW representation to a unimodal latent representation.
+231 """
+232 super () . __init__ ( domain_modules , workspace_dim )
+233
+234 self . gw_encoders = nn . ModuleDict ( gw_encoders )
+235 """The module's encoders"""
+236
+237 self . gw_decoders = nn . ModuleDict ( gw_decoders )
+238 """The module's decoders"""
+239
+240 def fuse (
+241 self ,
+242 x : LatentsDomainGroupT ,
+243 selection_scores : Mapping [ str , torch . Tensor ],
+244 ) -> torch . Tensor :
+245 """
+246 Merge function used to combine domains.
+247
+248 Args:
+249 x (`LatentsDomainGroupT`): the group of latent representation.
+250 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+251 use to encode the reprensetation.
+252 Returns:
+253 `torch.Tensor`: The merged representation.
+254 """
+255 return torch . tanh (
+256 torch . sum (
+257 torch . stack (
+258 [
+259 selection_scores [ domain ] . unsqueeze ( 1 ) * x [ domain ]
+260 for domain in selection_scores
+261 ]
+262 ),
+263 dim = 0 ,
+264 )
+265 )
+266
+267 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+268 """
+269 Encode the latent representation infos to the pre-fusion GW representation.
+270
+271 Args:
+272 x (`LatentsDomainGroupT`): the input domain representations.
+273
+274 Returns:
+275 `LatentsDomainGroupT`: pre-fusion representation
+276 """
+277 return {
+278 domain_name : self . gw_encoders [ domain_name ]( domain )
+279 for domain_name , domain in x . items ()
+280 }
+281
+282 def decode (
+283 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+284 ) -> LatentsDomainGroupDT :
+285 """
+286 Decodes a GW representation to multiple domains.
+287
+288 Args:
+289 z (`torch.Tensor`): the GW representation
+290 domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+291 use keys in `gw_interfaces` (all domains).
+292 Returns:
+293 `LatentsDomainGroupDT`: decoded unimodal representation
+294 """
+295 return {
+296 domain : self . gw_decoders [ domain ]( z )
+297 for domain in domains or self . gw_decoders . keys ()
+298 }
+299
+300
+301 def compute_fusion_scores (
+302 score_1 : torch . Tensor ,
+303 score_2 : torch . Tensor ,
+304 sensitivity_1 : float = 1.0 ,
+305 sensitivity_2 : float = 1.0 ,
+306 eps : float = 1e-6 ,
+307 ) -> torch . Tensor :
+308 """
+309 Combine precision scores using std summation in quadrature
+310
+311 The two scores should have the same dimension.
+312
+313 Args:
+314 score_1 (`torch.Tensor`): First scores.
+315 score_2 (`torch.Tensor`): Second scores.
+316 sensitivity_1 (`float`): sensitivity for the first score
+317 sensitivity_2 (`float`): sensitivity for the second score
+318 eps (`float`): a value added to avoid numerical unstability.
+319
+320 Returns:
+321 `torch.Tensor`: the combined scores
+322 """
+323 total_uncertainty = sensitivity_1 / ( eps + score_1 ) + sensitivity_2 / (
+324 eps + score_2
+325 )
+326 final_scores = 1 / ( eps + total_uncertainty )
+327 return final_scores / final_scores . sum ( dim = 0 , keepdim = True )
+328
+329
+330 class GWModuleBayesian ( GWModule ):
+331 """`GWModule` with a Bayesian based uncertainty prediction."""
+332
+333 def __init__ (
+334 self ,
+335 domain_modules : Mapping [ str , DomainModule ],
+336 workspace_dim : int ,
+337 gw_encoders : Mapping [ str , nn . Module ],
+338 gw_decoders : Mapping [ str , nn . Module ],
+339 sensitivity_selection : float = 1 ,
+340 sensitivity_precision : float = 1 ,
+341 precision_softmax_temp : float = 0.01 ,
+342 ) -> None :
+343 """
+344 Initializes the GWModuleBayesian.
+345
+346 Args:
+347 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+348 workspace_dim (`int`): dimension of the GW.
+349 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+350 name to a an torch.nn.Module class that encodes a
+351 unimodal latent representations into a GW representation (pre fusion).
+352 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+353 name to a an torch.nn.Module class that decodes a
+354 GW representation to a unimodal latent representation.
+355 sensitivity_selection (`float`): sensivity coef $c'_1$
+356 sensitivity_precision (`float`): sensitivity coef $c'_2$
+357 precision_softmax_temp (`float`): temperature to use in softmax of
+358 precision
+359 """
+360 super () . __init__ ( domain_modules , workspace_dim , gw_encoders , gw_decoders )
+361
+362 self . precisions = cast (
+363 dict [ str , torch . Tensor ],
+364 nn . ParameterDict (
+365 { domain : torch . randn ( workspace_dim ) for domain in gw_encoders }
+366 ),
+367 )
+368 """Precision at the neuron level for every domain."""
+369
+370 self . sensitivity_selection = sensitivity_selection
+371 self . sensitivity_precision = sensitivity_precision
+372 self . precision_softmax_temp = precision_softmax_temp
+373
+374 def get_precision ( self , domain : str , x : torch . Tensor ) -> torch . Tensor :
+375 """
+376 Get the precision vector of given domain and batch
+377
+378 Args:
+379 domain (`str`):
+380 x (`torch.Tensor`): batch of inputs
+381
+382 Returns:
+383 `torch.Tensor`: batch of precision
+384 """
+385 return self . precisions [ domain ] . unsqueeze ( 0 ) . expand ( x . size ( 0 ), - 1 )
+386
+387 def fuse (
+388 self ,
+389 x : LatentsDomainGroupT ,
+390 selection_scores : Mapping [ str , torch . Tensor ],
+391 ) -> torch . Tensor :
+392 """
+393 Merge function used to combine domains.
+394
+395 In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+396 dimension of the Global Workspace.
+397
+398 This function needs to merge two kind of scores:
+399 * the selection scores $a\\in [0,1]^{D\\times N}$;
+400 * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+401
+402 .. note::
+403 The precision score is obtained by predicting logits and using a softmax
+404
+405 We can obtain associated uncertainties to the scores by introducing a std
+406 variable and using bayesian integration:
+407
+408 $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+409 where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+410
+411 Similarly,
+412 $$b_k = \\frac{M_2}{\\mu_k^2}$$
+413 where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+414
+415 The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+416 $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+417
+418 which, in terms of $a_k$ and $b_k$ yields:
+419 $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+420 where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+421
+422 Finally, the finale combined coefficient is
+423 $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+424 where
+425 $$M_3 = \\frac{1}{\\sum_{i=1}^D
+426 \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+427
+428 Args:
+429 x (`LatentsDomainGroupT`): the group of latent representation.
+430 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+431 use to encode the reprensetation.
+432 Returns:
+433 `torch.Tensor`: The merged representation.
+434 """
+435 scores : list [ torch . Tensor ] = []
+436 precisions : list [ torch . Tensor ] = []
+437 domains : list [ torch . Tensor ] = []
+438 for domain , score in selection_scores . items ():
+439 scores . append ( score )
+440 precisions . append ( self . get_precision ( domain , x [ domain ]))
+441 domains . append ( x [ domain ])
+442 combined_scores = compute_fusion_scores (
+443 torch . stack ( scores ) . unsqueeze ( - 1 ),
+444 torch . softmax (
+445 torch . tanh ( torch . stack ( precisions )) * self . precision_softmax_temp , dim = 0
+446 ),
+447 self . sensitivity_selection ,
+448 self . sensitivity_precision ,
+449 )
+450 return torch . tanh (
+451 torch . sum (
+452 combined_scores * torch . stack ( domains ),
+453 dim = 0 ,
+454 )
+455 )
+
+
+
+
+
+
+
+
+ def
+ get_n_layers (n_layers : int , hidden_dim : int ) -> list [ torch . nn . modules . module . Module ] :
+
+ View Source
+
+
+
+ 14 def get_n_layers ( n_layers : int , hidden_dim : int ) -> list [ nn . Module ]:
+15 """
+16 Makes a list of `n_layers` `nn.Linear` layers with `nn.ReLU`.
+17
+18 Args:
+19 n_layers (`int`): number of layers
+20 hidden_dim (`int`): size of the hidden dimension
+21
+22 Returns:
+23 `list[nn.Module]`: list of linear and relu layers.
+24 """
+25 layers : list [ nn . Module ] = []
+26 for _ in range ( n_layers ):
+27 layers . extend ([ nn . Linear ( hidden_dim , hidden_dim ), nn . ReLU ()])
+28 return layers
+
+
+
+ Makes a list of n_layers
nn.Linear
layers with nn.ReLU
.
+
+
Arguments:
+
+
+n_layers (int
): number of layers
+hidden_dim (int
): size of the hidden dimension
+
+
+
Returns:
+
+
+ list[nn.Module]
: list of linear and relu layers.
+
+
+
+
+
+
+
+
+
+ class
+ GWDecoder (torch.nn.modules.container.Sequential ):
+
+ View Source
+
+
+
+ 31 class GWDecoder ( nn . Sequential ):
+32 """A Decoder network for GWModules."""
+33
+34 def __init__ (
+35 self ,
+36 in_dim : int ,
+37 hidden_dim : int ,
+38 out_dim : int ,
+39 n_layers : int ,
+40 ):
+41 """
+42 Initializes the decoder.
+43
+44 Args:
+45 in_dim (`int`): input dimension
+46 hidden_dim (`int`): hidden dimension
+47 out_dim (`int`): output dimension
+48 n_layers (`int`): number of hidden layers. The total number of layers
+49 will be `n_layers` + 2 (one before, one after).
+50 """
+51
+52 self . in_dim = in_dim
+53 """input dimension"""
+54
+55 self . hidden_dim = hidden_dim
+56 """hidden dimension"""
+57
+58 self . out_dim = out_dim
+59 """output dimension"""
+60
+61 self . n_layers = n_layers
+62 """
+63 number of hidden layers. The total number of layers
+64 will be `n_layers` + 2 (one before, one after)."""
+65
+66 super () . __init__ (
+67 nn . Linear ( self . in_dim , self . hidden_dim ),
+68 nn . ReLU (),
+69 * get_n_layers ( n_layers , self . hidden_dim ),
+70 nn . Linear ( self . hidden_dim , self . out_dim ),
+71 )
+
+
+
+ A Decoder network for GWModules.
+
+
+
+
+
+
+
+ GWDecoder (in_dim : int , hidden_dim : int , out_dim : int , n_layers : int )
+
+ View Source
+
+
+
+
34 def __init__ (
+35 self ,
+36 in_dim : int ,
+37 hidden_dim : int ,
+38 out_dim : int ,
+39 n_layers : int ,
+40 ):
+41 """
+42 Initializes the decoder.
+43
+44 Args:
+45 in_dim (`int`): input dimension
+46 hidden_dim (`int`): hidden dimension
+47 out_dim (`int`): output dimension
+48 n_layers (`int`): number of hidden layers. The total number of layers
+49 will be `n_layers` + 2 (one before, one after).
+50 """
+51
+52 self . in_dim = in_dim
+53 """input dimension"""
+54
+55 self . hidden_dim = hidden_dim
+56 """hidden dimension"""
+57
+58 self . out_dim = out_dim
+59 """output dimension"""
+60
+61 self . n_layers = n_layers
+62 """
+63 number of hidden layers. The total number of layers
+64 will be `n_layers` + 2 (one before, one after)."""
+65
+66 super () . __init__ (
+67 nn . Linear ( self . in_dim , self . hidden_dim ),
+68 nn . ReLU (),
+69 * get_n_layers ( n_layers , self . hidden_dim ),
+70 nn . Linear ( self . hidden_dim , self . out_dim ),
+71 )
+
+
+
+
Initializes the decoder.
+
+
Arguments:
+
+
+in_dim (int
): input dimension
+hidden_dim (int
): hidden dimension
+out_dim (int
): output dimension
+n_layers (int
): number of hidden layers. The total number of layers
+will be n_layers
+ 2 (one before, one after).
+
+
+
+
+
+
+
+ in_dim
+
+
+
+
+
+
+
+
+
+
+
+ hidden_dim
+
+
+
+
+
+
+
+
+
+
+
+ out_dim
+
+
+
+
+
+
+
+
+
+
+
+ n_layers
+
+
+
+
+
+
number of hidden layers. The total number of layers
+ will be n_layers
+ 2 (one before, one after).
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.container.Sequential
+ pop
+ forward
+ append
+ insert
+ extend
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+
class
+
GWEncoder (GWDecoder ):
+
+ View Source
+
+
+
+ 74 class GWEncoder ( GWDecoder ):
+ 75 """
+ 76 An Encoder network used in GWModules.
+ 77
+ 78 This is similar to the decoder, but adds a tanh non-linearity at the end.
+ 79 """
+ 80
+ 81 def __init__ (
+ 82 self ,
+ 83 in_dim : int ,
+ 84 hidden_dim : int ,
+ 85 out_dim : int ,
+ 86 n_layers : int ,
+ 87 ):
+ 88 """
+ 89 Initializes the encoder.
+ 90
+ 91 Args:
+ 92 in_dim (`int`): input dimension
+ 93 hidden_dim (`int`): hidden dimension
+ 94 out_dim (`int`): output dimension
+ 95 n_layers (`int`): number of hidden layers. The total number of layers
+ 96 will be `n_layers` + 2 (one before, one after).
+ 97 """
+ 98 super () . __init__ ( in_dim , hidden_dim , out_dim , n_layers )
+ 99
+100 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+101 return super () . forward ( input )
+
+
+
+ An Encoder network used in GWModules.
+
+
This is similar to the decoder, but adds a tanh non-linearity at the end.
+
+
+
+
+
+
+
+ GWEncoder (in_dim : int , hidden_dim : int , out_dim : int , n_layers : int )
+
+ View Source
+
+
+
+
81 def __init__ (
+82 self ,
+83 in_dim : int ,
+84 hidden_dim : int ,
+85 out_dim : int ,
+86 n_layers : int ,
+87 ):
+88 """
+89 Initializes the encoder.
+90
+91 Args:
+92 in_dim (`int`): input dimension
+93 hidden_dim (`int`): hidden dimension
+94 out_dim (`int`): output dimension
+95 n_layers (`int`): number of hidden layers. The total number of layers
+96 will be `n_layers` + 2 (one before, one after).
+97 """
+98 super () . __init__ ( in_dim , hidden_dim , out_dim , n_layers )
+
+
+
+
Initializes the encoder.
+
+
Arguments:
+
+
+in_dim (int
): input dimension
+hidden_dim (int
): hidden dimension
+out_dim (int
): output dimension
+n_layers (int
): number of hidden layers. The total number of layers
+will be n_layers
+ 2 (one before, one after).
+
+
+
+
+
+
+
+
+
+ def
+ forward (self , input : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+
100 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+101 return super () . forward ( input )
+
+
+
+
Define the computation performed at every call.
+
+
Should be overridden by all subclasses.
+
+
+
+
Although the recipe for forward pass needs to be defined within
+this function, one should call the Module
instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
+
+
+
+
+
+
+
+
Inherited Members
+
+
+
torch.nn.modules.container.Sequential
+ pop
+ append
+ insert
+ extend
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ class
+ GWEncoderLinear (torch.nn.modules.linear.Linear ):
+
+ View Source
+
+
+
+ 104 class GWEncoderLinear ( nn . Linear ):
+105 """A linear Encoder network used in GWModules."""
+106
+107 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+108 return torch . tanh ( super () . forward ( input ))
+
+
+
+ A linear Encoder network used in GWModules.
+
+
+
+
+
+
+
+ def
+ forward (self , input : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+
107 def forward ( self , input : torch . Tensor ) -> torch . Tensor :
+108 return torch . tanh ( super () . forward ( input ))
+
+
+
+
Define the computation performed at every call.
+
+
Should be overridden by all subclasses.
+
+
+
+
Although the recipe for forward pass needs to be defined within
+this function, one should call the Module
instance afterwards
+instead of this since the former takes care of running the
+registered hooks while the latter silently ignores them.
+
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.linear.Linear
+ Linear
+ in_features
+ out_features
+ weight
+ reset_parameters
+
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ compile
+
+
+
+
+
+
+
+
+
+ class
+ GWModuleBase (torch.nn.modules.module.Module , abc.ABC ):
+
+ View Source
+
+
+
+ 111 class GWModuleBase ( nn . Module , ABC ):
+112 """
+113 Base class for GWModule.
+114
+115 GWModule handles encoding, decoding the unimodal representations
+116 using the `gw_encoders` and`gw_decoders`, and define
+117 some common operations in GW like cycles and translations.
+118
+119 This is an abstract class and should be implemented.
+120 For an implemented interface, see `GWModule`.
+121 """
+122
+123 def __init__ (
+124 self ,
+125 domain_mods : Mapping [ str , DomainModule ],
+126 workspace_dim : int ,
+127 * args ,
+128 ** kwargs ,
+129 ) -> None :
+130 """
+131 Initializes the GWModule.
+132
+133 Args:
+134 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+135 workspace_dim (`int`): dimension of the GW.
+136 """
+137 super () . __init__ ()
+138
+139 self . domain_mods = domain_mods
+140 """The unimodal domain modules."""
+141
+142 self . workspace_dim = workspace_dim
+143 """Dimension of the GW"""
+144
+145 @abstractmethod
+146 def fuse (
+147 self , x : LatentsDomainGroupT , selection_scores : Mapping [ str , torch . Tensor ]
+148 ) -> torch . Tensor :
+149 """
+150 Merge function used to combine domains.
+151
+152 Args:
+153 x (`LatentsDomainGroupT`): the group of latent representation.
+154 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+155 use to encode the reprensetation.
+156 Returns:
+157 `torch.Tensor`: The merged representation.
+158 """
+159 ...
+160
+161 @abstractmethod
+162 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+163 """
+164 Encode the latent representation infos to the pre-fusion GW representation.
+165
+166 Args:
+167 x (`LatentsDomainGroupT`): the input domain representations
+168
+169 Returns:
+170 `LatentsDomainGroupT`: pre-fusion GW representations
+171 """
+172 ...
+173
+174 def encode_and_fuse (
+175 self , x : LatentsDomainGroupT , selection_module : SelectionBase
+176 ) -> torch . Tensor :
+177 """
+178 Encode the latent representation infos to the final GW representation.
+179 It combines the encode and fuse methods.
+180
+181 Args:
+182 x (`LatentsDomainGroupT`): the input domain representations
+183 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+184 use to encode the reprensetation.
+185
+186 Returns:
+187 `torch.Tensor`: The merged representation.
+188 """
+189 encodings = self . encode ( x )
+190 selection_scores = selection_module ( x , encodings )
+191 return self . fuse ( encodings , selection_scores )
+192
+193 @abstractmethod
+194 def decode (
+195 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+196 ) -> LatentsDomainGroupDT :
+197 """
+198 Decode the GW representation into given `domains`.
+199
+200 Args:
+201 z (`torch.Tensor`): the GW representation.
+202 domains (`Iterable[str]`): iterable of domains to decode.
+203
+204 Returns:
+205 `LatentsDomainGroupDT`: the decoded unimodal representations.
+206 """
+207 ...
+
+
+
+ Base class for GWModule.
+
+
GWModule handles encoding, decoding the unimodal representations
+using the gw_encoders
andgw_decoders
, and define
+some common operations in GW like cycles and translations.
+
+
This is an abstract class and should be implemented.
+For an implemented interface, see GWModule
.
+
+
+
+
+
+
+
+
123 def __init__ (
+124 self ,
+125 domain_mods : Mapping [ str , DomainModule ],
+126 workspace_dim : int ,
+127 * args ,
+128 ** kwargs ,
+129 ) -> None :
+130 """
+131 Initializes the GWModule.
+132
+133 Args:
+134 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+135 workspace_dim (`int`): dimension of the GW.
+136 """
+137 super () . __init__ ()
+138
+139 self . domain_mods = domain_mods
+140 """The unimodal domain modules."""
+141
+142 self . workspace_dim = workspace_dim
+143 """Dimension of the GW"""
+
+
+
+
Initializes the GWModule.
+
+
Arguments:
+
+
+domain_modules (Mapping[str, DomainModule]
): the domain modules.
+workspace_dim (int
): dimension of the GW.
+
+
+
+
+
+
+
+ domain_mods
+
+
+
+
+
+
The unimodal domain modules.
+
+
+
+
+
+
+ workspace_dim
+
+
+
+
+
+
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
fuse ( self , x : collections . abc . Mapping [ str , torch . Tensor ] , selection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor :
+
+
View Source
+
+
+
+
145 @abstractmethod
+146 def fuse (
+147 self , x : LatentsDomainGroupT , selection_scores : Mapping [ str , torch . Tensor ]
+148 ) -> torch . Tensor :
+149 """
+150 Merge function used to combine domains.
+151
+152 Args:
+153 x (`LatentsDomainGroupT`): the group of latent representation.
+154 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+155 use to encode the reprensetation.
+156 Returns:
+157 `torch.Tensor`: The merged representation.
+158 """
+159 ...
+
+
+
+
Merge function used to combine domains.
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the group of latent representation.
+selection_score (Mapping[str, torch.Tensor]
): attention scores to
+use to encode the reprensetation.
+
+
+
Returns:
+
+
+ torch.Tensor
: The merged representation.
+
+
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
encode ( self , x : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+
View Source
+
+
+
+
161 @abstractmethod
+162 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+163 """
+164 Encode the latent representation infos to the pre-fusion GW representation.
+165
+166 Args:
+167 x (`LatentsDomainGroupT`): the input domain representations
+168
+169 Returns:
+170 `LatentsDomainGroupT`: pre-fusion GW representations
+171 """
+172 ...
+
+
+
+
Encode the latent representation infos to the pre-fusion GW representation.
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the input domain representations
+
+
+
Returns:
+
+
+ LatentsDomainGroupT
: pre-fusion GW representations
+
+
+
+
+
+
+
+
+
+
174 def encode_and_fuse (
+175 self , x : LatentsDomainGroupT , selection_module : SelectionBase
+176 ) -> torch . Tensor :
+177 """
+178 Encode the latent representation infos to the final GW representation.
+179 It combines the encode and fuse methods.
+180
+181 Args:
+182 x (`LatentsDomainGroupT`): the input domain representations
+183 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+184 use to encode the reprensetation.
+185
+186 Returns:
+187 `torch.Tensor`: The merged representation.
+188 """
+189 encodings = self . encode ( x )
+190 selection_scores = selection_module ( x , encodings )
+191 return self . fuse ( encodings , selection_scores )
+
+
+
+
Encode the latent representation infos to the final GW representation.
+It combines the encode and fuse methods.
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the input domain representations
+selection_score (Mapping[str, torch.Tensor]
): attention scores to
+use to encode the reprensetation.
+
+
+
Returns:
+
+
+ torch.Tensor
: The merged representation.
+
+
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
decode ( self , z : torch . Tensor , domains : collections . abc . Iterable [ str ] | None = None ) -> dict [ str , torch . Tensor ] :
+
+
View Source
+
+
+
+
193 @abstractmethod
+194 def decode (
+195 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+196 ) -> LatentsDomainGroupDT :
+197 """
+198 Decode the GW representation into given `domains`.
+199
+200 Args:
+201 z (`torch.Tensor`): the GW representation.
+202 domains (`Iterable[str]`): iterable of domains to decode.
+203
+204 Returns:
+205 `LatentsDomainGroupDT`: the decoded unimodal representations.
+206 """
+207 ...
+
+
+
+
Decode the GW representation into given domains
.
+
+
Arguments:
+
+
+z (torch.Tensor
): the GW representation.
+domains (Iterable[str]
): iterable of domains to decode.
+
+
+
Returns:
+
+
+ LatentsDomainGroupDT
: the decoded unimodal representations.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ 210 class GWModule ( GWModuleBase ):
+211 """GW nn.Module. Implements `GWModuleBase`."""
+212
+213 def __init__ (
+214 self ,
+215 domain_modules : Mapping [ str , DomainModule ],
+216 workspace_dim : int ,
+217 gw_encoders : Mapping [ str , nn . Module ],
+218 gw_decoders : Mapping [ str , nn . Module ],
+219 ) -> None :
+220 """
+221 Initializes the GWModule.
+222
+223 Args:
+224 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+225 workspace_dim (`int`): dimension of the GW.
+226 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+227 name to a an torch.nn.Module class that encodes a
+228 unimodal latent representations into a GW representation (pre fusion).
+229 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+230 name to a an torch.nn.Module class that decodes a
+231 GW representation to a unimodal latent representation.
+232 """
+233 super () . __init__ ( domain_modules , workspace_dim )
+234
+235 self . gw_encoders = nn . ModuleDict ( gw_encoders )
+236 """The module's encoders"""
+237
+238 self . gw_decoders = nn . ModuleDict ( gw_decoders )
+239 """The module's decoders"""
+240
+241 def fuse (
+242 self ,
+243 x : LatentsDomainGroupT ,
+244 selection_scores : Mapping [ str , torch . Tensor ],
+245 ) -> torch . Tensor :
+246 """
+247 Merge function used to combine domains.
+248
+249 Args:
+250 x (`LatentsDomainGroupT`): the group of latent representation.
+251 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+252 use to encode the reprensetation.
+253 Returns:
+254 `torch.Tensor`: The merged representation.
+255 """
+256 return torch . tanh (
+257 torch . sum (
+258 torch . stack (
+259 [
+260 selection_scores [ domain ] . unsqueeze ( 1 ) * x [ domain ]
+261 for domain in selection_scores
+262 ]
+263 ),
+264 dim = 0 ,
+265 )
+266 )
+267
+268 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+269 """
+270 Encode the latent representation infos to the pre-fusion GW representation.
+271
+272 Args:
+273 x (`LatentsDomainGroupT`): the input domain representations.
+274
+275 Returns:
+276 `LatentsDomainGroupT`: pre-fusion representation
+277 """
+278 return {
+279 domain_name : self . gw_encoders [ domain_name ]( domain )
+280 for domain_name , domain in x . items ()
+281 }
+282
+283 def decode (
+284 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+285 ) -> LatentsDomainGroupDT :
+286 """
+287 Decodes a GW representation to multiple domains.
+288
+289 Args:
+290 z (`torch.Tensor`): the GW representation
+291 domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+292 use keys in `gw_interfaces` (all domains).
+293 Returns:
+294 `LatentsDomainGroupDT`: decoded unimodal representation
+295 """
+296 return {
+297 domain : self . gw_decoders [ domain ]( z )
+298 for domain in domains or self . gw_decoders . keys ()
+299 }
+
+
+
+
+
+
+
+
+
+
+
GWModule ( domain_modules : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , workspace_dim : int , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] )
+
+
View Source
+
+
+
+
213 def __init__ (
+214 self ,
+215 domain_modules : Mapping [ str , DomainModule ],
+216 workspace_dim : int ,
+217 gw_encoders : Mapping [ str , nn . Module ],
+218 gw_decoders : Mapping [ str , nn . Module ],
+219 ) -> None :
+220 """
+221 Initializes the GWModule.
+222
+223 Args:
+224 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+225 workspace_dim (`int`): dimension of the GW.
+226 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+227 name to a an torch.nn.Module class that encodes a
+228 unimodal latent representations into a GW representation (pre fusion).
+229 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+230 name to a an torch.nn.Module class that decodes a
+231 GW representation to a unimodal latent representation.
+232 """
+233 super () . __init__ ( domain_modules , workspace_dim )
+234
+235 self . gw_encoders = nn . ModuleDict ( gw_encoders )
+236 """The module's encoders"""
+237
+238 self . gw_decoders = nn . ModuleDict ( gw_decoders )
+239 """The module's decoders"""
+
+
+
+
Initializes the GWModule.
+
+
Arguments:
+
+
+domain_modules (Mapping[str, DomainModule]
): the domain modules.
+workspace_dim (int
): dimension of the GW.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a an torch.nn.Module class that encodes a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a an torch.nn.Module class that decodes a
+ GW representation to a unimodal latent representation.
+
+
+
+
+
+
+
+ gw_encoders
+
+
+
+
+
+
+
+
+
+
+
+ gw_decoders
+
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ fuse ( self , x : collections . abc . Mapping [ str , torch . Tensor ] , selection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor :
+
+ View Source
+
+
+
+
241 def fuse (
+242 self ,
+243 x : LatentsDomainGroupT ,
+244 selection_scores : Mapping [ str , torch . Tensor ],
+245 ) -> torch . Tensor :
+246 """
+247 Merge function used to combine domains.
+248
+249 Args:
+250 x (`LatentsDomainGroupT`): the group of latent representation.
+251 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+252 use to encode the reprensetation.
+253 Returns:
+254 `torch.Tensor`: The merged representation.
+255 """
+256 return torch . tanh (
+257 torch . sum (
+258 torch . stack (
+259 [
+260 selection_scores [ domain ] . unsqueeze ( 1 ) * x [ domain ]
+261 for domain in selection_scores
+262 ]
+263 ),
+264 dim = 0 ,
+265 )
+266 )
+
+
+
+
Merge function used to combine domains.
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the group of latent representation.
+selection_score (Mapping[str, torch.Tensor]
): attention scores to
+use to encode the reprensetation.
+
+
+
Returns:
+
+
+ torch.Tensor
: The merged representation.
+
+
+
+
+
+
+
+
+
+ def
+ encode ( self , x : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
268 def encode ( self , x : LatentsDomainGroupT ) -> LatentsDomainGroupDT :
+269 """
+270 Encode the latent representation infos to the pre-fusion GW representation.
+271
+272 Args:
+273 x (`LatentsDomainGroupT`): the input domain representations.
+274
+275 Returns:
+276 `LatentsDomainGroupT`: pre-fusion representation
+277 """
+278 return {
+279 domain_name : self . gw_encoders [ domain_name ]( domain )
+280 for domain_name , domain in x . items ()
+281 }
+
+
+
+
Encode the latent representation infos to the pre-fusion GW representation.
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the input domain representations.
+
+
+
Returns:
+
+
+ LatentsDomainGroupT
: pre-fusion representation
+
+
+
+
+
+
+
+
+
+ def
+ decode ( self , z : torch . Tensor , domains : collections . abc . Iterable [ str ] | None = None ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
283 def decode (
+284 self , z : torch . Tensor , domains : Iterable [ str ] | None = None
+285 ) -> LatentsDomainGroupDT :
+286 """
+287 Decodes a GW representation to multiple domains.
+288
+289 Args:
+290 z (`torch.Tensor`): the GW representation
+291 domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+292 use keys in `gw_interfaces` (all domains).
+293 Returns:
+294 `LatentsDomainGroupDT`: decoded unimodal representation
+295 """
+296 return {
+297 domain : self . gw_decoders [ domain ]( z )
+298 for domain in domains or self . gw_decoders . keys ()
+299 }
+
+
+
+
Decodes a GW representation to multiple domains.
+
+
Arguments:
+
+
+z (torch.Tensor
): the GW representation
+domains (Iterable[str] | None
): the domains to decode to. Defaults to
+use keys in gw_interfaces
(all domains).
+
+
+
Returns:
+
+
+ LatentsDomainGroupDT
: decoded unimodal representation
+
+
+
+
+
+
+
Inherited Members
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ def
+ compute_fusion_scores ( score_1 : torch . Tensor , score_2 : torch . Tensor , sensitivity_1 : float = 1.0 , sensitivity_2 : float = 1.0 , eps : float = 1e-06 ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 302 def compute_fusion_scores (
+303 score_1 : torch . Tensor ,
+304 score_2 : torch . Tensor ,
+305 sensitivity_1 : float = 1.0 ,
+306 sensitivity_2 : float = 1.0 ,
+307 eps : float = 1e-6 ,
+308 ) -> torch . Tensor :
+309 """
+310 Combine precision scores using std summation in quadrature
+311
+312 The two scores should have the same dimension.
+313
+314 Args:
+315 score_1 (`torch.Tensor`): First scores.
+316 score_2 (`torch.Tensor`): Second scores.
+317 sensitivity_1 (`float`): sensitivity for the first score
+318 sensitivity_2 (`float`): sensitivity for the second score
+319 eps (`float`): a value added to avoid numerical unstability.
+320
+321 Returns:
+322 `torch.Tensor`: the combined scores
+323 """
+324 total_uncertainty = sensitivity_1 / ( eps + score_1 ) + sensitivity_2 / (
+325 eps + score_2
+326 )
+327 final_scores = 1 / ( eps + total_uncertainty )
+328 return final_scores / final_scores . sum ( dim = 0 , keepdim = True )
+
+
+
+ Combine precision scores using std summation in quadrature
+
+
The two scores should have the same dimension.
+
+
Arguments:
+
+
+score_1 (torch.Tensor
): First scores.
+score_2 (torch.Tensor
): Second scores.
+sensitivity_1 (float
): sensitivity for the first score
+sensitivity_2 (float
): sensitivity for the second score
+eps (float
): a value added to avoid numerical unstability.
+
+
+
Returns:
+
+
+ torch.Tensor
: the combined scores
+
+
+
+
+
+
+
+
+
+
class
+
GWModuleBayesian (GWModule ):
+
+ View Source
+
+
+
+ 331 class GWModuleBayesian ( GWModule ):
+332 """`GWModule` with a Bayesian based uncertainty prediction."""
+333
+334 def __init__ (
+335 self ,
+336 domain_modules : Mapping [ str , DomainModule ],
+337 workspace_dim : int ,
+338 gw_encoders : Mapping [ str , nn . Module ],
+339 gw_decoders : Mapping [ str , nn . Module ],
+340 sensitivity_selection : float = 1 ,
+341 sensitivity_precision : float = 1 ,
+342 precision_softmax_temp : float = 0.01 ,
+343 ) -> None :
+344 """
+345 Initializes the GWModuleBayesian.
+346
+347 Args:
+348 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+349 workspace_dim (`int`): dimension of the GW.
+350 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+351 name to a an torch.nn.Module class that encodes a
+352 unimodal latent representations into a GW representation (pre fusion).
+353 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+354 name to a an torch.nn.Module class that decodes a
+355 GW representation to a unimodal latent representation.
+356 sensitivity_selection (`float`): sensivity coef $c'_1$
+357 sensitivity_precision (`float`): sensitivity coef $c'_2$
+358 precision_softmax_temp (`float`): temperature to use in softmax of
+359 precision
+360 """
+361 super () . __init__ ( domain_modules , workspace_dim , gw_encoders , gw_decoders )
+362
+363 self . precisions = cast (
+364 dict [ str , torch . Tensor ],
+365 nn . ParameterDict (
+366 { domain : torch . randn ( workspace_dim ) for domain in gw_encoders }
+367 ),
+368 )
+369 """Precision at the neuron level for every domain."""
+370
+371 self . sensitivity_selection = sensitivity_selection
+372 self . sensitivity_precision = sensitivity_precision
+373 self . precision_softmax_temp = precision_softmax_temp
+374
+375 def get_precision ( self , domain : str , x : torch . Tensor ) -> torch . Tensor :
+376 """
+377 Get the precision vector of given domain and batch
+378
+379 Args:
+380 domain (`str`):
+381 x (`torch.Tensor`): batch of inputs
+382
+383 Returns:
+384 `torch.Tensor`: batch of precision
+385 """
+386 return self . precisions [ domain ] . unsqueeze ( 0 ) . expand ( x . size ( 0 ), - 1 )
+387
+388 def fuse (
+389 self ,
+390 x : LatentsDomainGroupT ,
+391 selection_scores : Mapping [ str , torch . Tensor ],
+392 ) -> torch . Tensor :
+393 """
+394 Merge function used to combine domains.
+395
+396 In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+397 dimension of the Global Workspace.
+398
+399 This function needs to merge two kind of scores:
+400 * the selection scores $a\\in [0,1]^{D\\times N}$;
+401 * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+402
+403 .. note::
+404 The precision score is obtained by predicting logits and using a softmax
+405
+406 We can obtain associated uncertainties to the scores by introducing a std
+407 variable and using bayesian integration:
+408
+409 $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+410 where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+411
+412 Similarly,
+413 $$b_k = \\frac{M_2}{\\mu_k^2}$$
+414 where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+415
+416 The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+417 $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+418
+419 which, in terms of $a_k$ and $b_k$ yields:
+420 $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+421 where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+422
+423 Finally, the finale combined coefficient is
+424 $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+425 where
+426 $$M_3 = \\frac{1}{\\sum_{i=1}^D
+427 \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+428
+429 Args:
+430 x (`LatentsDomainGroupT`): the group of latent representation.
+431 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+432 use to encode the reprensetation.
+433 Returns:
+434 `torch.Tensor`: The merged representation.
+435 """
+436 scores : list [ torch . Tensor ] = []
+437 precisions : list [ torch . Tensor ] = []
+438 domains : list [ torch . Tensor ] = []
+439 for domain , score in selection_scores . items ():
+440 scores . append ( score )
+441 precisions . append ( self . get_precision ( domain , x [ domain ]))
+442 domains . append ( x [ domain ])
+443 combined_scores = compute_fusion_scores (
+444 torch . stack ( scores ) . unsqueeze ( - 1 ),
+445 torch . softmax (
+446 torch . tanh ( torch . stack ( precisions )) * self . precision_softmax_temp , dim = 0
+447 ),
+448 self . sensitivity_selection ,
+449 self . sensitivity_precision ,
+450 )
+451 return torch . tanh (
+452 torch . sum (
+453 combined_scores * torch . stack ( domains ),
+454 dim = 0 ,
+455 )
+456 )
+
+
+
+ GWModule
with a Bayesian based uncertainty prediction.
+
+
+
+
+
+
+
+
GWModuleBayesian ( domain_modules : collections . abc . Mapping [ str , shimmer.modules.domain.DomainModule ] , workspace_dim : int , gw_encoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , gw_decoders : collections . abc . Mapping [ str , torch . nn . modules . module . Module ] , sensitivity_selection : float = 1 , sensitivity_precision : float = 1 , precision_softmax_temp : float = 0.01 )
+
+
View Source
+
+
+
+
334 def __init__ (
+335 self ,
+336 domain_modules : Mapping [ str , DomainModule ],
+337 workspace_dim : int ,
+338 gw_encoders : Mapping [ str , nn . Module ],
+339 gw_decoders : Mapping [ str , nn . Module ],
+340 sensitivity_selection : float = 1 ,
+341 sensitivity_precision : float = 1 ,
+342 precision_softmax_temp : float = 0.01 ,
+343 ) -> None :
+344 """
+345 Initializes the GWModuleBayesian.
+346
+347 Args:
+348 domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+349 workspace_dim (`int`): dimension of the GW.
+350 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+351 name to a an torch.nn.Module class that encodes a
+352 unimodal latent representations into a GW representation (pre fusion).
+353 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+354 name to a an torch.nn.Module class that decodes a
+355 GW representation to a unimodal latent representation.
+356 sensitivity_selection (`float`): sensivity coef $c'_1$
+357 sensitivity_precision (`float`): sensitivity coef $c'_2$
+358 precision_softmax_temp (`float`): temperature to use in softmax of
+359 precision
+360 """
+361 super () . __init__ ( domain_modules , workspace_dim , gw_encoders , gw_decoders )
+362
+363 self . precisions = cast (
+364 dict [ str , torch . Tensor ],
+365 nn . ParameterDict (
+366 { domain : torch . randn ( workspace_dim ) for domain in gw_encoders }
+367 ),
+368 )
+369 """Precision at the neuron level for every domain."""
+370
+371 self . sensitivity_selection = sensitivity_selection
+372 self . sensitivity_precision = sensitivity_precision
+373 self . precision_softmax_temp = precision_softmax_temp
+
+
+
+
Initializes the GWModuleBayesian.
+
+
Arguments:
+
+
+domain_modules (Mapping[str, DomainModule]
): the domain modules.
+workspace_dim (int
): dimension of the GW.
+gw_encoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a an torch.nn.Module class that encodes a
+unimodal latent representations into a GW representation (pre fusion).
+gw_decoders (Mapping[str, torch.nn.Module]
): mapping for each domain
+name to a an torch.nn.Module class that decodes a
+ GW representation to a unimodal latent representation.
+sensitivity_selection (float
): sensivity coef $c'_1$
+sensitivity_precision (float
): sensitivity coef $c'_2$
+precision_softmax_temp (float
): temperature to use in softmax of
+precision
+
+
+
+
+
+
+
+ precisions
+
+
+
+
+
+
Precision at the neuron level for every domain.
+
+
+
+
+
+
+ sensitivity_selection
+
+
+
+
+
+
+
+
+
+
+ sensitivity_precision
+
+
+
+
+
+
+
+
+
+
+ precision_softmax_temp
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ get_precision (self , domain : str , x : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+
375 def get_precision ( self , domain : str , x : torch . Tensor ) -> torch . Tensor :
+376 """
+377 Get the precision vector of given domain and batch
+378
+379 Args:
+380 domain (`str`):
+381 x (`torch.Tensor`): batch of inputs
+382
+383 Returns:
+384 `torch.Tensor`: batch of precision
+385 """
+386 return self . precisions [ domain ] . unsqueeze ( 0 ) . expand ( x . size ( 0 ), - 1 )
+
+
+
+
Get the precision vector of given domain and batch
+
+
Arguments:
+
+
+domain (str
):
+x (torch.Tensor
): batch of inputs
+
+
+
Returns:
+
+
+ torch.Tensor
: batch of precision
+
+
+
+
+
+
+
+
+
+ def
+ fuse ( self , x : collections . abc . Mapping [ str , torch . Tensor ] , selection_scores : collections . abc . Mapping [ str , torch . Tensor ] ) -> torch . Tensor :
+
+ View Source
+
+
+
+
388 def fuse (
+389 self ,
+390 x : LatentsDomainGroupT ,
+391 selection_scores : Mapping [ str , torch . Tensor ],
+392 ) -> torch . Tensor :
+393 """
+394 Merge function used to combine domains.
+395
+396 In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+397 dimension of the Global Workspace.
+398
+399 This function needs to merge two kind of scores:
+400 * the selection scores $a\\in [0,1]^{D\\times N}$;
+401 * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+402
+403 .. note::
+404 The precision score is obtained by predicting logits and using a softmax
+405
+406 We can obtain associated uncertainties to the scores by introducing a std
+407 variable and using bayesian integration:
+408
+409 $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+410 where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+411
+412 Similarly,
+413 $$b_k = \\frac{M_2}{\\mu_k^2}$$
+414 where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+415
+416 The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+417 $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+418
+419 which, in terms of $a_k$ and $b_k$ yields:
+420 $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+421 where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+422
+423 Finally, the finale combined coefficient is
+424 $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+425 where
+426 $$M_3 = \\frac{1}{\\sum_{i=1}^D
+427 \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+428
+429 Args:
+430 x (`LatentsDomainGroupT`): the group of latent representation.
+431 selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+432 use to encode the reprensetation.
+433 Returns:
+434 `torch.Tensor`: The merged representation.
+435 """
+436 scores : list [ torch . Tensor ] = []
+437 precisions : list [ torch . Tensor ] = []
+438 domains : list [ torch . Tensor ] = []
+439 for domain , score in selection_scores . items ():
+440 scores . append ( score )
+441 precisions . append ( self . get_precision ( domain , x [ domain ]))
+442 domains . append ( x [ domain ])
+443 combined_scores = compute_fusion_scores (
+444 torch . stack ( scores ) . unsqueeze ( - 1 ),
+445 torch . softmax (
+446 torch . tanh ( torch . stack ( precisions )) * self . precision_softmax_temp , dim = 0
+447 ),
+448 self . sensitivity_selection ,
+449 self . sensitivity_precision ,
+450 )
+451 return torch . tanh (
+452 torch . sum (
+453 combined_scores * torch . stack ( domains ),
+454 dim = 0 ,
+455 )
+456 )
+
+
+
+
Merge function used to combine domains.
+
+
In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+dimension of the Global Workspace.
+
+
This function needs to merge two kind of scores:
+
+
+the selection scores $a\in [0,1]^{D\times N}$;
+the precision scores $b \in [0,1]^{D\times N \times d}$.
+
+
+
+
+
The precision score is obtained by predicting logits and using a softmax
+
+
+
+
We can obtain associated uncertainties to the scores by introducing a std
+variable and using bayesian integration:
+
+
$$a_k = \frac{M_1}{\sigma_k^2}$$
+where $M_1 = \frac{1}{\sum_{i=1}^D \frac{1}{\sigma_i^2}}$.
+
+
Similarly,
+$$b_k = \frac{M_2}{\mu_k^2}$$
+where $M_2 = \frac{1}{\sum_{i=1}^D \frac{1}{\mu_i^2}}$.
+
+
The we can sum the variances to obtain the final uncertainty (squared) $\xi$:
+$$\xi_k^2 = c_1 \sigma_k^2 + c_2 \mu_k^2$$
+
+
which, in terms of $a_k$ and $b_k$ yields:
+$$\xi_k^2 = \frac{c'_1}{a_k} + \frac{c'_2}{b_k}$$
+where $c'_1 = c_1 \cdot M_1$ and $c'_2 = c_2 \cdot M_2$.
+
+
Finally, the finale combined coefficient is
+$$\lambda_k = \frac{M_3}{\frac{c'_1}{a_k} + \frac{c'_2}{b_k}}$$
+where
+$$M_3 = \frac{1}{\sum_{i=1}^D
+ \frac{1}{\frac{c'_1}{a_i} + \frac{c'_2}{b_i}}}$$
+
+
Arguments:
+
+
+x (LatentsDomainGroupT
): the group of latent representation.
+selection_score (Mapping[str, torch.Tensor]
): attention scores to
+use to encode the reprensetation.
+
+
+
Returns:
+
+
+ torch.Tensor
: The merged representation.
+
+
+
+
+
+
+
Inherited Members
+
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/losses.html b/docs/api/v0.5.1/shimmer/modules/losses.html
new file mode 100644
index 00000000..4d306516
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/losses.html
@@ -0,0 +1,3888 @@
+
+
+
+
+
+
+ shimmer.modules.losses API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.losses
+
+
+
+
+ View Source
+
+ 1 from abc import ABC , abstractmethod
+ 2 from collections.abc import Generator , Mapping
+ 3 from itertools import product
+ 4 from typing import TypedDict
+ 5
+ 6 import torch
+ 7
+ 8 from shimmer.modules.contrastive_loss import ContrastiveLossType
+ 9 from shimmer.modules.domain import DomainModule , LossOutput
+ 10 from shimmer.modules.gw_module import (
+ 11 GWModule ,
+ 12 GWModuleBase ,
+ 13 GWModuleBayesian ,
+ 14 )
+ 15 from shimmer.modules.selection import SelectionBase
+ 16 from shimmer.types import LatentsDomainGroupsT , ModelModeT
+ 17
+ 18
+ 19 class GWLossesBase ( torch . nn . Module , ABC ):
+ 20 """
+ 21 Base Abstract Class for Global Workspace (GW) losses. This module is used
+ 22 to compute the different losses of the GW (typically translation, cycle,
+ 23 demi-cycle, contrastive losses).
+ 24 """
+ 25
+ 26 @abstractmethod
+ 27 def step (
+ 28 self ,
+ 29 domain_latents : LatentsDomainGroupsT ,
+ 30 mode : ModelModeT ,
+ 31 ) -> LossOutput :
+ 32 """
+ 33 Computes the losses.
+ 34
+ 35 Args:
+ 36 domain_latents (`LatentsDomainGroupsT`): All latent groups
+ 37 mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+ 38 Returns:
+ 39 `LossOutput`: the losses
+ 40 """
+ 41 ...
+ 42
+ 43
+ 44 def demi_cycle_loss (
+ 45 gw_mod : GWModuleBase ,
+ 46 selection_mod : SelectionBase ,
+ 47 domain_mods : Mapping [ str , DomainModule ],
+ 48 latent_domains : LatentsDomainGroupsT ,
+ 49 ) -> dict [ str , torch . Tensor ]:
+ 50 """
+ 51 Computes the demi-cycle loss.
+ 52
+ 53 This return multiple metrics:
+ 54 * `demi_cycle_{domain_name}` with the demi-cycle of a particular domain;
+ 55 * `demi_cycle_{domain_name}_{metric}` with additional metrics provided by
+ 56 the domain_mod's `compute_dcy_loss` output;
+ 57 * `demi_cycles` with the average value of all `demi_cycle_{domain_name}` values.
+ 58
+ 59 Args:
+ 60 gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+ 61 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+ 62 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+ 63 latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal
+ 64 groups
+ 65
+ 66 Returns:
+ 67 `dict[str, torch.Tensor]`: a dict of metrics.
+ 68 """
+ 69 losses : dict [ str , torch . Tensor ] = {}
+ 70 metrics : dict [ str , torch . Tensor ] = {}
+ 71 for domains , latents in latent_domains . items ():
+ 72 if len ( domains ) > 1 :
+ 73 continue
+ 74 domain_name = next ( iter ( domains ))
+ 75
+ 76 domain_mod = domain_mods [ domain_name ]
+ 77 x_recons = gw_mod . decode (
+ 78 gw_mod . encode_and_fuse ( latents , selection_mod ), domains = { domain_name }
+ 79 )[ domain_name ]
+ 80 loss_output = domain_mod . compute_dcy_loss ( x_recons , latents [ domain_name ])
+ 81 losses [ f "demi_cycle_ { domain_name } " ] = loss_output . loss
+ 82 metrics . update (
+ 83 { f "demi_cycle_ { domain_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+ 84 )
+ 85 losses [ "demi_cycles" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+ 86 losses . update ( metrics )
+ 87 return losses
+ 88
+ 89
+ 90 def cycle_loss (
+ 91 gw_mod : GWModuleBase ,
+ 92 selection_mod : SelectionBase ,
+ 93 domain_mods : Mapping [ str , DomainModule ],
+ 94 latent_domains : LatentsDomainGroupsT ,
+ 95 ) -> dict [ str , torch . Tensor ]:
+ 96 """
+ 97 Computes the cycle loss.
+ 98
+ 99 This return multiple metrics:
+100 * `cycle_{domain_source}_through_{domain_target}` with the cycle of
+101 a particular domain;
+102 * `cycle_{domain_source}_through_{domain_target}_{metric}` with additional
+103 metrics provided by the domain_mod's `compute_cy_loss` output;
+104 * `cycles` with the average value of all
+105 `cycle_{domain_source}_through_{domain_target}` values.
+106
+107 Args:
+108 gw_mod (`GWModuleBase`): The GWModule to use
+109 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+110 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+111 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+112
+113 Returns:
+114 `dict[str, torch.Tensor]`: a dict of metrics.
+115 """
+116 losses : dict [ str , torch . Tensor ] = {}
+117 metrics : dict [ str , torch . Tensor ] = {}
+118 for domains_source , latents_source in latent_domains . items ():
+119 if len ( domains_source ) > 1 :
+120 continue
+121 domain_name_source = list ( domains_source )[ 0 ]
+122
+123 domain_mod = domain_mods [ domain_name_source ]
+124 z = gw_mod . encode_and_fuse ( latents_source , selection_mod )
+125 for domain_name_target in domain_mods :
+126 if domain_name_target == domain_name_source :
+127 continue
+128
+129 x_pred = gw_mod . decode ( z , domains = { domain_name_target })
+130
+131 x_recons = gw_mod . decode (
+132 gw_mod . encode_and_fuse ( x_pred , selection_mod ),
+133 domains = { domain_name_source },
+134 )
+135
+136 loss_name = f " { domain_name_source } _through_ { domain_name_target } "
+137 loss_output = domain_mod . compute_cy_loss (
+138 x_recons [ domain_name_source ],
+139 latents_source [ domain_name_source ],
+140 )
+141 metrics . update (
+142 { f "cycle_ { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+143 )
+144 losses [ f "cycle_ { loss_name } " ] = loss_output . loss
+145 losses [ "cycles" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+146 losses . update ( metrics )
+147 return losses
+148
+149
+150 def translation_loss (
+151 gw_mod : GWModuleBase ,
+152 selection_mod : SelectionBase ,
+153 domain_mods : Mapping [ str , DomainModule ],
+154 latent_domains : LatentsDomainGroupsT ,
+155 ) -> dict [ str , torch . Tensor ]:
+156 """
+157 Computes the translation loss.
+158
+159 This return multiple metrics:
+160 * `translation_{domain_source}_to_{domain_target}` with the translation
+161 from a domain source to a domain target;
+162 * `translation_{domain_source}_to_{domain_target}_{metric}` with
+163 additional metrics provided by the domain_mod's
+164 `compute_tr_loss` output;
+165 * `translations` with the average value of all
+166 `translation_{domain_source}_to_{domain_target}` values.
+167
+168 Args:
+169 gw_mod (`GWModuleBase`): The GWModule to use
+170 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+171 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+172
+173 Returns:
+174 `dict[str, torch.Tensor]`: a dict of metrics.
+175 """
+176 losses : dict [ str , torch . Tensor ] = {}
+177 metrics : dict [ str , torch . Tensor ] = {}
+178 for domains , latents in latent_domains . items ():
+179 if len ( domains ) < 2 :
+180 continue
+181 for domain_name_target in domains :
+182 domain_sources = {
+183 domain : latents [ domain ]
+184 for domain in domains
+185 if domain != domain_name_target
+186 }
+187
+188 z = gw_mod . encode_and_fuse ( domain_sources , selection_mod )
+189 mod = domain_mods [ domain_name_target ]
+190
+191 domain_source_names = "/" . join ( domain_sources . keys ())
+192 loss_name = f " { domain_source_names } _to_ { domain_name_target } "
+193 if loss_name in losses :
+194 raise ValueError ( f " { loss_name } is already computed." )
+195
+196 prediction = gw_mod . decode ( z , domains = { domain_name_target })[
+197 domain_name_target
+198 ]
+199 loss_output = mod . compute_tr_loss (
+200 prediction ,
+201 latents [ domain_name_target ],
+202 )
+203 losses [ f "translation_ { loss_name } " ] = loss_output . loss
+204 metrics . update (
+205 {
+206 f "translation_ { loss_name } _ { k } " : v
+207 for k , v in loss_output . metrics . items ()
+208 }
+209 )
+210 losses [ "translations" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+211 losses . update ( metrics )
+212 return losses
+213
+214
+215 def contrastive_loss (
+216 gw_mod : GWModuleBase ,
+217 latent_domains : LatentsDomainGroupsT ,
+218 contrastive_fn : ContrastiveLossType ,
+219 ) -> dict [ str , torch . Tensor ]:
+220 """
+221 Computes the contrastive loss.
+222
+223 This return multiple metrics:
+224 * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+225 between 2 domains;
+226 * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+227 additional metrics provided by the domain_mod's
+228 `compute_cont_loss` output;
+229 * `contrastives` with the average value of all
+230 `contrastive_{domain_1}_and_{domain_2}` values.
+231
+232 Args:
+233 gw_mod (`GWModuleBase`): The GWModule to use
+234 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+235 contrastive_fn (`ContrastiveLossType`): the contrastive function to apply
+236
+237 Returns:
+238 `dict[str, torch.Tensor]`: a dict of metrics.
+239 """
+240 losses : dict [ str , torch . Tensor ] = {}
+241 metrics : dict [ str , torch . Tensor ] = {}
+242 keys : list [ set [ str ]] = []
+243
+244 for latents in latent_domains . values ():
+245 if len ( latents ) != 2 :
+246 continue
+247
+248 cont_latents = gw_mod . encode ( latents )
+249 for domain1 , z1 in cont_latents . items ():
+250 for domain2 , z2 in cont_latents . items ():
+251 selected_domains = { domain1 , domain2 }
+252 if domain1 == domain2 or selected_domains in keys :
+253 continue
+254
+255 keys . append ( selected_domains )
+256
+257 loss_name = f "contrastive_ { domain1 } _and_ { domain2 } "
+258 loss_output = contrastive_fn ( z1 , z2 )
+259 losses [ loss_name ] = loss_output . loss
+260 metrics . update (
+261 { f " { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+262 )
+263
+264 losses [ "contrastives" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+265 losses . update ( metrics )
+266 return losses
+267
+268
+269 def contrastive_loss_bayesian (
+270 gw_mod : GWModuleBayesian ,
+271 latent_domains : LatentsDomainGroupsT ,
+272 contrastive_fn : ContrastiveLossType ,
+273 ) -> dict [ str , torch . Tensor ]:
+274 """
+275 Computes the contrastive loss with a Bayesian based uncertainty prediction.
+276
+277 This return multiple metrics:
+278 * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+279 between 2 domains;
+280 * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+281 additional metrics provided by the domain_mod's
+282 `compute_cont_loss` output;
+283 * `contrastives` with the average value of all
+284 `contrastive_{domain_1}_and_{domain_2}` values.
+285
+286 Args:
+287 gw_mod (`GWModuleBayesian`): The GWModule to use
+288 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+289 contrastive_fn (`ContrastiveLossBayesianType`): the contrastive function
+290 to apply
+291
+292 Returns:
+293 `dict[str, torch.Tensor]`: a dict of metrics.
+294 """
+295 losses : dict [ str , torch . Tensor ] = {}
+296 metrics : dict [ str , torch . Tensor ] = {}
+297 keys : list [ set [ str ]] = []
+298
+299 for latents in latent_domains . values ():
+300 if len ( latents ) < 2 :
+301 continue
+302 for domain1_name , domain1 in latents . items ():
+303 z1 = gw_mod . encode ({ domain1_name : domain1 })[ domain1_name ]
+304 z1_precision = gw_mod . get_precision ( domain1_name , domain1 )
+305 for domain2_name , domain2 in latents . items ():
+306 selected_domains = { domain1_name , domain2_name }
+307 if domain1_name == domain2_name or selected_domains in keys :
+308 continue
+309
+310 keys . append ( selected_domains )
+311
+312 loss_name = f "contrastive_ { domain1_name } _and_ { domain2_name } "
+313 z2 = gw_mod . encode ({ domain2_name : domain2 })[ domain2_name ]
+314 z2_precision = gw_mod . get_precision ( domain2_name , domain2 )
+315 coef = torch . softmax (
+316 gw_mod . precision_softmax_temp
+317 * torch . stack ([ z1_precision , z2_precision ]),
+318 dim = 0 ,
+319 )
+320 norm = torch . sqrt ( coef [ 0 ] * coef [ 1 ])
+321 loss_output = contrastive_fn ( z1 * norm , z2 * norm )
+322 loss_output_no_norm = contrastive_fn ( z1 , z2 )
+323 losses [ loss_name ] = loss_output . loss
+324 metrics . update (
+325 { f " { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+326 )
+327 metrics [ f "unnorm_ { loss_name } " ] = loss_output_no_norm . loss
+328
+329 losses [ "contrastives" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+330 losses . update ( metrics )
+331 return losses
+332
+333
+334 class LossCoefs ( TypedDict , total = False ):
+335 """
+336 Dict of loss coefficients used in the GWLosses.
+337
+338 If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+339 If the loss is excplicitely set to 0, it will be logged, but not take part in
+340 the total loss.
+341 """
+342
+343 demi_cycles : float
+344 """Demi-cycle loss coefficient."""
+345
+346 cycles : float
+347 """Cycle loss coefficient."""
+348
+349 translations : float
+350 """Translation loss coefficient."""
+351
+352 contrastives : float
+353 """Contrastive loss coefficient."""
+354
+355
+356 class GWLosses2Domains ( GWLossesBase ):
+357 """
+358 Implementation of `GWLossesBase` used for `GWModule`.
+359 """
+360
+361 def __init__ (
+362 self ,
+363 gw_mod : GWModule ,
+364 selection_mod : SelectionBase ,
+365 domain_mods : dict [ str , DomainModule ],
+366 loss_coefs : LossCoefs ,
+367 contrastive_fn : ContrastiveLossType ,
+368 ):
+369 """
+370 Main loss module to use with the GlobalWorkspace
+371
+372 Args:
+373 gw_mod (`GWModule`): the GWModule
+374 selection_mod (`SelectionBase`): selection module
+375 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+376 domain name and value is the DomainModule
+377 loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+378 mapping to float with correct keys.
+379 contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+380 in contrastive loss
+381 """
+382
+383 super () . __init__ ()
+384 self . gw_mod = gw_mod
+385 self . selection_mod = selection_mod
+386 self . domain_mods = domain_mods
+387 self . loss_coefs = loss_coefs
+388 self . contrastive_fn = contrastive_fn
+389
+390 def demi_cycle_loss (
+391 self , latent_domains : LatentsDomainGroupsT
+392 ) -> dict [ str , torch . Tensor ]:
+393 """
+394 Computes the demi-cycle loss.
+395
+396 See `shimmer.modules.losses.demi_cycle_loss`.
+397
+398 Args:
+399 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+400
+401 Returns:
+402 `dict[str, torch.Tensor]`: a dict of metrics.
+403 """
+404 return demi_cycle_loss (
+405 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+406 )
+407
+408 def cycle_loss (
+409 self , latent_domains : LatentsDomainGroupsT
+410 ) -> dict [ str , torch . Tensor ]:
+411 """
+412 Computes the cycle loss.
+413
+414 See `shimmer.modules.losses.cycle_loss`.
+415
+416 Args:
+417 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+418
+419 Returns:
+420 `dict[str, torch.Tensor]`: a dict of metrics.
+421 """
+422 return cycle_loss (
+423 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+424 )
+425
+426 def translation_loss (
+427 self , latent_domains : LatentsDomainGroupsT
+428 ) -> dict [ str , torch . Tensor ]:
+429 """
+430 Computes the translation loss.
+431
+432 See `shimmer.modules.losses.translation_loss`.
+433
+434 Args:
+435 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+436
+437 Returns:
+438 `dict[str, torch.Tensor]`: a dict of metrics.
+439 """
+440 return translation_loss (
+441 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+442 )
+443
+444 def contrastive_loss (
+445 self , latent_domains : LatentsDomainGroupsT
+446 ) -> dict [ str , torch . Tensor ]:
+447 """
+448 Computes the contrastive loss.
+449
+450 See `shimmer.modules.losses.contrastive_loss`.
+451
+452 Args:
+453 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+454
+455 Returns:
+456 `dict[str, torch.Tensor]`: a dict of metrics.
+457 """
+458 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+459
+460 def step (
+461 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+462 ) -> LossOutput :
+463 """
+464 Computes and returns the losses
+465
+466 Contains:
+467 - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+468 - Cycle metrics (see `GWLosses.cycle_loss`)
+469 - Translation metrics (see `GWLosses.translation_loss`)
+470 - Contrastive metrics (see `GWLosses.contrastive_loss`)
+471
+472 Args:
+473 domain_latents (`LatentsDomainGroupsT`): All latent groups
+474 mode (`ModelModeT`): model mode
+475 Returns:
+476 `LossOutput`: the losses
+477 """
+478 metrics : dict [ str , torch . Tensor ] = {}
+479
+480 metrics . update ( self . demi_cycle_loss ( domain_latents ))
+481 metrics . update ( self . cycle_loss ( domain_latents ))
+482 metrics . update ( self . translation_loss ( domain_latents ))
+483 metrics . update ( self . contrastive_loss ( domain_latents ))
+484
+485 loss = torch . stack (
+486 [
+487 metrics [ name ] * coef
+488 for name , coef in self . loss_coefs . items ()
+489 if isinstance ( coef , float ) and coef > 0
+490 ],
+491 dim = 0 ,
+492 ) . mean ()
+493
+494 return LossOutput ( loss , metrics )
+495
+496
+497 def generate_partitions ( n : int ) -> Generator [ tuple [ int , ... ], None , None ]:
+498 """
+499 Generates all possible partitions of zeros and ones for `n` elements,
+500 excluding the all-zeros partition.
+501
+502 Args:
+503 n (`int`): The number of modalities to generate partitions for.
+504
+505 Yields:
+506 `tuple[int, ...]`: A partition of zeros and ones, excluding the
+507 all-zeros partition.
+508 """
+509 for perm in product ([ 0 , 1 ], repeat = n ):
+510 if any ( perm ):
+511 yield perm
+512
+513
+514 def broadcast_loss (
+515 gw_mod : GWModuleBase ,
+516 selection_mod : SelectionBase ,
+517 domain_mods : Mapping [ str , DomainModule ],
+518 latent_domains : LatentsDomainGroupsT ,
+519 ) -> dict [ str , torch . Tensor ]:
+520 """
+521 Computes broadcast loss including demi-cycle, cycle, and translation losses.
+522
+523 Args:
+524 gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+525 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+526 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+527 latent_domains: The latent domain representations.
+528
+529 Returns:
+530 A dictionary with the total loss and additional metrics.
+531 """
+532 losses : dict [ str , torch . Tensor ] = {}
+533 metrics : dict [ str , torch . Tensor ] = {}
+534
+535 demi_cycle_losses : list [ str ] = []
+536 cycle_losses : list [ str ] = []
+537 translation_losses : list [ str ] = []
+538 fused_losses : list [ str ] = []
+539
+540 for group_domains , latents in latent_domains . items ():
+541 encoded_latents = gw_mod . encode ( latents )
+542 partitions = generate_partitions ( len ( group_domains ))
+543 domain_names = list ( latents )
+544 group_name = "-" . join ( group_domains )
+545
+546 for partition in partitions :
+547 selected_latents = {
+548 domain : latents [ domain ]
+549 for domain , present in zip ( domain_names , partition , strict = True )
+550 if present
+551 }
+552 selected_encoded_latents = {
+553 domain : encoded_latents [ domain ] for domain in selected_latents
+554 }
+555 selected_group_label = "{" + ", " . join ( sorted ( selected_latents )) + "}"
+556
+557 selection_scores = selection_mod ( selected_latents , selected_encoded_latents )
+558 fused_latents = gw_mod . fuse ( selected_encoded_latents , selection_scores )
+559 decoded_latents = gw_mod . decode ( fused_latents )
+560
+561 num_active_domains = sum ( partition )
+562 num_total_domains = len ( decoded_latents )
+563
+564 for domain , pred in decoded_latents . items ():
+565 if domain not in group_domains : # if we don't have ground truth
+566 continue
+567 ground_truth = latents [ domain ]
+568 loss_output = domain_mods [ domain ] . compute_loss ( pred , ground_truth )
+569 loss_label = f "from_ { selected_group_label } _to_ { domain } "
+570 losses [ loss_label + "_loss" ] = loss_output . loss
+571 metrics . update (
+572 { f " { loss_label } _ { k } " : v for k , v in loss_output . metrics . items ()}
+573 )
+574
+575 if num_active_domains == 1 and domain in selected_latents :
+576 demi_cycle_losses . append ( loss_label + "_loss" )
+577 elif domain not in selected_latents :
+578 translation_losses . append ( loss_label + "_loss" )
+579 else : # fused loss
+580 fused_losses . append ( loss_label + "_loss" )
+581
+582 if num_active_domains < num_total_domains :
+583 inverse_selected_latents = {
+584 domain : decoded_latents [ domain ]
+585 for domain in decoded_latents
+586 if domain not in selected_latents
+587 }
+588
+589 inverse_selected_group_label = (
+590 "{" + ", " . join ( sorted ( inverse_selected_latents )) + "}"
+591 )
+592
+593 re_encoded_latents = gw_mod . encode ( inverse_selected_latents )
+594 re_selection_scores = selection_mod (
+595 inverse_selected_latents , re_encoded_latents
+596 )
+597 re_fused_latents = gw_mod . fuse ( re_encoded_latents , re_selection_scores )
+598 re_decoded_latents = gw_mod . decode (
+599 re_fused_latents , domains = selected_latents . keys ()
+600 )
+601
+602 for domain in selected_latents :
+603 re_ground_truth = latents [ domain ]
+604 re_loss_output = domain_mods [ domain ] . compute_loss (
+605 re_decoded_latents [ domain ], re_ground_truth
+606 )
+607 loss_label = (
+608 f "from_ { selected_group_label } _"
+609 f "through_ { inverse_selected_group_label } _to_ { domain } _"
+610 f "case_ { group_name } "
+611 )
+612 losses [ loss_label + "_loss" ] = re_loss_output . loss
+613 metrics . update (
+614 {
+615 f " { loss_label } _ { k } " : v
+616 for k , v in re_loss_output . metrics . items ()
+617 }
+618 )
+619 cycle_losses . append ( loss_label + "_loss" )
+620
+621 if demi_cycle_losses :
+622 metrics [ "demi_cycles" ] = torch . mean (
+623 torch . stack ([ losses [ loss_name ] for loss_name in demi_cycle_losses ])
+624 )
+625 if cycle_losses :
+626 metrics [ "cycles" ] = torch . mean (
+627 torch . stack ([ losses [ loss_name ] for loss_name in cycle_losses ])
+628 )
+629 if translation_losses :
+630 metrics [ "translations" ] = torch . mean (
+631 torch . stack ([ losses [ loss_name ] for loss_name in translation_losses ])
+632 )
+633 if fused_losses :
+634 metrics [ "fused" ] = torch . mean (
+635 torch . stack ([ losses [ loss_name ] for loss_name in fused_losses ])
+636 )
+637
+638 metrics . update ( losses )
+639 return metrics
+640
+641
+642 class BroadcastLossCoefs ( TypedDict , total = False ):
+643 """
+644 Dict of loss coefficients used in the GWLossesFusion.
+645
+646 If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+647 If the loss is excplicitely set to 0, it will be logged, but not take part in
+648 the total loss.
+649 """
+650
+651 contrastives : float
+652 """Contrastive loss coefficient."""
+653
+654 fused : float
+655 """fused loss coefficient (encode multiple domains and decode to one of them)."""
+656
+657 demi_cycles : float
+658 """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
+659
+660 cycles : float
+661 """cycles loss coefficient. Cycles can be many-to-one"""
+662
+663 translations : float
+664 """translation loss coefficient. Translation, like cycles, can be many-to-one."""
+665
+666
+667 class GWLosses ( GWLossesBase ):
+668 """
+669 Implementation of `GWLossesBase` for fusion-based models.
+670 """
+671
+672 def __init__ (
+673 self ,
+674 gw_mod : GWModule ,
+675 selection_mod : SelectionBase ,
+676 domain_mods : dict [ str , DomainModule ],
+677 loss_coefs : BroadcastLossCoefs ,
+678 contrastive_fn : ContrastiveLossType ,
+679 ):
+680 """
+681 Initializes the loss computation module for a Global Workspace Fusion model.
+682
+683 Args:
+684 gw_mod: The GWModule for the global workspace.
+685 selection_mod: The selection mechanism for the model.
+686 domain_mods: A mapping of domain names to their respective DomainModule.
+687 loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+688 contrastive_fn: The function used for computing contrastive loss.
+689 """
+690 super () . __init__ ()
+691 self . gw_mod = gw_mod
+692 self . selection_mod = selection_mod
+693 self . domain_mods = domain_mods
+694 self . loss_coefs = loss_coefs
+695 self . contrastive_fn = contrastive_fn
+696
+697 def contrastive_loss (
+698 self , latent_domains : LatentsDomainGroupsT
+699 ) -> dict [ str , torch . Tensor ]:
+700 """
+701 Computes the contrastive loss for the given latent domains.
+702
+703 Args:
+704 latent_domains: The latent domain representations.
+705
+706 Returns:
+707 A dictionary of contrastive loss metrics.
+708 """
+709
+710 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+711
+712 def broadcast_loss (
+713 self , latent_domains : LatentsDomainGroupsT
+714 ) -> dict [ str , torch . Tensor ]:
+715 return broadcast_loss (
+716 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+717 )
+718
+719 def step (
+720 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+721 ) -> LossOutput :
+722 """
+723 Performs a step of loss computation.
+724
+725 Args:
+726 domain_latents: Latent representations for all domains.
+727 mode: The mode in which the model is currently operating.
+728
+729 Returns:
+730 A LossOutput object containing the loss and metrics for this step.
+731 """
+732
+733 metrics : dict [ str , torch . Tensor ] = {}
+734
+735 metrics . update ( self . contrastive_loss ( domain_latents ))
+736 metrics . update ( self . broadcast_loss ( domain_latents ))
+737
+738 loss = torch . stack (
+739 [
+740 metrics [ name ] * coef
+741 for name , coef in self . loss_coefs . items ()
+742 if isinstance ( coef , float ) and coef > 0
+743 ],
+744 dim = 0 ,
+745 ) . mean ()
+746
+747 metrics [ "broadcast_loss" ] = torch . stack (
+748 [
+749 metrics [ name ]
+750 for name , coef in self . loss_coefs . items ()
+751 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+752 ],
+753 dim = 0 ,
+754 ) . mean ()
+755
+756 return LossOutput ( loss , metrics )
+757
+758
+759 class GWLossesBayesian ( GWLossesBase ):
+760 """
+761 Implementation of `GWLossesBase` used for `GWModuleBayesian`.
+762 """
+763
+764 def __init__ (
+765 self ,
+766 gw_mod : GWModuleBayesian ,
+767 selection_mod : SelectionBase ,
+768 domain_mods : dict [ str , DomainModule ],
+769 loss_coefs : BroadcastLossCoefs ,
+770 contrastive_fn : ContrastiveLossType ,
+771 use_normalized_constrastive : bool = True ,
+772 ):
+773 """
+774 Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+775
+776 Args:
+777 gw_mod (`GWModuleBayesian`): the GWModule
+778 selection_mod (`SelectionBase`): selection module
+779 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+780 domain name and value is the DomainModule
+781 loss_coefs (`BroadcastLossCoefs`): loss coefficients
+782 contrastive_fn (`ContrastiveLossType`): the contrastive function
+783 to use in contrastive loss
+784 use_normalized_constrastive (`bool`): whether to use the normalized cont
+785 loss by the precision coefs
+786 """
+787 super () . __init__ ()
+788
+789 self . gw_mod = gw_mod
+790 """The GWModule."""
+791
+792 self . selection_mod = selection_mod
+793 """Selection module"""
+794
+795 self . domain_mods = domain_mods
+796 """Domain modules linked to the GW."""
+797
+798 self . loss_coefs = loss_coefs
+799 """The loss coefficients."""
+800
+801 self . contrastive_fn = contrastive_fn
+802 """
+803 Contrastive loss to use.
+804 """
+805
+806 self . use_normalized_constrastive = use_normalized_constrastive
+807
+808 def contrastive_loss (
+809 self , latent_domains : LatentsDomainGroupsT
+810 ) -> dict [ str , torch . Tensor ]:
+811 """
+812 Contrastive loss.
+813
+814 Args:
+815 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+816
+817 Returns:
+818 `dict[str, torch.Tensor]`: a dict of metrics.
+819 """
+820 if self . use_normalized_constrastive :
+821 return contrastive_loss_bayesian (
+822 self . gw_mod , latent_domains , self . contrastive_fn
+823 )
+824 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+825
+826 def broadcast_loss (
+827 self , latent_domains : LatentsDomainGroupsT
+828 ) -> dict [ str , torch . Tensor ]:
+829 return broadcast_loss (
+830 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+831 )
+832
+833 def step (
+834 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+835 ) -> LossOutput :
+836 """
+837 Performs a step of loss computation.
+838
+839 Args:
+840 domain_latents: Latent representations for all domains.
+841 mode: The mode in which the model is currently operating.
+842
+843 Returns:
+844 A LossOutput object containing the loss and metrics for this step.
+845 """
+846
+847 metrics : dict [ str , torch . Tensor ] = {}
+848
+849 metrics . update ( self . contrastive_loss ( domain_latents ))
+850 metrics . update ( self . broadcast_loss ( domain_latents ))
+851
+852 loss = torch . stack (
+853 [
+854 metrics [ name ] * coef
+855 for name , coef in self . loss_coefs . items ()
+856 if isinstance ( coef , float ) and coef > 0
+857 ],
+858 dim = 0 ,
+859 ) . mean ()
+860
+861 metrics [ "broadcast_loss" ] = torch . stack (
+862 [
+863 metrics [ name ]
+864 for name , coef in self . loss_coefs . items ()
+865 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+866 ],
+867 dim = 0 ,
+868 ) . mean ()
+869
+870 return LossOutput ( loss , metrics )
+
+
+
+
+
+
+
+
+ class
+ GWLossesBase (torch.nn.modules.module.Module , abc.ABC ):
+
+ View Source
+
+
+
+ 20 class GWLossesBase ( torch . nn . Module , ABC ):
+21 """
+22 Base Abstract Class for Global Workspace (GW) losses. This module is used
+23 to compute the different losses of the GW (typically translation, cycle,
+24 demi-cycle, contrastive losses).
+25 """
+26
+27 @abstractmethod
+28 def step (
+29 self ,
+30 domain_latents : LatentsDomainGroupsT ,
+31 mode : ModelModeT ,
+32 ) -> LossOutput :
+33 """
+34 Computes the losses.
+35
+36 Args:
+37 domain_latents (`LatentsDomainGroupsT`): All latent groups
+38 mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+39 Returns:
+40 `LossOutput`: the losses
+41 """
+42 ...
+
+
+
+ Base Abstract Class for Global Workspace (GW) losses. This module is used
+to compute the different losses of the GW (typically translation, cycle,
+demi-cycle, contrastive losses).
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
step ( self , domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , mode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer.modules.domain.LossOutput :
+
+
View Source
+
+
+
+
27 @abstractmethod
+28 def step (
+29 self ,
+30 domain_latents : LatentsDomainGroupsT ,
+31 mode : ModelModeT ,
+32 ) -> LossOutput :
+33 """
+34 Computes the losses.
+35
+36 Args:
+37 domain_latents (`LatentsDomainGroupsT`): All latent groups
+38 mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+39 Returns:
+40 `LossOutput`: the losses
+41 """
+42 ...
+
+
+
+
Computes the losses.
+
+
Arguments:
+
+
+domain_latents (LatentsDomainGroupsT
): All latent groups
+mode (Literal["train", "val", "test", "val/ood", "test/ood"]
): model mode
+
+
+
Returns:
+
+
+ LossOutput
: the losses
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ 45 def demi_cycle_loss (
+46 gw_mod : GWModuleBase ,
+47 selection_mod : SelectionBase ,
+48 domain_mods : Mapping [ str , DomainModule ],
+49 latent_domains : LatentsDomainGroupsT ,
+50 ) -> dict [ str , torch . Tensor ]:
+51 """
+52 Computes the demi-cycle loss.
+53
+54 This return multiple metrics:
+55 * `demi_cycle_{domain_name}` with the demi-cycle of a particular domain;
+56 * `demi_cycle_{domain_name}_{metric}` with additional metrics provided by
+57 the domain_mod's `compute_dcy_loss` output;
+58 * `demi_cycles` with the average value of all `demi_cycle_{domain_name}` values.
+59
+60 Args:
+61 gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+62 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+63 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+64 latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal
+65 groups
+66
+67 Returns:
+68 `dict[str, torch.Tensor]`: a dict of metrics.
+69 """
+70 losses : dict [ str , torch . Tensor ] = {}
+71 metrics : dict [ str , torch . Tensor ] = {}
+72 for domains , latents in latent_domains . items ():
+73 if len ( domains ) > 1 :
+74 continue
+75 domain_name = next ( iter ( domains ))
+76
+77 domain_mod = domain_mods [ domain_name ]
+78 x_recons = gw_mod . decode (
+79 gw_mod . encode_and_fuse ( latents , selection_mod ), domains = { domain_name }
+80 )[ domain_name ]
+81 loss_output = domain_mod . compute_dcy_loss ( x_recons , latents [ domain_name ])
+82 losses [ f "demi_cycle_ { domain_name } " ] = loss_output . loss
+83 metrics . update (
+84 { f "demi_cycle_ { domain_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+85 )
+86 losses [ "demi_cycles" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+87 losses . update ( metrics )
+88 return losses
+
+
+
+ Computes the demi-cycle loss.
+
+
This return multiple metrics:
+
+
+
+ demi_cycle_{domain_name}
with the demi-cycle of a particular domain;
+ demi_cycle_{domain_name}_{metric}
with additional metrics provided by
+ the domain_mod's compute_dcy_loss
output;
+ demi_cycles
with the average value of all demi_cycle_{domain_name}
values.
+
+
+
+
Arguments:
+
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ 91 def cycle_loss (
+ 92 gw_mod : GWModuleBase ,
+ 93 selection_mod : SelectionBase ,
+ 94 domain_mods : Mapping [ str , DomainModule ],
+ 95 latent_domains : LatentsDomainGroupsT ,
+ 96 ) -> dict [ str , torch . Tensor ]:
+ 97 """
+ 98 Computes the cycle loss.
+ 99
+100 This return multiple metrics:
+101 * `cycle_{domain_source}_through_{domain_target}` with the cycle of
+102 a particular domain;
+103 * `cycle_{domain_source}_through_{domain_target}_{metric}` with additional
+104 metrics provided by the domain_mod's `compute_cy_loss` output;
+105 * `cycles` with the average value of all
+106 `cycle_{domain_source}_through_{domain_target}` values.
+107
+108 Args:
+109 gw_mod (`GWModuleBase`): The GWModule to use
+110 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+111 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+112 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+113
+114 Returns:
+115 `dict[str, torch.Tensor]`: a dict of metrics.
+116 """
+117 losses : dict [ str , torch . Tensor ] = {}
+118 metrics : dict [ str , torch . Tensor ] = {}
+119 for domains_source , latents_source in latent_domains . items ():
+120 if len ( domains_source ) > 1 :
+121 continue
+122 domain_name_source = list ( domains_source )[ 0 ]
+123
+124 domain_mod = domain_mods [ domain_name_source ]
+125 z = gw_mod . encode_and_fuse ( latents_source , selection_mod )
+126 for domain_name_target in domain_mods :
+127 if domain_name_target == domain_name_source :
+128 continue
+129
+130 x_pred = gw_mod . decode ( z , domains = { domain_name_target })
+131
+132 x_recons = gw_mod . decode (
+133 gw_mod . encode_and_fuse ( x_pred , selection_mod ),
+134 domains = { domain_name_source },
+135 )
+136
+137 loss_name = f " { domain_name_source } _through_ { domain_name_target } "
+138 loss_output = domain_mod . compute_cy_loss (
+139 x_recons [ domain_name_source ],
+140 latents_source [ domain_name_source ],
+141 )
+142 metrics . update (
+143 { f "cycle_ { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+144 )
+145 losses [ f "cycle_ { loss_name } " ] = loss_output . loss
+146 losses [ "cycles" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+147 losses . update ( metrics )
+148 return losses
+
+
+
+ Computes the cycle loss.
+
+
This return multiple metrics:
+
+
+
+ cycle_{domain_source}_through_{domain_target}
with the cycle of
+ a particular domain;
+ cycle_{domain_source}_through_{domain_target}_{metric}
with additional
+ metrics provided by the domain_mod's compute_cy_loss
output;
+ cycles
with the average value of all
+ cycle_{domain_source}_through_{domain_target}
values.
+
+
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): The GWModule to use
+selection_mod (shimmer.modules.selection.SelectionBase
): Selection mod to use
+domain_mods (Mapping[str, DomainModule]
): the domain modules
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ 151 def translation_loss (
+152 gw_mod : GWModuleBase ,
+153 selection_mod : SelectionBase ,
+154 domain_mods : Mapping [ str , DomainModule ],
+155 latent_domains : LatentsDomainGroupsT ,
+156 ) -> dict [ str , torch . Tensor ]:
+157 """
+158 Computes the translation loss.
+159
+160 This return multiple metrics:
+161 * `translation_{domain_source}_to_{domain_target}` with the translation
+162 from a domain source to a domain target;
+163 * `translation_{domain_source}_to_{domain_target}_{metric}` with
+164 additional metrics provided by the domain_mod's
+165 `compute_tr_loss` output;
+166 * `translations` with the average value of all
+167 `translation_{domain_source}_to_{domain_target}` values.
+168
+169 Args:
+170 gw_mod (`GWModuleBase`): The GWModule to use
+171 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+172 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+173
+174 Returns:
+175 `dict[str, torch.Tensor]`: a dict of metrics.
+176 """
+177 losses : dict [ str , torch . Tensor ] = {}
+178 metrics : dict [ str , torch . Tensor ] = {}
+179 for domains , latents in latent_domains . items ():
+180 if len ( domains ) < 2 :
+181 continue
+182 for domain_name_target in domains :
+183 domain_sources = {
+184 domain : latents [ domain ]
+185 for domain in domains
+186 if domain != domain_name_target
+187 }
+188
+189 z = gw_mod . encode_and_fuse ( domain_sources , selection_mod )
+190 mod = domain_mods [ domain_name_target ]
+191
+192 domain_source_names = "/" . join ( domain_sources . keys ())
+193 loss_name = f " { domain_source_names } _to_ { domain_name_target } "
+194 if loss_name in losses :
+195 raise ValueError ( f " { loss_name } is already computed." )
+196
+197 prediction = gw_mod . decode ( z , domains = { domain_name_target })[
+198 domain_name_target
+199 ]
+200 loss_output = mod . compute_tr_loss (
+201 prediction ,
+202 latents [ domain_name_target ],
+203 )
+204 losses [ f "translation_ { loss_name } " ] = loss_output . loss
+205 metrics . update (
+206 {
+207 f "translation_ { loss_name } _ { k } " : v
+208 for k , v in loss_output . metrics . items ()
+209 }
+210 )
+211 losses [ "translations" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+212 losses . update ( metrics )
+213 return losses
+
+
+
+ Computes the translation loss.
+
+
This return multiple metrics:
+
+
+
+ translation_{domain_source}_to_{domain_target}
with the translation
+ from a domain source to a domain target;
+ translation_{domain_source}_to_{domain_target}_{metric}
with
+ additional metrics provided by the domain_mod's
+ compute_tr_loss
output;
+ translations
with the average value of all
+ translation_{domain_source}_to_{domain_target}
values.
+
+
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): The GWModule to use
+domain_mods (Mapping[str, DomainModule]
): the domain modules
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ 216 def contrastive_loss (
+217 gw_mod : GWModuleBase ,
+218 latent_domains : LatentsDomainGroupsT ,
+219 contrastive_fn : ContrastiveLossType ,
+220 ) -> dict [ str , torch . Tensor ]:
+221 """
+222 Computes the contrastive loss.
+223
+224 This return multiple metrics:
+225 * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+226 between 2 domains;
+227 * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+228 additional metrics provided by the domain_mod's
+229 `compute_cont_loss` output;
+230 * `contrastives` with the average value of all
+231 `contrastive_{domain_1}_and_{domain_2}` values.
+232
+233 Args:
+234 gw_mod (`GWModuleBase`): The GWModule to use
+235 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+236 contrastive_fn (`ContrastiveLossType`): the contrastive function to apply
+237
+238 Returns:
+239 `dict[str, torch.Tensor]`: a dict of metrics.
+240 """
+241 losses : dict [ str , torch . Tensor ] = {}
+242 metrics : dict [ str , torch . Tensor ] = {}
+243 keys : list [ set [ str ]] = []
+244
+245 for latents in latent_domains . values ():
+246 if len ( latents ) != 2 :
+247 continue
+248
+249 cont_latents = gw_mod . encode ( latents )
+250 for domain1 , z1 in cont_latents . items ():
+251 for domain2 , z2 in cont_latents . items ():
+252 selected_domains = { domain1 , domain2 }
+253 if domain1 == domain2 or selected_domains in keys :
+254 continue
+255
+256 keys . append ( selected_domains )
+257
+258 loss_name = f "contrastive_ { domain1 } _and_ { domain2 } "
+259 loss_output = contrastive_fn ( z1 , z2 )
+260 losses [ loss_name ] = loss_output . loss
+261 metrics . update (
+262 { f " { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+263 )
+264
+265 losses [ "contrastives" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+266 losses . update ( metrics )
+267 return losses
+
+
+
+ Computes the contrastive loss.
+
+
This return multiple metrics:
+
+
+
+ contrastive_{domain_1}_and_{domain_2}
with the contrastive
+ between 2 domains;
+ contrastive_{domain_1}_and_{domain_2}_{metric}
with
+ additional metrics provided by the domain_mod's
+ compute_cont_loss
output;
+ contrastives
with the average value of all
+ contrastive_{domain_1}_and_{domain_2}
values.
+
+
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): The GWModule to use
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+contrastive_fn (ContrastiveLossType
): the contrastive function to apply
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ 270 def contrastive_loss_bayesian (
+271 gw_mod : GWModuleBayesian ,
+272 latent_domains : LatentsDomainGroupsT ,
+273 contrastive_fn : ContrastiveLossType ,
+274 ) -> dict [ str , torch . Tensor ]:
+275 """
+276 Computes the contrastive loss with a Bayesian based uncertainty prediction.
+277
+278 This return multiple metrics:
+279 * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+280 between 2 domains;
+281 * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+282 additional metrics provided by the domain_mod's
+283 `compute_cont_loss` output;
+284 * `contrastives` with the average value of all
+285 `contrastive_{domain_1}_and_{domain_2}` values.
+286
+287 Args:
+288 gw_mod (`GWModuleBayesian`): The GWModule to use
+289 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+290 contrastive_fn (`ContrastiveLossBayesianType`): the contrastive function
+291 to apply
+292
+293 Returns:
+294 `dict[str, torch.Tensor]`: a dict of metrics.
+295 """
+296 losses : dict [ str , torch . Tensor ] = {}
+297 metrics : dict [ str , torch . Tensor ] = {}
+298 keys : list [ set [ str ]] = []
+299
+300 for latents in latent_domains . values ():
+301 if len ( latents ) < 2 :
+302 continue
+303 for domain1_name , domain1 in latents . items ():
+304 z1 = gw_mod . encode ({ domain1_name : domain1 })[ domain1_name ]
+305 z1_precision = gw_mod . get_precision ( domain1_name , domain1 )
+306 for domain2_name , domain2 in latents . items ():
+307 selected_domains = { domain1_name , domain2_name }
+308 if domain1_name == domain2_name or selected_domains in keys :
+309 continue
+310
+311 keys . append ( selected_domains )
+312
+313 loss_name = f "contrastive_ { domain1_name } _and_ { domain2_name } "
+314 z2 = gw_mod . encode ({ domain2_name : domain2 })[ domain2_name ]
+315 z2_precision = gw_mod . get_precision ( domain2_name , domain2 )
+316 coef = torch . softmax (
+317 gw_mod . precision_softmax_temp
+318 * torch . stack ([ z1_precision , z2_precision ]),
+319 dim = 0 ,
+320 )
+321 norm = torch . sqrt ( coef [ 0 ] * coef [ 1 ])
+322 loss_output = contrastive_fn ( z1 * norm , z2 * norm )
+323 loss_output_no_norm = contrastive_fn ( z1 , z2 )
+324 losses [ loss_name ] = loss_output . loss
+325 metrics . update (
+326 { f " { loss_name } _ { k } " : v for k , v in loss_output . metrics . items ()}
+327 )
+328 metrics [ f "unnorm_ { loss_name } " ] = loss_output_no_norm . loss
+329
+330 losses [ "contrastives" ] = torch . stack ( list ( losses . values ()), dim = 0 ) . mean ()
+331 losses . update ( metrics )
+332 return losses
+
+
+
+ Computes the contrastive loss with a Bayesian based uncertainty prediction.
+
+
This return multiple metrics:
+
+
+
+ contrastive_{domain_1}_and_{domain_2}
with the contrastive
+ between 2 domains;
+ contrastive_{domain_1}_and_{domain_2}_{metric}
with
+ additional metrics provided by the domain_mod's
+ compute_cont_loss
output;
+ contrastives
with the average value of all
+ contrastive_{domain_1}_and_{domain_2}
values.
+
+
+
+
Arguments:
+
+
+gw_mod (GWModuleBayesian
): The GWModule to use
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+contrastive_fn (ContrastiveLossBayesianType
): the contrastive function
+to apply
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ class
+ LossCoefs (typing.TypedDict ):
+
+ View Source
+
+
+
+ 335 class LossCoefs ( TypedDict , total = False ):
+336 """
+337 Dict of loss coefficients used in the GWLosses.
+338
+339 If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+340 If the loss is excplicitely set to 0, it will be logged, but not take part in
+341 the total loss.
+342 """
+343
+344 demi_cycles : float
+345 """Demi-cycle loss coefficient."""
+346
+347 cycles : float
+348 """Cycle loss coefficient."""
+349
+350 translations : float
+351 """Translation loss coefficient."""
+352
+353 contrastives : float
+354 """Contrastive loss coefficient."""
+
+
+
+ Dict of loss coefficients used in the GWLosses.
+
+
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+If the loss is excplicitely set to 0, it will be logged, but not take part in
+the total loss.
+
+
+
+
+
+ demi_cycles : float
+
+
+
+
+
+
Demi-cycle loss coefficient.
+
+
+
+
+
+
+ cycles : float
+
+
+
+
+
+
Cycle loss coefficient.
+
+
+
+
+
+
+ translations : float
+
+
+
+
+
+
Translation loss coefficient.
+
+
+
+
+
+
+ contrastives : float
+
+
+
+
+
+
Contrastive loss coefficient.
+
+
+
+
+
+
+
+
+
+
class
+
GWLosses2Domains (GWLossesBase ):
+
+ View Source
+
+
+
+ 357 class GWLosses2Domains ( GWLossesBase ):
+358 """
+359 Implementation of `GWLossesBase` used for `GWModule`.
+360 """
+361
+362 def __init__ (
+363 self ,
+364 gw_mod : GWModule ,
+365 selection_mod : SelectionBase ,
+366 domain_mods : dict [ str , DomainModule ],
+367 loss_coefs : LossCoefs ,
+368 contrastive_fn : ContrastiveLossType ,
+369 ):
+370 """
+371 Main loss module to use with the GlobalWorkspace
+372
+373 Args:
+374 gw_mod (`GWModule`): the GWModule
+375 selection_mod (`SelectionBase`): selection module
+376 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+377 domain name and value is the DomainModule
+378 loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+379 mapping to float with correct keys.
+380 contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+381 in contrastive loss
+382 """
+383
+384 super () . __init__ ()
+385 self . gw_mod = gw_mod
+386 self . selection_mod = selection_mod
+387 self . domain_mods = domain_mods
+388 self . loss_coefs = loss_coefs
+389 self . contrastive_fn = contrastive_fn
+390
+391 def demi_cycle_loss (
+392 self , latent_domains : LatentsDomainGroupsT
+393 ) -> dict [ str , torch . Tensor ]:
+394 """
+395 Computes the demi-cycle loss.
+396
+397 See `shimmer.modules.losses.demi_cycle_loss`.
+398
+399 Args:
+400 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+401
+402 Returns:
+403 `dict[str, torch.Tensor]`: a dict of metrics.
+404 """
+405 return demi_cycle_loss (
+406 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+407 )
+408
+409 def cycle_loss (
+410 self , latent_domains : LatentsDomainGroupsT
+411 ) -> dict [ str , torch . Tensor ]:
+412 """
+413 Computes the cycle loss.
+414
+415 See `shimmer.modules.losses.cycle_loss`.
+416
+417 Args:
+418 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+419
+420 Returns:
+421 `dict[str, torch.Tensor]`: a dict of metrics.
+422 """
+423 return cycle_loss (
+424 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+425 )
+426
+427 def translation_loss (
+428 self , latent_domains : LatentsDomainGroupsT
+429 ) -> dict [ str , torch . Tensor ]:
+430 """
+431 Computes the translation loss.
+432
+433 See `shimmer.modules.losses.translation_loss`.
+434
+435 Args:
+436 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+437
+438 Returns:
+439 `dict[str, torch.Tensor]`: a dict of metrics.
+440 """
+441 return translation_loss (
+442 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+443 )
+444
+445 def contrastive_loss (
+446 self , latent_domains : LatentsDomainGroupsT
+447 ) -> dict [ str , torch . Tensor ]:
+448 """
+449 Computes the contrastive loss.
+450
+451 See `shimmer.modules.losses.contrastive_loss`.
+452
+453 Args:
+454 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+455
+456 Returns:
+457 `dict[str, torch.Tensor]`: a dict of metrics.
+458 """
+459 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+460
+461 def step (
+462 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+463 ) -> LossOutput :
+464 """
+465 Computes and returns the losses
+466
+467 Contains:
+468 - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+469 - Cycle metrics (see `GWLosses.cycle_loss`)
+470 - Translation metrics (see `GWLosses.translation_loss`)
+471 - Contrastive metrics (see `GWLosses.contrastive_loss`)
+472
+473 Args:
+474 domain_latents (`LatentsDomainGroupsT`): All latent groups
+475 mode (`ModelModeT`): model mode
+476 Returns:
+477 `LossOutput`: the losses
+478 """
+479 metrics : dict [ str , torch . Tensor ] = {}
+480
+481 metrics . update ( self . demi_cycle_loss ( domain_latents ))
+482 metrics . update ( self . cycle_loss ( domain_latents ))
+483 metrics . update ( self . translation_loss ( domain_latents ))
+484 metrics . update ( self . contrastive_loss ( domain_latents ))
+485
+486 loss = torch . stack (
+487 [
+488 metrics [ name ] * coef
+489 for name , coef in self . loss_coefs . items ()
+490 if isinstance ( coef , float ) and coef > 0
+491 ],
+492 dim = 0 ,
+493 ) . mean ()
+494
+495 return LossOutput ( loss , metrics )
+
+
+
+
+
+
+
+
+
+
+
362 def __init__ (
+363 self ,
+364 gw_mod : GWModule ,
+365 selection_mod : SelectionBase ,
+366 domain_mods : dict [ str , DomainModule ],
+367 loss_coefs : LossCoefs ,
+368 contrastive_fn : ContrastiveLossType ,
+369 ):
+370 """
+371 Main loss module to use with the GlobalWorkspace
+372
+373 Args:
+374 gw_mod (`GWModule`): the GWModule
+375 selection_mod (`SelectionBase`): selection module
+376 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+377 domain name and value is the DomainModule
+378 loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+379 mapping to float with correct keys.
+380 contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+381 in contrastive loss
+382 """
+383
+384 super () . __init__ ()
+385 self . gw_mod = gw_mod
+386 self . selection_mod = selection_mod
+387 self . domain_mods = domain_mods
+388 self . loss_coefs = loss_coefs
+389 self . contrastive_fn = contrastive_fn
+
+
+
+
Main loss module to use with the GlobalWorkspace
+
+
Arguments:
+
+
+gw_mod (GWModule
): the GWModule
+selection_mod (SelectionBase
): selection module
+domain_mods (dict[str, DomainModule]
): a dict where the key is the
+domain name and value is the DomainModule
+loss_coefs (LossCoefs
): loss coefficients. LossCoefs object, or a
+mapping to float with correct keys.
+contrastive_fn (ContrastiveLossType
): the contrastive function to use
+in contrastive loss
+
+
+
+
+
+
+
+ gw_mod
+
+
+
+
+
+
+
+
+
+
+ selection_mod
+
+
+
+
+
+
+
+
+
+
+ domain_mods
+
+
+
+
+
+
+
+
+
+
+ loss_coefs
+
+
+
+
+
+
+
+
+
+
+ contrastive_fn
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ demi_cycle_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
391 def demi_cycle_loss (
+392 self , latent_domains : LatentsDomainGroupsT
+393 ) -> dict [ str , torch . Tensor ]:
+394 """
+395 Computes the demi-cycle loss.
+396
+397 See `shimmer.modules.losses.demi_cycle_loss`.
+398
+399 Args:
+400 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+401
+402 Returns:
+403 `dict[str, torch.Tensor]`: a dict of metrics.
+404 """
+405 return demi_cycle_loss (
+406 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+407 )
+
+
+
+
Computes the demi-cycle loss.
+
+
See demi_cycle_loss
.
+
+
Arguments:
+
+
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ def
+ cycle_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
409 def cycle_loss (
+410 self , latent_domains : LatentsDomainGroupsT
+411 ) -> dict [ str , torch . Tensor ]:
+412 """
+413 Computes the cycle loss.
+414
+415 See `shimmer.modules.losses.cycle_loss`.
+416
+417 Args:
+418 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+419
+420 Returns:
+421 `dict[str, torch.Tensor]`: a dict of metrics.
+422 """
+423 return cycle_loss (
+424 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+425 )
+
+
+
+
Computes the cycle loss.
+
+
See cycle_loss
.
+
+
Arguments:
+
+
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ def
+ translation_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
427 def translation_loss (
+428 self , latent_domains : LatentsDomainGroupsT
+429 ) -> dict [ str , torch . Tensor ]:
+430 """
+431 Computes the translation loss.
+432
+433 See `shimmer.modules.losses.translation_loss`.
+434
+435 Args:
+436 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+437
+438 Returns:
+439 `dict[str, torch.Tensor]`: a dict of metrics.
+440 """
+441 return translation_loss (
+442 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+443 )
+
+
+
+
Computes the translation loss.
+
+
See translation_loss
.
+
+
Arguments:
+
+
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ def
+ contrastive_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
445 def contrastive_loss (
+446 self , latent_domains : LatentsDomainGroupsT
+447 ) -> dict [ str , torch . Tensor ]:
+448 """
+449 Computes the contrastive loss.
+450
+451 See `shimmer.modules.losses.contrastive_loss`.
+452
+453 Args:
+454 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+455
+456 Returns:
+457 `dict[str, torch.Tensor]`: a dict of metrics.
+458 """
+459 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+
+
+
+
Computes the contrastive loss.
+
+
See contrastive_loss
.
+
+
Arguments:
+
+
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+
def
+
step ( self , domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , mode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer.modules.domain.LossOutput :
+
+
View Source
+
+
+
+
461 def step (
+462 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+463 ) -> LossOutput :
+464 """
+465 Computes and returns the losses
+466
+467 Contains:
+468 - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+469 - Cycle metrics (see `GWLosses.cycle_loss`)
+470 - Translation metrics (see `GWLosses.translation_loss`)
+471 - Contrastive metrics (see `GWLosses.contrastive_loss`)
+472
+473 Args:
+474 domain_latents (`LatentsDomainGroupsT`): All latent groups
+475 mode (`ModelModeT`): model mode
+476 Returns:
+477 `LossOutput`: the losses
+478 """
+479 metrics : dict [ str , torch . Tensor ] = {}
+480
+481 metrics . update ( self . demi_cycle_loss ( domain_latents ))
+482 metrics . update ( self . cycle_loss ( domain_latents ))
+483 metrics . update ( self . translation_loss ( domain_latents ))
+484 metrics . update ( self . contrastive_loss ( domain_latents ))
+485
+486 loss = torch . stack (
+487 [
+488 metrics [ name ] * coef
+489 for name , coef in self . loss_coefs . items ()
+490 if isinstance ( coef , float ) and coef > 0
+491 ],
+492 dim = 0 ,
+493 ) . mean ()
+494
+495 return LossOutput ( loss , metrics )
+
+
+
+
Computes and returns the losses
+
+
Contains:
+
+
+
+ Demi-cycle metrics (see GWLosses.demi_cycle_loss
)
+ Cycle metrics (see GWLosses.cycle_loss
)
+ Translation metrics (see GWLosses.translation_loss
)
+ Contrastive metrics (see GWLosses.contrastive_loss
)
+
+
+
+
Arguments:
+
+
+domain_latents (LatentsDomainGroupsT
): All latent groups
+mode (ModelModeT
): model mode
+
+
+
Returns:
+
+
+ LossOutput
: the losses
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+ get_extra_state
+ set_extra_state
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ extra_repr
+ compile
+
+
+
+
+
+
+
+
+
+ def
+ generate_partitions (n : int ) -> collections . abc . Generator [ tuple [ int , ... ], None , None ] :
+
+ View Source
+
+
+
+ 498 def generate_partitions ( n : int ) -> Generator [ tuple [ int , ... ], None , None ]:
+499 """
+500 Generates all possible partitions of zeros and ones for `n` elements,
+501 excluding the all-zeros partition.
+502
+503 Args:
+504 n (`int`): The number of modalities to generate partitions for.
+505
+506 Yields:
+507 `tuple[int, ...]`: A partition of zeros and ones, excluding the
+508 all-zeros partition.
+509 """
+510 for perm in product ([ 0 , 1 ], repeat = n ):
+511 if any ( perm ):
+512 yield perm
+
+
+
+ Generates all possible partitions of zeros and ones for n
elements,
+excluding the all-zeros partition.
+
+
Arguments:
+
+
+n (int
): The number of modalities to generate partitions for.
+
+
+
Yields:
+
+
+ tuple[int, ...]
: A partition of zeros and ones, excluding the
+ all-zeros partition.
+
+
+
+
+
+
+
+
+
+ 515 def broadcast_loss (
+516 gw_mod : GWModuleBase ,
+517 selection_mod : SelectionBase ,
+518 domain_mods : Mapping [ str , DomainModule ],
+519 latent_domains : LatentsDomainGroupsT ,
+520 ) -> dict [ str , torch . Tensor ]:
+521 """
+522 Computes broadcast loss including demi-cycle, cycle, and translation losses.
+523
+524 Args:
+525 gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+526 selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+527 domain_mods (`Mapping[str, DomainModule]`): the domain modules
+528 latent_domains: The latent domain representations.
+529
+530 Returns:
+531 A dictionary with the total loss and additional metrics.
+532 """
+533 losses : dict [ str , torch . Tensor ] = {}
+534 metrics : dict [ str , torch . Tensor ] = {}
+535
+536 demi_cycle_losses : list [ str ] = []
+537 cycle_losses : list [ str ] = []
+538 translation_losses : list [ str ] = []
+539 fused_losses : list [ str ] = []
+540
+541 for group_domains , latents in latent_domains . items ():
+542 encoded_latents = gw_mod . encode ( latents )
+543 partitions = generate_partitions ( len ( group_domains ))
+544 domain_names = list ( latents )
+545 group_name = "-" . join ( group_domains )
+546
+547 for partition in partitions :
+548 selected_latents = {
+549 domain : latents [ domain ]
+550 for domain , present in zip ( domain_names , partition , strict = True )
+551 if present
+552 }
+553 selected_encoded_latents = {
+554 domain : encoded_latents [ domain ] for domain in selected_latents
+555 }
+556 selected_group_label = "{" + ", " . join ( sorted ( selected_latents )) + "}"
+557
+558 selection_scores = selection_mod ( selected_latents , selected_encoded_latents )
+559 fused_latents = gw_mod . fuse ( selected_encoded_latents , selection_scores )
+560 decoded_latents = gw_mod . decode ( fused_latents )
+561
+562 num_active_domains = sum ( partition )
+563 num_total_domains = len ( decoded_latents )
+564
+565 for domain , pred in decoded_latents . items ():
+566 if domain not in group_domains : # if we don't have ground truth
+567 continue
+568 ground_truth = latents [ domain ]
+569 loss_output = domain_mods [ domain ] . compute_loss ( pred , ground_truth )
+570 loss_label = f "from_ { selected_group_label } _to_ { domain } "
+571 losses [ loss_label + "_loss" ] = loss_output . loss
+572 metrics . update (
+573 { f " { loss_label } _ { k } " : v for k , v in loss_output . metrics . items ()}
+574 )
+575
+576 if num_active_domains == 1 and domain in selected_latents :
+577 demi_cycle_losses . append ( loss_label + "_loss" )
+578 elif domain not in selected_latents :
+579 translation_losses . append ( loss_label + "_loss" )
+580 else : # fused loss
+581 fused_losses . append ( loss_label + "_loss" )
+582
+583 if num_active_domains < num_total_domains :
+584 inverse_selected_latents = {
+585 domain : decoded_latents [ domain ]
+586 for domain in decoded_latents
+587 if domain not in selected_latents
+588 }
+589
+590 inverse_selected_group_label = (
+591 "{" + ", " . join ( sorted ( inverse_selected_latents )) + "}"
+592 )
+593
+594 re_encoded_latents = gw_mod . encode ( inverse_selected_latents )
+595 re_selection_scores = selection_mod (
+596 inverse_selected_latents , re_encoded_latents
+597 )
+598 re_fused_latents = gw_mod . fuse ( re_encoded_latents , re_selection_scores )
+599 re_decoded_latents = gw_mod . decode (
+600 re_fused_latents , domains = selected_latents . keys ()
+601 )
+602
+603 for domain in selected_latents :
+604 re_ground_truth = latents [ domain ]
+605 re_loss_output = domain_mods [ domain ] . compute_loss (
+606 re_decoded_latents [ domain ], re_ground_truth
+607 )
+608 loss_label = (
+609 f "from_ { selected_group_label } _"
+610 f "through_ { inverse_selected_group_label } _to_ { domain } _"
+611 f "case_ { group_name } "
+612 )
+613 losses [ loss_label + "_loss" ] = re_loss_output . loss
+614 metrics . update (
+615 {
+616 f " { loss_label } _ { k } " : v
+617 for k , v in re_loss_output . metrics . items ()
+618 }
+619 )
+620 cycle_losses . append ( loss_label + "_loss" )
+621
+622 if demi_cycle_losses :
+623 metrics [ "demi_cycles" ] = torch . mean (
+624 torch . stack ([ losses [ loss_name ] for loss_name in demi_cycle_losses ])
+625 )
+626 if cycle_losses :
+627 metrics [ "cycles" ] = torch . mean (
+628 torch . stack ([ losses [ loss_name ] for loss_name in cycle_losses ])
+629 )
+630 if translation_losses :
+631 metrics [ "translations" ] = torch . mean (
+632 torch . stack ([ losses [ loss_name ] for loss_name in translation_losses ])
+633 )
+634 if fused_losses :
+635 metrics [ "fused" ] = torch . mean (
+636 torch . stack ([ losses [ loss_name ] for loss_name in fused_losses ])
+637 )
+638
+639 metrics . update ( losses )
+640 return metrics
+
+
+
+ Computes broadcast loss including demi-cycle, cycle, and translation losses.
+
+
Arguments:
+
+
+
+
Returns:
+
+
+ A dictionary with the total loss and additional metrics.
+
+
+
+
+
+
+
+
+
+ class
+ BroadcastLossCoefs (typing.TypedDict ):
+
+ View Source
+
+
+
+ 643 class BroadcastLossCoefs ( TypedDict , total = False ):
+644 """
+645 Dict of loss coefficients used in the GWLossesFusion.
+646
+647 If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+648 If the loss is excplicitely set to 0, it will be logged, but not take part in
+649 the total loss.
+650 """
+651
+652 contrastives : float
+653 """Contrastive loss coefficient."""
+654
+655 fused : float
+656 """fused loss coefficient (encode multiple domains and decode to one of them)."""
+657
+658 demi_cycles : float
+659 """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
+660
+661 cycles : float
+662 """cycles loss coefficient. Cycles can be many-to-one"""
+663
+664 translations : float
+665 """translation loss coefficient. Translation, like cycles, can be many-to-one."""
+
+
+
+ Dict of loss coefficients used in the GWLossesFusion.
+
+
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+If the loss is excplicitely set to 0, it will be logged, but not take part in
+the total loss.
+
+
+
+
+
+ contrastives : float
+
+
+
+
+
+
Contrastive loss coefficient.
+
+
+
+
+
+
+ fused : float
+
+
+
+
+
+
fused loss coefficient (encode multiple domains and decode to one of them).
+
+
+
+
+
+
+ demi_cycles : float
+
+
+
+
+
+
demi_cycles loss coefficient. Demi-cycles are always one-to-one
+
+
+
+
+
+
+ cycles : float
+
+
+
+
+
+
cycles loss coefficient. Cycles can be many-to-one
+
+
+
+
+
+
+ translations : float
+
+
+
+
+
+
translation loss coefficient. Translation, like cycles, can be many-to-one.
+
+
+
+
+
+
+
+
+
+ 668 class GWLosses ( GWLossesBase ):
+669 """
+670 Implementation of `GWLossesBase` for fusion-based models.
+671 """
+672
+673 def __init__ (
+674 self ,
+675 gw_mod : GWModule ,
+676 selection_mod : SelectionBase ,
+677 domain_mods : dict [ str , DomainModule ],
+678 loss_coefs : BroadcastLossCoefs ,
+679 contrastive_fn : ContrastiveLossType ,
+680 ):
+681 """
+682 Initializes the loss computation module for a Global Workspace Fusion model.
+683
+684 Args:
+685 gw_mod: The GWModule for the global workspace.
+686 selection_mod: The selection mechanism for the model.
+687 domain_mods: A mapping of domain names to their respective DomainModule.
+688 loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+689 contrastive_fn: The function used for computing contrastive loss.
+690 """
+691 super () . __init__ ()
+692 self . gw_mod = gw_mod
+693 self . selection_mod = selection_mod
+694 self . domain_mods = domain_mods
+695 self . loss_coefs = loss_coefs
+696 self . contrastive_fn = contrastive_fn
+697
+698 def contrastive_loss (
+699 self , latent_domains : LatentsDomainGroupsT
+700 ) -> dict [ str , torch . Tensor ]:
+701 """
+702 Computes the contrastive loss for the given latent domains.
+703
+704 Args:
+705 latent_domains: The latent domain representations.
+706
+707 Returns:
+708 A dictionary of contrastive loss metrics.
+709 """
+710
+711 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+712
+713 def broadcast_loss (
+714 self , latent_domains : LatentsDomainGroupsT
+715 ) -> dict [ str , torch . Tensor ]:
+716 return broadcast_loss (
+717 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+718 )
+719
+720 def step (
+721 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+722 ) -> LossOutput :
+723 """
+724 Performs a step of loss computation.
+725
+726 Args:
+727 domain_latents: Latent representations for all domains.
+728 mode: The mode in which the model is currently operating.
+729
+730 Returns:
+731 A LossOutput object containing the loss and metrics for this step.
+732 """
+733
+734 metrics : dict [ str , torch . Tensor ] = {}
+735
+736 metrics . update ( self . contrastive_loss ( domain_latents ))
+737 metrics . update ( self . broadcast_loss ( domain_latents ))
+738
+739 loss = torch . stack (
+740 [
+741 metrics [ name ] * coef
+742 for name , coef in self . loss_coefs . items ()
+743 if isinstance ( coef , float ) and coef > 0
+744 ],
+745 dim = 0 ,
+746 ) . mean ()
+747
+748 metrics [ "broadcast_loss" ] = torch . stack (
+749 [
+750 metrics [ name ]
+751 for name , coef in self . loss_coefs . items ()
+752 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+753 ],
+754 dim = 0 ,
+755 ) . mean ()
+756
+757 return LossOutput ( loss , metrics )
+
+
+
+
+
+
+
+
+
+
+
673 def __init__ (
+674 self ,
+675 gw_mod : GWModule ,
+676 selection_mod : SelectionBase ,
+677 domain_mods : dict [ str , DomainModule ],
+678 loss_coefs : BroadcastLossCoefs ,
+679 contrastive_fn : ContrastiveLossType ,
+680 ):
+681 """
+682 Initializes the loss computation module for a Global Workspace Fusion model.
+683
+684 Args:
+685 gw_mod: The GWModule for the global workspace.
+686 selection_mod: The selection mechanism for the model.
+687 domain_mods: A mapping of domain names to their respective DomainModule.
+688 loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+689 contrastive_fn: The function used for computing contrastive loss.
+690 """
+691 super () . __init__ ()
+692 self . gw_mod = gw_mod
+693 self . selection_mod = selection_mod
+694 self . domain_mods = domain_mods
+695 self . loss_coefs = loss_coefs
+696 self . contrastive_fn = contrastive_fn
+
+
+
+
Initializes the loss computation module for a Global Workspace Fusion model.
+
+
Arguments:
+
+
+gw_mod: The GWModule for the global workspace.
+selection_mod: The selection mechanism for the model.
+domain_mods: A mapping of domain names to their respective DomainModule.
+loss_coefs (BroadcastLossCoefs
): coefs for the losses
+contrastive_fn: The function used for computing contrastive loss.
+
+
+
+
+
+
+
+ gw_mod
+
+
+
+
+
+
+
+
+
+
+ selection_mod
+
+
+
+
+
+
+
+
+
+
+ domain_mods
+
+
+
+
+
+
+
+
+
+
+ loss_coefs
+
+
+
+
+
+
+
+
+
+
+ contrastive_fn
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ contrastive_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
698 def contrastive_loss (
+699 self , latent_domains : LatentsDomainGroupsT
+700 ) -> dict [ str , torch . Tensor ]:
+701 """
+702 Computes the contrastive loss for the given latent domains.
+703
+704 Args:
+705 latent_domains: The latent domain representations.
+706
+707 Returns:
+708 A dictionary of contrastive loss metrics.
+709 """
+710
+711 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+
+
+
+
Computes the contrastive loss for the given latent domains.
+
+
Arguments:
+
+
+latent_domains: The latent domain representations.
+
+
+
Returns:
+
+
+ A dictionary of contrastive loss metrics.
+
+
+
+
+
+
+
+
+
+ def
+ broadcast_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
713 def broadcast_loss (
+714 self , latent_domains : LatentsDomainGroupsT
+715 ) -> dict [ str , torch . Tensor ]:
+716 return broadcast_loss (
+717 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+718 )
+
+
+
+
+
+
+
+
+
+
+
def
+
step ( self , domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , mode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer.modules.domain.LossOutput :
+
+
View Source
+
+
+
+
720 def step (
+721 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+722 ) -> LossOutput :
+723 """
+724 Performs a step of loss computation.
+725
+726 Args:
+727 domain_latents: Latent representations for all domains.
+728 mode: The mode in which the model is currently operating.
+729
+730 Returns:
+731 A LossOutput object containing the loss and metrics for this step.
+732 """
+733
+734 metrics : dict [ str , torch . Tensor ] = {}
+735
+736 metrics . update ( self . contrastive_loss ( domain_latents ))
+737 metrics . update ( self . broadcast_loss ( domain_latents ))
+738
+739 loss = torch . stack (
+740 [
+741 metrics [ name ] * coef
+742 for name , coef in self . loss_coefs . items ()
+743 if isinstance ( coef , float ) and coef > 0
+744 ],
+745 dim = 0 ,
+746 ) . mean ()
+747
+748 metrics [ "broadcast_loss" ] = torch . stack (
+749 [
+750 metrics [ name ]
+751 for name , coef in self . loss_coefs . items ()
+752 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+753 ],
+754 dim = 0 ,
+755 ) . mean ()
+756
+757 return LossOutput ( loss , metrics )
+
+
+
+
Performs a step of loss computation.
+
+
Arguments:
+
+
+domain_latents: Latent representations for all domains.
+mode: The mode in which the model is currently operating.
+
+
+
Returns:
+
+
+ A LossOutput object containing the loss and metrics for this step.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+
class
+
GWLossesBayesian (GWLossesBase ):
+
+ View Source
+
+
+
+ 760 class GWLossesBayesian ( GWLossesBase ):
+761 """
+762 Implementation of `GWLossesBase` used for `GWModuleBayesian`.
+763 """
+764
+765 def __init__ (
+766 self ,
+767 gw_mod : GWModuleBayesian ,
+768 selection_mod : SelectionBase ,
+769 domain_mods : dict [ str , DomainModule ],
+770 loss_coefs : BroadcastLossCoefs ,
+771 contrastive_fn : ContrastiveLossType ,
+772 use_normalized_constrastive : bool = True ,
+773 ):
+774 """
+775 Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+776
+777 Args:
+778 gw_mod (`GWModuleBayesian`): the GWModule
+779 selection_mod (`SelectionBase`): selection module
+780 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+781 domain name and value is the DomainModule
+782 loss_coefs (`BroadcastLossCoefs`): loss coefficients
+783 contrastive_fn (`ContrastiveLossType`): the contrastive function
+784 to use in contrastive loss
+785 use_normalized_constrastive (`bool`): whether to use the normalized cont
+786 loss by the precision coefs
+787 """
+788 super () . __init__ ()
+789
+790 self . gw_mod = gw_mod
+791 """The GWModule."""
+792
+793 self . selection_mod = selection_mod
+794 """Selection module"""
+795
+796 self . domain_mods = domain_mods
+797 """Domain modules linked to the GW."""
+798
+799 self . loss_coefs = loss_coefs
+800 """The loss coefficients."""
+801
+802 self . contrastive_fn = contrastive_fn
+803 """
+804 Contrastive loss to use.
+805 """
+806
+807 self . use_normalized_constrastive = use_normalized_constrastive
+808
+809 def contrastive_loss (
+810 self , latent_domains : LatentsDomainGroupsT
+811 ) -> dict [ str , torch . Tensor ]:
+812 """
+813 Contrastive loss.
+814
+815 Args:
+816 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+817
+818 Returns:
+819 `dict[str, torch.Tensor]`: a dict of metrics.
+820 """
+821 if self . use_normalized_constrastive :
+822 return contrastive_loss_bayesian (
+823 self . gw_mod , latent_domains , self . contrastive_fn
+824 )
+825 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+826
+827 def broadcast_loss (
+828 self , latent_domains : LatentsDomainGroupsT
+829 ) -> dict [ str , torch . Tensor ]:
+830 return broadcast_loss (
+831 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+832 )
+833
+834 def step (
+835 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+836 ) -> LossOutput :
+837 """
+838 Performs a step of loss computation.
+839
+840 Args:
+841 domain_latents: Latent representations for all domains.
+842 mode: The mode in which the model is currently operating.
+843
+844 Returns:
+845 A LossOutput object containing the loss and metrics for this step.
+846 """
+847
+848 metrics : dict [ str , torch . Tensor ] = {}
+849
+850 metrics . update ( self . contrastive_loss ( domain_latents ))
+851 metrics . update ( self . broadcast_loss ( domain_latents ))
+852
+853 loss = torch . stack (
+854 [
+855 metrics [ name ] * coef
+856 for name , coef in self . loss_coefs . items ()
+857 if isinstance ( coef , float ) and coef > 0
+858 ],
+859 dim = 0 ,
+860 ) . mean ()
+861
+862 metrics [ "broadcast_loss" ] = torch . stack (
+863 [
+864 metrics [ name ]
+865 for name , coef in self . loss_coefs . items ()
+866 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+867 ],
+868 dim = 0 ,
+869 ) . mean ()
+870
+871 return LossOutput ( loss , metrics )
+
+
+
+
+
+
+
+
+
+
+
765 def __init__ (
+766 self ,
+767 gw_mod : GWModuleBayesian ,
+768 selection_mod : SelectionBase ,
+769 domain_mods : dict [ str , DomainModule ],
+770 loss_coefs : BroadcastLossCoefs ,
+771 contrastive_fn : ContrastiveLossType ,
+772 use_normalized_constrastive : bool = True ,
+773 ):
+774 """
+775 Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+776
+777 Args:
+778 gw_mod (`GWModuleBayesian`): the GWModule
+779 selection_mod (`SelectionBase`): selection module
+780 domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+781 domain name and value is the DomainModule
+782 loss_coefs (`BroadcastLossCoefs`): loss coefficients
+783 contrastive_fn (`ContrastiveLossType`): the contrastive function
+784 to use in contrastive loss
+785 use_normalized_constrastive (`bool`): whether to use the normalized cont
+786 loss by the precision coefs
+787 """
+788 super () . __init__ ()
+789
+790 self . gw_mod = gw_mod
+791 """The GWModule."""
+792
+793 self . selection_mod = selection_mod
+794 """Selection module"""
+795
+796 self . domain_mods = domain_mods
+797 """Domain modules linked to the GW."""
+798
+799 self . loss_coefs = loss_coefs
+800 """The loss coefficients."""
+801
+802 self . contrastive_fn = contrastive_fn
+803 """
+804 Contrastive loss to use.
+805 """
+806
+807 self . use_normalized_constrastive = use_normalized_constrastive
+
+
+
+
Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+
+
Arguments:
+
+
+gw_mod (GWModuleBayesian
): the GWModule
+selection_mod (SelectionBase
): selection module
+domain_mods (dict[str, DomainModule]
): a dict where the key is the
+domain name and value is the DomainModule
+loss_coefs (BroadcastLossCoefs
): loss coefficients
+contrastive_fn (ContrastiveLossType
): the contrastive function
+to use in contrastive loss
+use_normalized_constrastive (bool
): whether to use the normalized cont
+loss by the precision coefs
+
+
+
+
+
+
+
+ gw_mod
+
+
+
+
+
+
+
+
+
+
+
+ selection_mod
+
+
+
+
+
+
+
+
+
+
+
+ domain_mods
+
+
+
+
+
+
Domain modules linked to the GW.
+
+
+
+
+
+
+ loss_coefs
+
+
+
+
+
+
+
+
+
+
+
+ contrastive_fn
+
+
+
+
+
+
Contrastive loss to use.
+
+
+
+
+
+
+ use_normalized_constrastive
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ contrastive_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
809 def contrastive_loss (
+810 self , latent_domains : LatentsDomainGroupsT
+811 ) -> dict [ str , torch . Tensor ]:
+812 """
+813 Contrastive loss.
+814
+815 Args:
+816 latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+817
+818 Returns:
+819 `dict[str, torch.Tensor]`: a dict of metrics.
+820 """
+821 if self . use_normalized_constrastive :
+822 return contrastive_loss_bayesian (
+823 self . gw_mod , latent_domains , self . contrastive_fn
+824 )
+825 return contrastive_loss ( self . gw_mod , latent_domains , self . contrastive_fn )
+
+
+
+
Contrastive loss.
+
+
Arguments:
+
+
+latent_domains (LatentsDomainGroupsT
): the latent unimodal groups
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: a dict of metrics.
+
+
+
+
+
+
+
+
+
+ def
+ broadcast_loss ( self , latent_domains : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
827 def broadcast_loss (
+828 self , latent_domains : LatentsDomainGroupsT
+829 ) -> dict [ str , torch . Tensor ]:
+830 return broadcast_loss (
+831 self . gw_mod , self . selection_mod , self . domain_mods , latent_domains
+832 )
+
+
+
+
+
+
+
+
+
+
+
def
+
step ( self , domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] , mode : Literal [ 'train' , 'val' , 'test' , 'val/ood' , 'test/ood' ] ) -> shimmer.modules.domain.LossOutput :
+
+
View Source
+
+
+
+
834 def step (
+835 self , domain_latents : LatentsDomainGroupsT , mode : ModelModeT
+836 ) -> LossOutput :
+837 """
+838 Performs a step of loss computation.
+839
+840 Args:
+841 domain_latents: Latent representations for all domains.
+842 mode: The mode in which the model is currently operating.
+843
+844 Returns:
+845 A LossOutput object containing the loss and metrics for this step.
+846 """
+847
+848 metrics : dict [ str , torch . Tensor ] = {}
+849
+850 metrics . update ( self . contrastive_loss ( domain_latents ))
+851 metrics . update ( self . broadcast_loss ( domain_latents ))
+852
+853 loss = torch . stack (
+854 [
+855 metrics [ name ] * coef
+856 for name , coef in self . loss_coefs . items ()
+857 if isinstance ( coef , float ) and coef > 0
+858 ],
+859 dim = 0 ,
+860 ) . mean ()
+861
+862 metrics [ "broadcast_loss" ] = torch . stack (
+863 [
+864 metrics [ name ]
+865 for name , coef in self . loss_coefs . items ()
+866 if isinstance ( coef , float ) and coef > 0 and name != "contrastives"
+867 ],
+868 dim = 0 ,
+869 ) . mean ()
+870
+871 return LossOutput ( loss , metrics )
+
+
+
+
Performs a step of loss computation.
+
+
Arguments:
+
+
+domain_latents: Latent representations for all domains.
+mode: The mode in which the model is currently operating.
+
+
+
Returns:
+
+
+ A LossOutput object containing the loss and metrics for this step.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ forward
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/selection.html b/docs/api/v0.5.1/shimmer/modules/selection.html
new file mode 100644
index 00000000..6a4bb9c0
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/selection.html
@@ -0,0 +1,2151 @@
+
+
+
+
+
+
+ shimmer.modules.selection API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.selection
+
+
+
+
+ View Source
+
+ 1 from abc import ABC , abstractmethod
+ 2 from collections.abc import Iterable
+ 3
+ 4 import torch
+ 5 import torch.nn as nn
+ 6
+ 7 from shimmer.types import LatentsDomainGroupT
+ 8 from shimmer.utils import group_batch_size , group_device
+ 9
+ 10
+ 11 class SelectionBase ( torch . nn . Module , ABC ):
+ 12 """
+ 13 This is the base class for the selection mechanism.
+ 14 The selection mechanisms handles the "competition" between modules and *selects*
+ 15 fusion coefficients for the domains.
+ 16 """
+ 17
+ 18 def update_gw_state ( self , gw_state : torch . Tensor ) -> None :
+ 19 """
+ 20 Update the internal copy of the previous GW state.
+ 21 By default, this is not implemented and will raise an error if used.
+ 22
+ 23 :note..
+ 24 This is not defined as an abstractmethod as some selection method may
+ 25 not need it.
+ 26
+ 27 Args:
+ 28 gw_state (`torch.Tensor`): the previous GW state
+ 29 """
+ 30 pass
+ 31
+ 32 @abstractmethod
+ 33 def forward (
+ 34 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+ 35 ) -> dict [ str , torch . Tensor ]:
+ 36 """
+ 37 Forward pass of the selection method.
+ 38
+ 39 Args:
+ 40 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+ 41
+ 42 Returns:
+ 43 `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+ 44 coefficient for each item in the batch.
+ 45
+ 46 Example:
+ 47 >>> SomeSelectionImplementation().forward(
+ 48 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+ 49 ... )
+ 50 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+ 51 """
+ 52 ...
+ 53
+ 54 # This is just for proper auto-completion...
+ 55 def __call__ (
+ 56 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+ 57 ) -> dict [ str , torch . Tensor ]:
+ 58 return super () . __call__ ( domains , encodings_pre_fusion )
+ 59
+ 60
+ 61 class SingleDomainSelection ( SelectionBase ):
+ 62 """
+ 63 This selection mechanism handles groups that can have multiple domains, but always
+ 64 return a selection of 1 domain from the group with a uniform distribution.
+ 65
+ 66 For example, if the group has 2 domains, there is a 50% chance of selecting each
+ 67 domain.
+ 68 """
+ 69
+ 70 def forward (
+ 71 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+ 72 ) -> dict [ str , torch . Tensor ]:
+ 73 """
+ 74 Forward pass of the module.
+ 75
+ 76 Args:
+ 77 domains (`LatentsDomainGroupT`): input unimodal latent representations
+ 78 gw_state (`torch.Tensor`): the previous GW state
+ 79
+ 80 Returns:
+ 81 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+ 82 in the batch.
+ 83 """
+ 84 selection : dict [ str , torch . Tensor ] = {}
+ 85 bs = group_batch_size ( domains )
+ 86 choice = torch . randint ( len ( domains ), size = ( bs ,), device = group_device ( domains ))
+ 87 for k , domain in enumerate ( domains . keys ()):
+ 88 selection [ domain ] = ( choice == k ) . to ( torch . float32 )
+ 89 return selection
+ 90
+ 91
+ 92 class FixedSharedSelection ( SelectionBase ):
+ 93 """
+ 94 This selection mechanism is deterministic and always shares the weights equally
+ 95 between domains.
+ 96
+ 97 For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
+ 98 """
+ 99
+100 def forward (
+101 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+102 ) -> dict [ str , torch . Tensor ]:
+103 """
+104 Forward pass of the module.
+105
+106 Args:
+107 domains (`LatentsDomainGroupT`): input unimodal latent representations
+108 gw_state (`torch.Tensor`): the previous GW state
+109
+110 Returns:
+111 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+112 in the batch.
+113 """
+114 selection : dict [ str , torch . Tensor ] = {}
+115 bs = group_batch_size ( domains )
+116 coef = torch . full (( bs ,), 1.0 / len ( domains ), device = group_device ( domains ))
+117 for domain in domains :
+118 selection [ domain ] = coef . clone ()
+119 return selection
+120
+121
+122 def _calculate_attention_dict (
+123 domains : LatentsDomainGroupT ,
+124 keys : dict [ str , torch . Tensor ],
+125 query : torch . Tensor ,
+126 ) -> dict [ str , torch . Tensor ]:
+127 """
+128 Args:
+129 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+130 keys (`dict[str, torch.Tensor]`): The keys for each domain.
+131 query (`torch.Tensor`): The query tensor.
+132
+133 Returns:
+134 `dict[str, torch.Tensor]`: The attention scores for each domain.
+135 """
+136 dot_products = {
+137 domain : torch . bmm ( key . unsqueeze ( 1 ), query . unsqueeze ( 2 )) . squeeze ()
+138 for domain , key in keys . items ()
+139 }
+140
+141 dot_products_tensor = torch . stack ( list ( dot_products . values ()), dim = 1 )
+142
+143 attention_scores = torch . softmax ( dot_products_tensor , dim = 1 )
+144
+145 attention_dict = {
+146 domain : attention_scores [:, i ] for i , domain in enumerate ( domains )
+147 }
+148 return attention_dict
+149
+150
+151 class KQFixedQSelection ( SelectionBase ):
+152 """
+153 Key-Query attention with a fixed gw vector.
+154 """
+155
+156 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+157 """
+158 Args:
+159 head_size (`int`) : dimension of the key and query vectors.
+160 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+161 for now)
+162 domain_names (`Iterable[str]`) : list of input domains
+163 """
+164 super () . __init__ ()
+165 self . head_size = head_size
+166 self . query_layer = nn . Linear ( domain_dim , head_size )
+167 self . key_layers = nn . ModuleDict (
+168 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+169 )
+170 # Start with a random gw state
+171 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+172
+173 def forward (
+174 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+175 ) -> dict [ str , torch . Tensor ]:
+176 """
+177 Compute keys and queries, match them with dot product and softmax.
+178 Does this twice, once with the static query and once with a dynamic query.
+179
+180 Args:
+181 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+182 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+183
+184 Returns:
+185 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+186 group.
+187 """
+188
+189 keys = {
+190 domain : self . key_layers [ domain ]( encoding )
+191 for domain , encoding in domains . items ()
+192 }
+193
+194 batch_size = group_batch_size ( domains )
+195
+196 # Retrieve random query
+197 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+198
+199 # Calculate the attention scores
+200 return _calculate_attention_dict ( domains , keys , query )
+201
+202
+203 class RandomSelection ( SelectionBase ):
+204 """
+205 Modified random attention to only utilize uniform-softmax scores across modalities.
+206 This version omits the binary scaling factors and focuses on generating attention
+207 coefficients using a uniform distribution followed by a domain-wise softmax.
+208 """
+209
+210 def __init__ ( self , temperature : float ):
+211 """
+212 Args:
+213 temperature (`float`): Temperature of the softmax applied to uniform
+214 scaling factors.
+215 """
+216 super () . __init__ ()
+217 self . temperature = temperature
+218
+219 def forward (
+220 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+221 ) -> dict [ str , torch . Tensor ]:
+222 """
+223 Generate uniform-then-domain-wise-softmaxed samples for each domain.
+224
+225 Args:
+226 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+227 This is not used in the function directly but determines the structure
+228 of the returned attention coefficients.
+229
+230 Returns:
+231 `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+232 coefficient for each item in the batch, based solely on
+233 uniform-softmax scores.
+234 """
+235 num_domains = len ( domains )
+236 batch_size = group_batch_size ( domains )
+237 device = group_device ( domains )
+238
+239 # Generate uniform scores
+240 uniform_scores = torch . rand ( batch_size , num_domains , device = device )
+241
+242 # Apply softmax across domains with temperature scaling
+243 softmax_scores = torch . softmax ( uniform_scores / self . temperature , dim = 1 )
+244 # Create attention dictionary for each domain
+245 attention_dict = {
+246 domain : softmax_scores [:, i ] for i , domain in enumerate ( domains )
+247 }
+248
+249 return attention_dict
+250
+251
+252 class DynamicQueryAttention ( SelectionBase ):
+253 """
+254 Key-Query attention with a dynamic gw vector.
+255 The query is updated based on the scaled gw vector.
+256 """
+257
+258 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+259 """
+260 Args:
+261 head_size (`int`) : dimension of the key and query vectors.
+262 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+263 for now)
+264 domain_names (`Iterable[str]`) : list of input domains
+265 """
+266 super () . __init__ ()
+267 self . head_size = head_size
+268 self . query_layer = nn . Linear ( domain_dim , head_size )
+269 self . key_layers = nn . ModuleDict (
+270 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+271 )
+272 # Start with a random gw state
+273 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+274
+275 def fuse_weighted_encodings (
+276 self , encodings : LatentsDomainGroupT , attention_dict : dict [ str , torch . Tensor ]
+277 ) -> torch . Tensor :
+278 """
+279 Fuse the weighted encodings using the attention scores.
+280
+281 Args:
+282 encodings (`LatentsDomainGroupT`): Unimodal latent representation
+283 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+284 domain in the group.
+285
+286 Returns:
+287 `torch.Tensor`: The fused tensor.
+288 """
+289 # Apply attention scores to the encodings
+290 weighted_encodings = {}
+291 for key in attention_dict :
+292 if key in encodings :
+293 # Perform element-wise multiplication
+294 weighted_encodings [ key ] = (
+295 attention_dict [ key ] . unsqueeze ( 1 ) * encodings [ key ]
+296 )
+297
+298 # Stack the tensors along a new dimension (dimension 0)
+299 stacked_tensors = torch . stack ( list ( weighted_encodings . values ()))
+300
+301 # Apply fusion by summing along the newly created dimension
+302 summed_tensor = torch . sum ( stacked_tensors , dim = 0 )
+303 return summed_tensor
+304
+305 def forward (
+306 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+307 ) -> dict [ str , torch . Tensor ]:
+308 """
+309 Compute keys and queries, match them with dot product and softmax.
+310 Does this twice, once with the static query and once with a dynamic query.
+311
+312 Args:
+313 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+314 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+315
+316 Returns:
+317 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+318 group.
+319 """
+320
+321 keys = {
+322 domain : self . key_layers [ domain ]( encoding )
+323 for domain , encoding in domains . items ()
+324 }
+325
+326 batch_size = group_batch_size ( domains )
+327
+328 # Retrieve random query
+329 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+330
+331 # Calculate the attention scores
+332 static_attention_dict = _calculate_attention_dict ( domains , keys , query )
+333
+334 # Apply the attention scores to the encodings
+335 summed_tensor = self . fuse_weighted_encodings (
+336 encodings_pre_fusion , static_attention_dict
+337 )
+338
+339 # Retrieve query (now it is dependent on the new gw state)
+340 query = self . query_layer ( summed_tensor )
+341
+342 # Calculate the attention scores again
+343 dynamic_attention_dict = _calculate_attention_dict ( domains , keys , query )
+344
+345 return dynamic_attention_dict
+
+
+
+
+
+
+
+
+ class
+ SelectionBase (torch.nn.modules.module.Module , abc.ABC ):
+
+ View Source
+
+
+
+ 12 class SelectionBase ( torch . nn . Module , ABC ):
+13 """
+14 This is the base class for the selection mechanism.
+15 The selection mechanisms handles the "competition" between modules and *selects*
+16 fusion coefficients for the domains.
+17 """
+18
+19 def update_gw_state ( self , gw_state : torch . Tensor ) -> None :
+20 """
+21 Update the internal copy of the previous GW state.
+22 By default, this is not implemented and will raise an error if used.
+23
+24 :note..
+25 This is not defined as an abstractmethod as some selection method may
+26 not need it.
+27
+28 Args:
+29 gw_state (`torch.Tensor`): the previous GW state
+30 """
+31 pass
+32
+33 @abstractmethod
+34 def forward (
+35 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+36 ) -> dict [ str , torch . Tensor ]:
+37 """
+38 Forward pass of the selection method.
+39
+40 Args:
+41 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+42
+43 Returns:
+44 `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+45 coefficient for each item in the batch.
+46
+47 Example:
+48 >>> SomeSelectionImplementation().forward(
+49 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+50 ... )
+51 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+52 """
+53 ...
+54
+55 # This is just for proper auto-completion...
+56 def __call__ (
+57 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+58 ) -> dict [ str , torch . Tensor ]:
+59 return super () . __call__ ( domains , encodings_pre_fusion )
+
+
+
+ This is the base class for the selection mechanism.
+The selection mechanisms handles the "competition" between modules and selects
+fusion coefficients for the domains.
+
+
+
+
+
+
+
+ def
+ update_gw_state (self , gw_state : torch . Tensor ) -> None :
+
+ View Source
+
+
+
+
19 def update_gw_state ( self , gw_state : torch . Tensor ) -> None :
+20 """
+21 Update the internal copy of the previous GW state.
+22 By default, this is not implemented and will raise an error if used.
+23
+24 :note..
+25 This is not defined as an abstractmethod as some selection method may
+26 not need it.
+27
+28 Args:
+29 gw_state (`torch.Tensor`): the previous GW state
+30 """
+31 pass
+
+
+
+
Update the internal copy of the previous GW state.
+By default, this is not implemented and will raise an error if used.
+
+
:note..
+ This is not defined as an abstractmethod as some selection method may
+ not need it.
+
+
Arguments:
+
+
+gw_state (torch.Tensor
): the previous GW state
+
+
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+
View Source
+
+
+
+
33 @abstractmethod
+34 def forward (
+35 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+36 ) -> dict [ str , torch . Tensor ]:
+37 """
+38 Forward pass of the selection method.
+39
+40 Args:
+41 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+42
+43 Returns:
+44 `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+45 coefficient for each item in the batch.
+46
+47 Example:
+48 >>> SomeSelectionImplementation().forward(
+49 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+50 ... )
+51 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+52 """
+53 ...
+
+
+
+
Forward pass of the selection method.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): Group of unimodal latent representations.
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: for each domain in the group, the fusion
+ coefficient for each item in the batch.
+
+
+
Example:
+
+
+
+
>>> SomeSelectionImplementation () . forward (
+... { "v" : torch . randn ( 3 , 4 ), "t" : torch . randn ( 3 , 8 )}
+... )
+{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+
+
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+
class
+
SingleDomainSelection (SelectionBase ):
+
+ View Source
+
+
+
+ 62 class SingleDomainSelection ( SelectionBase ):
+63 """
+64 This selection mechanism handles groups that can have multiple domains, but always
+65 return a selection of 1 domain from the group with a uniform distribution.
+66
+67 For example, if the group has 2 domains, there is a 50% chance of selecting each
+68 domain.
+69 """
+70
+71 def forward (
+72 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+73 ) -> dict [ str , torch . Tensor ]:
+74 """
+75 Forward pass of the module.
+76
+77 Args:
+78 domains (`LatentsDomainGroupT`): input unimodal latent representations
+79 gw_state (`torch.Tensor`): the previous GW state
+80
+81 Returns:
+82 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+83 in the batch.
+84 """
+85 selection : dict [ str , torch . Tensor ] = {}
+86 bs = group_batch_size ( domains )
+87 choice = torch . randint ( len ( domains ), size = ( bs ,), device = group_device ( domains ))
+88 for k , domain in enumerate ( domains . keys ()):
+89 selection [ domain ] = ( choice == k ) . to ( torch . float32 )
+90 return selection
+
+
+
+ This selection mechanism handles groups that can have multiple domains, but always
+return a selection of 1 domain from the group with a uniform distribution.
+
+
For example, if the group has 2 domains, there is a 50% chance of selecting each
+domain.
+
+
+
+
+
+
+
+ def
+ forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
71 def forward (
+72 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+73 ) -> dict [ str , torch . Tensor ]:
+74 """
+75 Forward pass of the module.
+76
+77 Args:
+78 domains (`LatentsDomainGroupT`): input unimodal latent representations
+79 gw_state (`torch.Tensor`): the previous GW state
+80
+81 Returns:
+82 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+83 in the batch.
+84 """
+85 selection : dict [ str , torch . Tensor ] = {}
+86 bs = group_batch_size ( domains )
+87 choice = torch . randint ( len ( domains ), size = ( bs ,), device = group_device ( domains ))
+88 for k , domain in enumerate ( domains . keys ()):
+89 selection [ domain ] = ( choice == k ) . to ( torch . float32 )
+90 return selection
+
+
+
+
Forward pass of the module.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): input unimodal latent representations
+gw_state (torch.Tensor
): the previous GW state
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: whether the domain is selected for each input
+ in the batch.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+ get_extra_state
+ set_extra_state
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ extra_repr
+ compile
+
+
+
+
+
+
+
+
+
+
+
class
+
FixedSharedSelection (SelectionBase ):
+
+ View Source
+
+
+
+ 93 class FixedSharedSelection ( SelectionBase ):
+ 94 """
+ 95 This selection mechanism is deterministic and always shares the weights equally
+ 96 between domains.
+ 97
+ 98 For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
+ 99 """
+100
+101 def forward (
+102 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+103 ) -> dict [ str , torch . Tensor ]:
+104 """
+105 Forward pass of the module.
+106
+107 Args:
+108 domains (`LatentsDomainGroupT`): input unimodal latent representations
+109 gw_state (`torch.Tensor`): the previous GW state
+110
+111 Returns:
+112 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+113 in the batch.
+114 """
+115 selection : dict [ str , torch . Tensor ] = {}
+116 bs = group_batch_size ( domains )
+117 coef = torch . full (( bs ,), 1.0 / len ( domains ), device = group_device ( domains ))
+118 for domain in domains :
+119 selection [ domain ] = coef . clone ()
+120 return selection
+
+
+
+ This selection mechanism is deterministic and always shares the weights equally
+between domains.
+
+
For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
+
+
+
+
+
+
+
+ def
+ forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
101 def forward (
+102 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+103 ) -> dict [ str , torch . Tensor ]:
+104 """
+105 Forward pass of the module.
+106
+107 Args:
+108 domains (`LatentsDomainGroupT`): input unimodal latent representations
+109 gw_state (`torch.Tensor`): the previous GW state
+110
+111 Returns:
+112 `dict[str, torch.Tensor]`: whether the domain is selected for each input
+113 in the batch.
+114 """
+115 selection : dict [ str , torch . Tensor ] = {}
+116 bs = group_batch_size ( domains )
+117 coef = torch . full (( bs ,), 1.0 / len ( domains ), device = group_device ( domains ))
+118 for domain in domains :
+119 selection [ domain ] = coef . clone ()
+120 return selection
+
+
+
+
Forward pass of the module.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): input unimodal latent representations
+gw_state (torch.Tensor
): the previous GW state
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: whether the domain is selected for each input
+ in the batch.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+
+
class
+
KQFixedQSelection (SelectionBase ):
+
+ View Source
+
+
+
+ 152 class KQFixedQSelection ( SelectionBase ):
+153 """
+154 Key-Query attention with a fixed gw vector.
+155 """
+156
+157 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+158 """
+159 Args:
+160 head_size (`int`) : dimension of the key and query vectors.
+161 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+162 for now)
+163 domain_names (`Iterable[str]`) : list of input domains
+164 """
+165 super () . __init__ ()
+166 self . head_size = head_size
+167 self . query_layer = nn . Linear ( domain_dim , head_size )
+168 self . key_layers = nn . ModuleDict (
+169 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+170 )
+171 # Start with a random gw state
+172 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+173
+174 def forward (
+175 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+176 ) -> dict [ str , torch . Tensor ]:
+177 """
+178 Compute keys and queries, match them with dot product and softmax.
+179 Does this twice, once with the static query and once with a dynamic query.
+180
+181 Args:
+182 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+183 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+184
+185 Returns:
+186 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+187 group.
+188 """
+189
+190 keys = {
+191 domain : self . key_layers [ domain ]( encoding )
+192 for domain , encoding in domains . items ()
+193 }
+194
+195 batch_size = group_batch_size ( domains )
+196
+197 # Retrieve random query
+198 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+199
+200 # Calculate the attention scores
+201 return _calculate_attention_dict ( domains , keys , query )
+
+
+
+ Key-Query attention with a fixed gw vector.
+
+
+
+
+
+
+
+ KQFixedQSelection ( head_size : int , domain_dim : int , domain_names : collections . abc . Iterable [ str ] )
+
+ View Source
+
+
+
+
157 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+158 """
+159 Args:
+160 head_size (`int`) : dimension of the key and query vectors.
+161 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+162 for now)
+163 domain_names (`Iterable[str]`) : list of input domains
+164 """
+165 super () . __init__ ()
+166 self . head_size = head_size
+167 self . query_layer = nn . Linear ( domain_dim , head_size )
+168 self . key_layers = nn . ModuleDict (
+169 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+170 )
+171 # Start with a random gw state
+172 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+
+
+
+
Arguments:
+
+
+head_size (int
) : dimension of the key and query vectors.
+domain_dim (int
) : dimension of the input dims (assumed to be the same
+for now)
+domain_names (Iterable[str]
) : list of input domains
+
+
+
+
+
+
+
+ head_size
+
+
+
+
+
+
+
+
+
+
+ query_layer
+
+
+
+
+
+
+
+
+
+
+ key_layers
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
174 def forward (
+175 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+176 ) -> dict [ str , torch . Tensor ]:
+177 """
+178 Compute keys and queries, match them with dot product and softmax.
+179 Does this twice, once with the static query and once with a dynamic query.
+180
+181 Args:
+182 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+183 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+184
+185 Returns:
+186 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+187 group.
+188 """
+189
+190 keys = {
+191 domain : self . key_layers [ domain ]( encoding )
+192 for domain , encoding in domains . items ()
+193 }
+194
+195 batch_size = group_batch_size ( domains )
+196
+197 # Retrieve random query
+198 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+199
+200 # Calculate the attention scores
+201 return _calculate_attention_dict ( domains , keys , query )
+
+
+
+
Compute keys and queries, match them with dot product and softmax.
+Does this twice, once with the static query and once with a dynamic query.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): Group of unimodal latent representations.
+encodings (LatentsDomainGroupT
): Group of pre-fusion encodings.
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: the attention scores for each domain in the
+ group.
+
+
+
+
+
+
+
Inherited Members
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ 204 class RandomSelection ( SelectionBase ):
+205 """
+206 Modified random attention to only utilize uniform-softmax scores across modalities.
+207 This version omits the binary scaling factors and focuses on generating attention
+208 coefficients using a uniform distribution followed by a domain-wise softmax.
+209 """
+210
+211 def __init__ ( self , temperature : float ):
+212 """
+213 Args:
+214 temperature (`float`): Temperature of the softmax applied to uniform
+215 scaling factors.
+216 """
+217 super () . __init__ ()
+218 self . temperature = temperature
+219
+220 def forward (
+221 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+222 ) -> dict [ str , torch . Tensor ]:
+223 """
+224 Generate uniform-then-domain-wise-softmaxed samples for each domain.
+225
+226 Args:
+227 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+228 This is not used in the function directly but determines the structure
+229 of the returned attention coefficients.
+230
+231 Returns:
+232 `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+233 coefficient for each item in the batch, based solely on
+234 uniform-softmax scores.
+235 """
+236 num_domains = len ( domains )
+237 batch_size = group_batch_size ( domains )
+238 device = group_device ( domains )
+239
+240 # Generate uniform scores
+241 uniform_scores = torch . rand ( batch_size , num_domains , device = device )
+242
+243 # Apply softmax across domains with temperature scaling
+244 softmax_scores = torch . softmax ( uniform_scores / self . temperature , dim = 1 )
+245 # Create attention dictionary for each domain
+246 attention_dict = {
+247 domain : softmax_scores [:, i ] for i , domain in enumerate ( domains )
+248 }
+249
+250 return attention_dict
+
+
+
+ Modified random attention to only utilize uniform-softmax scores across modalities.
+This version omits the binary scaling factors and focuses on generating attention
+coefficients using a uniform distribution followed by a domain-wise softmax.
+
+
+
+
+
+
+
+ RandomSelection (temperature : float )
+
+ View Source
+
+
+
+
211 def __init__ ( self , temperature : float ):
+212 """
+213 Args:
+214 temperature (`float`): Temperature of the softmax applied to uniform
+215 scaling factors.
+216 """
+217 super () . __init__ ()
+218 self . temperature = temperature
+
+
+
+
Arguments:
+
+
+temperature (float
): Temperature of the softmax applied to uniform
+scaling factors.
+
+
+
+
+
+
+
+ temperature
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
220 def forward (
+221 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+222 ) -> dict [ str , torch . Tensor ]:
+223 """
+224 Generate uniform-then-domain-wise-softmaxed samples for each domain.
+225
+226 Args:
+227 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+228 This is not used in the function directly but determines the structure
+229 of the returned attention coefficients.
+230
+231 Returns:
+232 `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+233 coefficient for each item in the batch, based solely on
+234 uniform-softmax scores.
+235 """
+236 num_domains = len ( domains )
+237 batch_size = group_batch_size ( domains )
+238 device = group_device ( domains )
+239
+240 # Generate uniform scores
+241 uniform_scores = torch . rand ( batch_size , num_domains , device = device )
+242
+243 # Apply softmax across domains with temperature scaling
+244 softmax_scores = torch . softmax ( uniform_scores / self . temperature , dim = 1 )
+245 # Create attention dictionary for each domain
+246 attention_dict = {
+247 domain : softmax_scores [:, i ] for i , domain in enumerate ( domains )
+248 }
+249
+250 return attention_dict
+
+
+
+
Generate uniform-then-domain-wise-softmaxed samples for each domain.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): Group of unimodal latent representations.
+This is not used in the function directly but determines the structure
+of the returned attention coefficients.
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: For each domain in the group, the fusion
+ coefficient for each item in the batch, based solely on
+ uniform-softmax scores.
+
+
+
+
+
+
+
Inherited Members
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+ get_extra_state
+ set_extra_state
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+ extra_repr
+ compile
+
+
+
+
+
+
+
+
+
+
class
+
DynamicQueryAttention (SelectionBase ):
+
+ View Source
+
+
+
+ 253 class DynamicQueryAttention ( SelectionBase ):
+254 """
+255 Key-Query attention with a dynamic gw vector.
+256 The query is updated based on the scaled gw vector.
+257 """
+258
+259 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+260 """
+261 Args:
+262 head_size (`int`) : dimension of the key and query vectors.
+263 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+264 for now)
+265 domain_names (`Iterable[str]`) : list of input domains
+266 """
+267 super () . __init__ ()
+268 self . head_size = head_size
+269 self . query_layer = nn . Linear ( domain_dim , head_size )
+270 self . key_layers = nn . ModuleDict (
+271 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+272 )
+273 # Start with a random gw state
+274 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+275
+276 def fuse_weighted_encodings (
+277 self , encodings : LatentsDomainGroupT , attention_dict : dict [ str , torch . Tensor ]
+278 ) -> torch . Tensor :
+279 """
+280 Fuse the weighted encodings using the attention scores.
+281
+282 Args:
+283 encodings (`LatentsDomainGroupT`): Unimodal latent representation
+284 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+285 domain in the group.
+286
+287 Returns:
+288 `torch.Tensor`: The fused tensor.
+289 """
+290 # Apply attention scores to the encodings
+291 weighted_encodings = {}
+292 for key in attention_dict :
+293 if key in encodings :
+294 # Perform element-wise multiplication
+295 weighted_encodings [ key ] = (
+296 attention_dict [ key ] . unsqueeze ( 1 ) * encodings [ key ]
+297 )
+298
+299 # Stack the tensors along a new dimension (dimension 0)
+300 stacked_tensors = torch . stack ( list ( weighted_encodings . values ()))
+301
+302 # Apply fusion by summing along the newly created dimension
+303 summed_tensor = torch . sum ( stacked_tensors , dim = 0 )
+304 return summed_tensor
+305
+306 def forward (
+307 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+308 ) -> dict [ str , torch . Tensor ]:
+309 """
+310 Compute keys and queries, match them with dot product and softmax.
+311 Does this twice, once with the static query and once with a dynamic query.
+312
+313 Args:
+314 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+315 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+316
+317 Returns:
+318 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+319 group.
+320 """
+321
+322 keys = {
+323 domain : self . key_layers [ domain ]( encoding )
+324 for domain , encoding in domains . items ()
+325 }
+326
+327 batch_size = group_batch_size ( domains )
+328
+329 # Retrieve random query
+330 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+331
+332 # Calculate the attention scores
+333 static_attention_dict = _calculate_attention_dict ( domains , keys , query )
+334
+335 # Apply the attention scores to the encodings
+336 summed_tensor = self . fuse_weighted_encodings (
+337 encodings_pre_fusion , static_attention_dict
+338 )
+339
+340 # Retrieve query (now it is dependent on the new gw state)
+341 query = self . query_layer ( summed_tensor )
+342
+343 # Calculate the attention scores again
+344 dynamic_attention_dict = _calculate_attention_dict ( domains , keys , query )
+345
+346 return dynamic_attention_dict
+
+
+
+ Key-Query attention with a dynamic gw vector.
+The query is updated based on the scaled gw vector.
+
+
+
+
+
+
+
+ DynamicQueryAttention ( head_size : int , domain_dim : int , domain_names : collections . abc . Iterable [ str ] )
+
+ View Source
+
+
+
+
259 def __init__ ( self , head_size : int , domain_dim : int , domain_names : Iterable [ str ]):
+260 """
+261 Args:
+262 head_size (`int`) : dimension of the key and query vectors.
+263 domain_dim (`int`) : dimension of the input dims (assumed to be the same
+264 for now)
+265 domain_names (`Iterable[str]`) : list of input domains
+266 """
+267 super () . __init__ ()
+268 self . head_size = head_size
+269 self . query_layer = nn . Linear ( domain_dim , head_size )
+270 self . key_layers = nn . ModuleDict (
+271 { domain : nn . Linear ( domain_dim , head_size ) for domain in domain_names }
+272 )
+273 # Start with a random gw state
+274 self . register_buffer ( "initial_gw_state" , torch . rand ( domain_dim ))
+
+
+
+
Arguments:
+
+
+head_size (int
) : dimension of the key and query vectors.
+domain_dim (int
) : dimension of the input dims (assumed to be the same
+for now)
+domain_names (Iterable[str]
) : list of input domains
+
+
+
+
+
+
+
+ head_size
+
+
+
+
+
+
+
+
+
+
+ query_layer
+
+
+
+
+
+
+
+
+
+
+ key_layers
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ fuse_weighted_encodings ( self , encodings : collections . abc . Mapping [ str , torch . Tensor ] , attention_dict : dict [ str , torch . Tensor ] ) -> torch . Tensor :
+
+ View Source
+
+
+
+
276 def fuse_weighted_encodings (
+277 self , encodings : LatentsDomainGroupT , attention_dict : dict [ str , torch . Tensor ]
+278 ) -> torch . Tensor :
+279 """
+280 Fuse the weighted encodings using the attention scores.
+281
+282 Args:
+283 encodings (`LatentsDomainGroupT`): Unimodal latent representation
+284 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+285 domain in the group.
+286
+287 Returns:
+288 `torch.Tensor`: The fused tensor.
+289 """
+290 # Apply attention scores to the encodings
+291 weighted_encodings = {}
+292 for key in attention_dict :
+293 if key in encodings :
+294 # Perform element-wise multiplication
+295 weighted_encodings [ key ] = (
+296 attention_dict [ key ] . unsqueeze ( 1 ) * encodings [ key ]
+297 )
+298
+299 # Stack the tensors along a new dimension (dimension 0)
+300 stacked_tensors = torch . stack ( list ( weighted_encodings . values ()))
+301
+302 # Apply fusion by summing along the newly created dimension
+303 summed_tensor = torch . sum ( stacked_tensors , dim = 0 )
+304 return summed_tensor
+
+
+
+
Fuse the weighted encodings using the attention scores.
+
+
Arguments:
+
+
+encodings (LatentsDomainGroupT
): Unimodal latent representation
+attention_dict (dict[str, torch.Tensor]
): The attention scores for each
+domain in the group.
+
+
+
Returns:
+
+
+ torch.Tensor
: The fused tensor.
+
+
+
+
+
+
+
+
+
+ def
+ forward ( self , domains : collections . abc . Mapping [ str , torch . Tensor ] , encodings_pre_fusion : collections . abc . Mapping [ str , torch . Tensor ] ) -> dict [ str , torch . Tensor ] :
+
+ View Source
+
+
+
+
306 def forward (
+307 self , domains : LatentsDomainGroupT , encodings_pre_fusion : LatentsDomainGroupT
+308 ) -> dict [ str , torch . Tensor ]:
+309 """
+310 Compute keys and queries, match them with dot product and softmax.
+311 Does this twice, once with the static query and once with a dynamic query.
+312
+313 Args:
+314 domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+315 encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+316
+317 Returns:
+318 `dict[str, torch.Tensor]`: the attention scores for each domain in the
+319 group.
+320 """
+321
+322 keys = {
+323 domain : self . key_layers [ domain ]( encoding )
+324 for domain , encoding in domains . items ()
+325 }
+326
+327 batch_size = group_batch_size ( domains )
+328
+329 # Retrieve random query
+330 query = self . query_layer ( self . initial_gw_state . expand ( batch_size , - 1 ))
+331
+332 # Calculate the attention scores
+333 static_attention_dict = _calculate_attention_dict ( domains , keys , query )
+334
+335 # Apply the attention scores to the encodings
+336 summed_tensor = self . fuse_weighted_encodings (
+337 encodings_pre_fusion , static_attention_dict
+338 )
+339
+340 # Retrieve query (now it is dependent on the new gw state)
+341 query = self . query_layer ( summed_tensor )
+342
+343 # Calculate the attention scores again
+344 dynamic_attention_dict = _calculate_attention_dict ( domains , keys , query )
+345
+346 return dynamic_attention_dict
+
+
+
+
Compute keys and queries, match them with dot product and softmax.
+Does this twice, once with the static query and once with a dynamic query.
+
+
Arguments:
+
+
+domains (LatentsDomainGroupT
): Group of unimodal latent representations.
+encodings (LatentsDomainGroupT
): Group of pre-fusion encodings.
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: the attention scores for each domain in the
+ group.
+
+
+
+
+
+
+
Inherited Members
+
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/utils.html b/docs/api/v0.5.1/shimmer/modules/utils.html
new file mode 100644
index 00000000..790987da
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/utils.html
@@ -0,0 +1,760 @@
+
+
+
+
+
+
+ shimmer.modules.utils API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.utils
+
+
+
+
+ View Source
+
+ 1 from collections.abc import Iterable
+ 2
+ 3 import torch
+ 4
+ 5 from shimmer.modules.gw_module import GWModuleBase
+ 6 from shimmer.modules.selection import SelectionBase
+ 7 from shimmer.types import (
+ 8 LatentsDomainGroupDT ,
+ 9 LatentsDomainGroupsT ,
+ 10 LatentsDomainGroupT ,
+ 11 )
+ 12
+ 13
+ 14 def translation (
+ 15 gw_module : GWModuleBase ,
+ 16 selection_mod : SelectionBase ,
+ 17 x : LatentsDomainGroupT ,
+ 18 to : str ,
+ 19 ) -> torch . Tensor :
+ 20 """
+ 21 Translate from multiple domains to one domain.
+ 22
+ 23 Args:
+ 24 gw_module (`GWModuleBase`): GWModule to perform the translation over
+ 25 selection_mod (`SelectionBase`): selection module
+ 26 x (`LatentsDomainGroupT`): the group of latent representations
+ 27 to (`str`): the domain name to encode to
+ 28
+ 29 Returns:
+ 30 `torch.Tensor`: the translated unimodal representation
+ 31 of the provided domain.
+ 32 """
+ 33 return gw_module . decode ( gw_module . encode_and_fuse ( x , selection_mod ), domains = { to })[
+ 34 to
+ 35 ]
+ 36
+ 37
+ 38 def cycle (
+ 39 gw_module : GWModuleBase ,
+ 40 selection_mod : SelectionBase ,
+ 41 x : LatentsDomainGroupT ,
+ 42 through : str ,
+ 43 ) -> LatentsDomainGroupDT :
+ 44 """
+ 45 Do a full cycle from a group of representation through one domain.
+ 46
+ 47 [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
+ 48
+ 49 Args:
+ 50 gw_module (`GWModuleBase`): GWModule to perform the translation over
+ 51 selection_mod (`SelectionBase`): selection module
+ 52 x (`LatentsDomainGroupT`): group of unimodal latent representation
+ 53 through (`str`): domain name to cycle through
+ 54 Returns:
+ 55 `LatentsDomainGroupDT`: group of unimodal latent representation after
+ 56 cycling.
+ 57 """
+ 58 return {
+ 59 domain : translation (
+ 60 gw_module ,
+ 61 selection_mod ,
+ 62 { through : translation ( gw_module , selection_mod , x , through )},
+ 63 domain ,
+ 64 )
+ 65 for domain in x
+ 66 }
+ 67
+ 68
+ 69 def batch_demi_cycles (
+ 70 gw_mod : GWModuleBase ,
+ 71 selection_mod : SelectionBase ,
+ 72 latent_domains : LatentsDomainGroupsT ,
+ 73 ) -> dict [ str , torch . Tensor ]:
+ 74 """
+ 75 Computes demi-cycles of a batch of groups of domains.
+ 76
+ 77 Args:
+ 78 gw_mod (`GWModuleBase`): the GWModuleBase
+ 79 selection_mod (`SelectionBase`): selection module
+ 80 latent_domains (`LatentsT`): the batch of groups of domains
+ 81
+ 82 Returns:
+ 83 `dict[str, torch.Tensor]`: demi-cycles predictions for each domain.
+ 84 """
+ 85 predictions : dict [ str , torch . Tensor ] = {}
+ 86 for domains , latents in latent_domains . items ():
+ 87 if len ( domains ) > 1 :
+ 88 continue
+ 89 domain_name = list ( domains )[ 0 ]
+ 90 z = translation ( gw_mod , selection_mod , latents , to = domain_name )
+ 91 predictions [ domain_name ] = z
+ 92 return predictions
+ 93
+ 94
+ 95 def batch_cycles (
+ 96 gw_mod : GWModuleBase ,
+ 97 selection_mod : SelectionBase ,
+ 98 latent_domains : LatentsDomainGroupsT ,
+ 99 through_domains : Iterable [ str ],
+100 ) -> dict [ tuple [ str , str ], torch . Tensor ]:
+101 """
+102 Computes cycles of a batch of groups of domains.
+103
+104 Args:
+105 gw_mod (`GWModuleBase`): GWModule to use for the cycle
+106 selection_mod (`SelectionBase`): selection module
+107 latent_domains (`LatentsT`): the batch of groups of domains
+108 out_domains (`Iterable[str]`): iterable of domain names to do the cycle through.
+109 Each domain will be done separetely.
+110
+111 Returns:
+112 `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each
+113 couple of (start domain, intermediary domain).
+114 """
+115 predictions : dict [ tuple [ str , str ], torch . Tensor ] = {}
+116 for domains_source , latents_source in latent_domains . items ():
+117 if len ( domains_source ) > 1 :
+118 continue
+119 domain_name_source = next ( iter ( domains_source ))
+120 for domain_name_through in through_domains :
+121 if domain_name_source == domain_name_through :
+122 continue
+123 z = cycle (
+124 gw_mod , selection_mod , latents_source , through = domain_name_through
+125 )
+126 domains = ( domain_name_source , domain_name_through )
+127 predictions [ domains ] = z [ domain_name_source ]
+128 return predictions
+129
+130
+131 def batch_translations (
+132 gw_mod : GWModuleBase ,
+133 selection_mod : SelectionBase ,
+134 latent_domains : LatentsDomainGroupsT ,
+135 ) -> dict [ tuple [ str , str ], torch . Tensor ]:
+136 """
+137 Computes translations of a batch of groups of domains.
+138
+139 Args:
+140 gw_mod (`GWModuleBase`): GWModule to do the translation
+141 selection_mod (`SelectionBase`): selection module
+142 latent_domains (`LatentsT`): the batch of groups of domains
+143
+144 Returns:
+145 `dict[tuple[str, str], torch.Tensor]`: translation predictions for each
+146 couple of (start domain, target domain).
+147 """
+148 predictions : dict [ tuple [ str , str ], torch . Tensor ] = {}
+149 for domains , latents in latent_domains . items ():
+150 if len ( domains ) < 2 :
+151 continue
+152 for domain_name_source in domains :
+153 for domain_name_target in domains :
+154 if domain_name_source == domain_name_target :
+155 continue
+156 prediction = translation (
+157 gw_mod ,
+158 selection_mod ,
+159 { domain_name_source : latents [ domain_name_source ]},
+160 to = domain_name_target ,
+161 )
+162 predictions [( domain_name_source , domain_name_target )] = prediction
+163 return predictions
+
+
+
+
+
+
+
+
+ 15 def translation (
+16 gw_module : GWModuleBase ,
+17 selection_mod : SelectionBase ,
+18 x : LatentsDomainGroupT ,
+19 to : str ,
+20 ) -> torch . Tensor :
+21 """
+22 Translate from multiple domains to one domain.
+23
+24 Args:
+25 gw_module (`GWModuleBase`): GWModule to perform the translation over
+26 selection_mod (`SelectionBase`): selection module
+27 x (`LatentsDomainGroupT`): the group of latent representations
+28 to (`str`): the domain name to encode to
+29
+30 Returns:
+31 `torch.Tensor`: the translated unimodal representation
+32 of the provided domain.
+33 """
+34 return gw_module . decode ( gw_module . encode_and_fuse ( x , selection_mod ), domains = { to })[
+35 to
+36 ]
+
+
+
+ Translate from multiple domains to one domain.
+
+
Arguments:
+
+
+gw_module (GWModuleBase
): GWModule to perform the translation over
+selection_mod (SelectionBase
): selection module
+x (LatentsDomainGroupT
): the group of latent representations
+to (str
): the domain name to encode to
+
+
+
Returns:
+
+
+ torch.Tensor
: the translated unimodal representation
+ of the provided domain.
+
+
+
+
+
+
+
+
+
+ 39 def cycle (
+40 gw_module : GWModuleBase ,
+41 selection_mod : SelectionBase ,
+42 x : LatentsDomainGroupT ,
+43 through : str ,
+44 ) -> LatentsDomainGroupDT :
+45 """
+46 Do a full cycle from a group of representation through one domain.
+47
+48 [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
+49
+50 Args:
+51 gw_module (`GWModuleBase`): GWModule to perform the translation over
+52 selection_mod (`SelectionBase`): selection module
+53 x (`LatentsDomainGroupT`): group of unimodal latent representation
+54 through (`str`): domain name to cycle through
+55 Returns:
+56 `LatentsDomainGroupDT`: group of unimodal latent representation after
+57 cycling.
+58 """
+59 return {
+60 domain : translation (
+61 gw_module ,
+62 selection_mod ,
+63 { through : translation ( gw_module , selection_mod , x , through )},
+64 domain ,
+65 )
+66 for domain in x
+67 }
+
+
+
+ Do a full cycle from a group of representation through one domain.
+
+
[Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
+
+
Arguments:
+
+
+gw_module (GWModuleBase
): GWModule to perform the translation over
+selection_mod (SelectionBase
): selection module
+x (LatentsDomainGroupT
): group of unimodal latent representation
+through (str
): domain name to cycle through
+
+
+
Returns:
+
+
+ LatentsDomainGroupDT
: group of unimodal latent representation after
+ cycling.
+
+
+
+
+
+
+
+
+
+ 70 def batch_demi_cycles (
+71 gw_mod : GWModuleBase ,
+72 selection_mod : SelectionBase ,
+73 latent_domains : LatentsDomainGroupsT ,
+74 ) -> dict [ str , torch . Tensor ]:
+75 """
+76 Computes demi-cycles of a batch of groups of domains.
+77
+78 Args:
+79 gw_mod (`GWModuleBase`): the GWModuleBase
+80 selection_mod (`SelectionBase`): selection module
+81 latent_domains (`LatentsT`): the batch of groups of domains
+82
+83 Returns:
+84 `dict[str, torch.Tensor]`: demi-cycles predictions for each domain.
+85 """
+86 predictions : dict [ str , torch . Tensor ] = {}
+87 for domains , latents in latent_domains . items ():
+88 if len ( domains ) > 1 :
+89 continue
+90 domain_name = list ( domains )[ 0 ]
+91 z = translation ( gw_mod , selection_mod , latents , to = domain_name )
+92 predictions [ domain_name ] = z
+93 return predictions
+
+
+
+ Computes demi-cycles of a batch of groups of domains.
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): the GWModuleBase
+selection_mod (SelectionBase
): selection module
+latent_domains (LatentsT
): the batch of groups of domains
+
+
+
Returns:
+
+
+ dict[str, torch.Tensor]
: demi-cycles predictions for each domain.
+
+
+
+
+
+
+
+
+
+ 96 def batch_cycles (
+ 97 gw_mod : GWModuleBase ,
+ 98 selection_mod : SelectionBase ,
+ 99 latent_domains : LatentsDomainGroupsT ,
+100 through_domains : Iterable [ str ],
+101 ) -> dict [ tuple [ str , str ], torch . Tensor ]:
+102 """
+103 Computes cycles of a batch of groups of domains.
+104
+105 Args:
+106 gw_mod (`GWModuleBase`): GWModule to use for the cycle
+107 selection_mod (`SelectionBase`): selection module
+108 latent_domains (`LatentsT`): the batch of groups of domains
+109 out_domains (`Iterable[str]`): iterable of domain names to do the cycle through.
+110 Each domain will be done separetely.
+111
+112 Returns:
+113 `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each
+114 couple of (start domain, intermediary domain).
+115 """
+116 predictions : dict [ tuple [ str , str ], torch . Tensor ] = {}
+117 for domains_source , latents_source in latent_domains . items ():
+118 if len ( domains_source ) > 1 :
+119 continue
+120 domain_name_source = next ( iter ( domains_source ))
+121 for domain_name_through in through_domains :
+122 if domain_name_source == domain_name_through :
+123 continue
+124 z = cycle (
+125 gw_mod , selection_mod , latents_source , through = domain_name_through
+126 )
+127 domains = ( domain_name_source , domain_name_through )
+128 predictions [ domains ] = z [ domain_name_source ]
+129 return predictions
+
+
+
+ Computes cycles of a batch of groups of domains.
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): GWModule to use for the cycle
+selection_mod (SelectionBase
): selection module
+latent_domains (LatentsT
): the batch of groups of domains
+out_domains (Iterable[str]
): iterable of domain names to do the cycle through.
+Each domain will be done separetely.
+
+
+
Returns:
+
+
+ dict[tuple[str, str], torch.Tensor]
: cycles predictions for each
+ couple of (start domain, intermediary domain).
+
+
+
+
+
+
+
+
+
+ 132 def batch_translations (
+133 gw_mod : GWModuleBase ,
+134 selection_mod : SelectionBase ,
+135 latent_domains : LatentsDomainGroupsT ,
+136 ) -> dict [ tuple [ str , str ], torch . Tensor ]:
+137 """
+138 Computes translations of a batch of groups of domains.
+139
+140 Args:
+141 gw_mod (`GWModuleBase`): GWModule to do the translation
+142 selection_mod (`SelectionBase`): selection module
+143 latent_domains (`LatentsT`): the batch of groups of domains
+144
+145 Returns:
+146 `dict[tuple[str, str], torch.Tensor]`: translation predictions for each
+147 couple of (start domain, target domain).
+148 """
+149 predictions : dict [ tuple [ str , str ], torch . Tensor ] = {}
+150 for domains , latents in latent_domains . items ():
+151 if len ( domains ) < 2 :
+152 continue
+153 for domain_name_source in domains :
+154 for domain_name_target in domains :
+155 if domain_name_source == domain_name_target :
+156 continue
+157 prediction = translation (
+158 gw_mod ,
+159 selection_mod ,
+160 { domain_name_source : latents [ domain_name_source ]},
+161 to = domain_name_target ,
+162 )
+163 predictions [( domain_name_source , domain_name_target )] = prediction
+164 return predictions
+
+
+
+ Computes translations of a batch of groups of domains.
+
+
Arguments:
+
+
+gw_mod (GWModuleBase
): GWModule to do the translation
+selection_mod (SelectionBase
): selection module
+latent_domains (LatentsT
): the batch of groups of domains
+
+
+
Returns:
+
+
+ dict[tuple[str, str], torch.Tensor]
: translation predictions for each
+ couple of (start domain, target domain).
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/modules/vae.html b/docs/api/v0.5.1/shimmer/modules/vae.html
new file mode 100644
index 00000000..c251cdcb
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/modules/vae.html
@@ -0,0 +1,1288 @@
+
+
+
+
+
+
+ shimmer.modules.vae API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.modules.vae
+
+
+
+
+ View Source
+
+ 1 import math
+ 2 from abc import ABC , abstractmethod
+ 3 from typing import Any
+ 4
+ 5 import torch
+ 6 from torch import nn
+ 7
+ 8
+ 9 def reparameterize ( mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+ 10 """
+ 11 Reparameterization trick for VAE
+ 12
+ 13 Args:
+ 14 mean (`torch.Tensor`): predicted means
+ 15 logvar (`torch.Tensor`): predicted log variance
+ 16
+ 17 Returns:
+ 18 `torch.Tensor`: a sample from normal distribution with provided
+ 19 parameters, sampled using the reparameterization trick.
+ 20 """
+ 21 std = ( 0.5 * logvar ) . exp ()
+ 22 eps = torch . randn_like ( std )
+ 23 return std * eps + mean
+ 24
+ 25
+ 26 def kl_divergence_loss ( mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+ 27 """
+ 28 Computes the KL divergence loss used in VAE.
+ 29
+ 30 Args:
+ 31 mean (`torch.Tensor`): predicted means
+ 32 logvar (`torch.Tensor`): predicted logvars
+ 33
+ 34 Returns:
+ 35 `torch.Tensor`: the loss
+ 36 """
+ 37 kl = - 0.5 * torch . sum ( 1 + logvar - mean . pow ( 2 ) - logvar . exp ())
+ 38 return kl
+ 39
+ 40
+ 41 def gaussian_nll (
+ 42 mu : torch . Tensor , log_sigma : torch . Tensor , x : torch . Tensor
+ 43 ) -> torch . Tensor :
+ 44 """
+ 45 Computes gaussian nll loss used in VAE.
+ 46
+ 47 Args:
+ 48 mu (`torch.Tensor`): predictions
+ 49 log_sigma (`torch.Tensor`): log sigma
+ 50 x (`torch.Tensor`): targets
+ 51
+ 52 Returns:
+ 53 `torch.Tensor`: the Gaussian NLL loss
+ 54 """
+ 55 return (
+ 56 0.5 * torch . pow (( x - mu ) / log_sigma . exp (), 2 )
+ 57 + log_sigma
+ 58 + 0.5 * math . log ( 2 * math . pi )
+ 59 )
+ 60
+ 61
+ 62 class VAEEncoder ( nn . Module , ABC ):
+ 63 """
+ 64 Base class for a VAE encoder.
+ 65 """
+ 66
+ 67 @abstractmethod
+ 68 def forward ( self , x : Any ) -> tuple [ torch . Tensor , torch . Tensor ]:
+ 69 """
+ 70 Encode representation with VAE.
+ 71
+ 72
+ 73 Args:
+ 74 x (`Any`): Some input value
+ 75
+ 76 Returns:
+ 77 `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+ 78 """
+ 79 ...
+ 80
+ 81
+ 82 class VAEDecoder ( nn . Module , ABC ):
+ 83 """
+ 84 Base class for a VAE decoder.
+ 85 """
+ 86
+ 87 @abstractmethod
+ 88 def forward ( self , x : torch . Tensor ) -> Any :
+ 89 """
+ 90 Decode representation with VAE
+ 91
+ 92 Args:
+ 93 x (`torch.Tensor`): VAE latent representation representation
+ 94
+ 95 Returns:
+ 96 `Any`: the reconstructed input
+ 97 """
+ 98 ...
+ 99
+100
+101 class VAE ( nn . Module ):
+102 """VAE module"""
+103
+104 def __init__ (
+105 self ,
+106 encoder : VAEEncoder ,
+107 decoder : VAEDecoder ,
+108 beta : float = 1 ,
+109 ):
+110 """
+111 Initializes a VAE.
+112
+113 Args:
+114 encoder (`VAEEncoder`): VAE encode
+115 decoder (`VAEDecoder`): VAE decoder
+116 beta (`float`): beta value for Beta-VAE. Defaults to 1.
+117 """
+118 super () . __init__ ()
+119
+120 assert beta >= 0
+121
+122 self . beta = beta
+123 """Beta value for Beta-VAEs"""
+124
+125 self . encoder = encoder
+126 """The encoder"""
+127
+128 self . decoder = decoder
+129 """The decoder"""
+130
+131 def encode ( self , x : Any ) -> torch . Tensor :
+132 """
+133 Encode the representation and returns the mean prediction of VAE.
+134
+135 Args:
+136 x (`Any`): Some input value
+137
+138 Returns:
+139 `torch.Tensor`: The mean representation.
+140 """
+141 mean_z , _ = self . encoder ( x )
+142 return mean_z
+143
+144 def decode ( self , z : torch . Tensor ) -> Any :
+145 """
+146 Decode the VAE latent representation into input value.
+147
+148 Args:
+149 z (`torch.Tensor`): the VAE latent representation.
+150
+151 Returns:
+152 `Any`: the reconstructed input.
+153 """
+154 return self . decoder ( z )
+155
+156 def forward ( self , x : Any ) -> tuple [ tuple [ torch . Tensor , torch . Tensor ], Any ]:
+157 """
+158 Encode and decodes from x.
+159
+160 Args:
+161 x (`Any`): the input data
+162
+163 Returns:
+164 `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+165 first tuple contains the mean and logvar of the encoded input,
+166 the second item is the reconstructed input.
+167 """
+168 mean , logvar = self . encoder ( x )
+169 z = reparameterize ( mean , logvar )
+170
+171 x_reconstructed = self . decoder ( z )
+172
+173 return ( mean , logvar ), x_reconstructed
+
+
+
+
+
+
+
+
+ def
+ reparameterize (mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 10 def reparameterize ( mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+11 """
+12 Reparameterization trick for VAE
+13
+14 Args:
+15 mean (`torch.Tensor`): predicted means
+16 logvar (`torch.Tensor`): predicted log variance
+17
+18 Returns:
+19 `torch.Tensor`: a sample from normal distribution with provided
+20 parameters, sampled using the reparameterization trick.
+21 """
+22 std = ( 0.5 * logvar ) . exp ()
+23 eps = torch . randn_like ( std )
+24 return std * eps + mean
+
+
+
+ Reparameterization trick for VAE
+
+
Arguments:
+
+
+mean (torch.Tensor
): predicted means
+logvar (torch.Tensor
): predicted log variance
+
+
+
Returns:
+
+
+ torch.Tensor
: a sample from normal distribution with provided
+ parameters, sampled using the reparameterization trick.
+
+
+
+
+
+
+
+
+
+ def
+ kl_divergence_loss (mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 27 def kl_divergence_loss ( mean : torch . Tensor , logvar : torch . Tensor ) -> torch . Tensor :
+28 """
+29 Computes the KL divergence loss used in VAE.
+30
+31 Args:
+32 mean (`torch.Tensor`): predicted means
+33 logvar (`torch.Tensor`): predicted logvars
+34
+35 Returns:
+36 `torch.Tensor`: the loss
+37 """
+38 kl = - 0.5 * torch . sum ( 1 + logvar - mean . pow ( 2 ) - logvar . exp ())
+39 return kl
+
+
+
+ Computes the KL divergence loss used in VAE.
+
+
Arguments:
+
+
+mean (torch.Tensor
): predicted means
+logvar (torch.Tensor
): predicted logvars
+
+
+
Returns:
+
+
+ torch.Tensor
: the loss
+
+
+
+
+
+
+
+
+
+ def
+ gaussian_nll ( mu : torch . Tensor , log_sigma : torch . Tensor , x : torch . Tensor ) -> torch . Tensor :
+
+ View Source
+
+
+
+ 42 def gaussian_nll (
+43 mu : torch . Tensor , log_sigma : torch . Tensor , x : torch . Tensor
+44 ) -> torch . Tensor :
+45 """
+46 Computes gaussian nll loss used in VAE.
+47
+48 Args:
+49 mu (`torch.Tensor`): predictions
+50 log_sigma (`torch.Tensor`): log sigma
+51 x (`torch.Tensor`): targets
+52
+53 Returns:
+54 `torch.Tensor`: the Gaussian NLL loss
+55 """
+56 return (
+57 0.5 * torch . pow (( x - mu ) / log_sigma . exp (), 2 )
+58 + log_sigma
+59 + 0.5 * math . log ( 2 * math . pi )
+60 )
+
+
+
+ Computes gaussian nll loss used in VAE.
+
+
Arguments:
+
+
+mu (torch.Tensor
): predictions
+log_sigma (torch.Tensor
): log sigma
+x (torch.Tensor
): targets
+
+
+
Returns:
+
+
+ torch.Tensor
: the Gaussian NLL loss
+
+
+
+
+
+
+
+
+
+ class
+ VAEEncoder (torch.nn.modules.module.Module , abc.ABC ):
+
+ View Source
+
+
+
+ 63 class VAEEncoder ( nn . Module , ABC ):
+64 """
+65 Base class for a VAE encoder.
+66 """
+67
+68 @abstractmethod
+69 def forward ( self , x : Any ) -> tuple [ torch . Tensor , torch . Tensor ]:
+70 """
+71 Encode representation with VAE.
+72
+73
+74 Args:
+75 x (`Any`): Some input value
+76
+77 Returns:
+78 `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+79 """
+80 ...
+
+
+
+ Base class for a VAE encoder.
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
forward (self , x : Any ) -> tuple [ torch . Tensor , torch . Tensor ] :
+
+
View Source
+
+
+
+
68 @abstractmethod
+69 def forward ( self , x : Any ) -> tuple [ torch . Tensor , torch . Tensor ]:
+70 """
+71 Encode representation with VAE.
+72
+73
+74 Args:
+75 x (`Any`): Some input value
+76
+77 Returns:
+78 `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+79 """
+80 ...
+
+
+
+
Encode representation with VAE.
+
+
Arguments:
+
+
+x (Any
): Some input value
+
+
+
Returns:
+
+
+ tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ class
+ VAEDecoder (torch.nn.modules.module.Module , abc.ABC ):
+
+ View Source
+
+
+
+ 83 class VAEDecoder ( nn . Module , ABC ):
+84 """
+85 Base class for a VAE decoder.
+86 """
+87
+88 @abstractmethod
+89 def forward ( self , x : torch . Tensor ) -> Any :
+90 """
+91 Decode representation with VAE
+92
+93 Args:
+94 x (`torch.Tensor`): VAE latent representation representation
+95
+96 Returns:
+97 `Any`: the reconstructed input
+98 """
+99 ...
+
+
+
+ Base class for a VAE decoder.
+
+
+
+
+
+
+
@abstractmethod
+
+
def
+
forward (self , x : torch . Tensor ) -> Any :
+
+
View Source
+
+
+
+
88 @abstractmethod
+89 def forward ( self , x : torch . Tensor ) -> Any :
+90 """
+91 Decode representation with VAE
+92
+93 Args:
+94 x (`torch.Tensor`): VAE latent representation representation
+95
+96 Returns:
+97 `Any`: the reconstructed input
+98 """
+99 ...
+
+
+
+
Decode representation with VAE
+
+
Arguments:
+
+
+x (torch.Tensor
): VAE latent representation representation
+
+
+
Returns:
+
+
+ Any
: the reconstructed input
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
+
+ class
+ VAE (torch.nn.modules.module.Module ):
+
+ View Source
+
+
+
+ 102 class VAE ( nn . Module ):
+103 """VAE module"""
+104
+105 def __init__ (
+106 self ,
+107 encoder : VAEEncoder ,
+108 decoder : VAEDecoder ,
+109 beta : float = 1 ,
+110 ):
+111 """
+112 Initializes a VAE.
+113
+114 Args:
+115 encoder (`VAEEncoder`): VAE encode
+116 decoder (`VAEDecoder`): VAE decoder
+117 beta (`float`): beta value for Beta-VAE. Defaults to 1.
+118 """
+119 super () . __init__ ()
+120
+121 assert beta >= 0
+122
+123 self . beta = beta
+124 """Beta value for Beta-VAEs"""
+125
+126 self . encoder = encoder
+127 """The encoder"""
+128
+129 self . decoder = decoder
+130 """The decoder"""
+131
+132 def encode ( self , x : Any ) -> torch . Tensor :
+133 """
+134 Encode the representation and returns the mean prediction of VAE.
+135
+136 Args:
+137 x (`Any`): Some input value
+138
+139 Returns:
+140 `torch.Tensor`: The mean representation.
+141 """
+142 mean_z , _ = self . encoder ( x )
+143 return mean_z
+144
+145 def decode ( self , z : torch . Tensor ) -> Any :
+146 """
+147 Decode the VAE latent representation into input value.
+148
+149 Args:
+150 z (`torch.Tensor`): the VAE latent representation.
+151
+152 Returns:
+153 `Any`: the reconstructed input.
+154 """
+155 return self . decoder ( z )
+156
+157 def forward ( self , x : Any ) -> tuple [ tuple [ torch . Tensor , torch . Tensor ], Any ]:
+158 """
+159 Encode and decodes from x.
+160
+161 Args:
+162 x (`Any`): the input data
+163
+164 Returns:
+165 `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+166 first tuple contains the mean and logvar of the encoded input,
+167 the second item is the reconstructed input.
+168 """
+169 mean , logvar = self . encoder ( x )
+170 z = reparameterize ( mean , logvar )
+171
+172 x_reconstructed = self . decoder ( z )
+173
+174 return ( mean , logvar ), x_reconstructed
+
+
+
+
+
+
+
+
+
+
+
105 def __init__ (
+106 self ,
+107 encoder : VAEEncoder ,
+108 decoder : VAEDecoder ,
+109 beta : float = 1 ,
+110 ):
+111 """
+112 Initializes a VAE.
+113
+114 Args:
+115 encoder (`VAEEncoder`): VAE encode
+116 decoder (`VAEDecoder`): VAE decoder
+117 beta (`float`): beta value for Beta-VAE. Defaults to 1.
+118 """
+119 super () . __init__ ()
+120
+121 assert beta >= 0
+122
+123 self . beta = beta
+124 """Beta value for Beta-VAEs"""
+125
+126 self . encoder = encoder
+127 """The encoder"""
+128
+129 self . decoder = decoder
+130 """The decoder"""
+
+
+
+
Initializes a VAE.
+
+
Arguments:
+
+
+encoder (VAEEncoder
): VAE encode
+decoder (VAEDecoder
): VAE decoder
+beta (float
): beta value for Beta-VAE. Defaults to 1.
+
+
+
+
+
+
+
+ beta
+
+
+
+
+
+
Beta value for Beta-VAEs
+
+
+
+
+
+
+ encoder
+
+
+
+
+
+
+
+
+
+
+
+ decoder
+
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ encode (self , x : Any ) -> torch . Tensor :
+
+ View Source
+
+
+
+
132 def encode ( self , x : Any ) -> torch . Tensor :
+133 """
+134 Encode the representation and returns the mean prediction of VAE.
+135
+136 Args:
+137 x (`Any`): Some input value
+138
+139 Returns:
+140 `torch.Tensor`: The mean representation.
+141 """
+142 mean_z , _ = self . encoder ( x )
+143 return mean_z
+
+
+
+
Encode the representation and returns the mean prediction of VAE.
+
+
Arguments:
+
+
+x (Any
): Some input value
+
+
+
Returns:
+
+
+ torch.Tensor
: The mean representation.
+
+
+
+
+
+
+
+
+
+ def
+ decode (self , z : torch . Tensor ) -> Any :
+
+ View Source
+
+
+
+
145 def decode ( self , z : torch . Tensor ) -> Any :
+146 """
+147 Decode the VAE latent representation into input value.
+148
+149 Args:
+150 z (`torch.Tensor`): the VAE latent representation.
+151
+152 Returns:
+153 `Any`: the reconstructed input.
+154 """
+155 return self . decoder ( z )
+
+
+
+
Decode the VAE latent representation into input value.
+
+
Arguments:
+
+
+z (torch.Tensor
): the VAE latent representation.
+
+
+
Returns:
+
+
+ Any
: the reconstructed input.
+
+
+
+
+
+
+
+
+
+ def
+ forward (self , x : Any ) -> tuple [ tuple [ torch . Tensor , torch . Tensor ], typing . Any ] :
+
+ View Source
+
+
+
+
157 def forward ( self , x : Any ) -> tuple [ tuple [ torch . Tensor , torch . Tensor ], Any ]:
+158 """
+159 Encode and decodes from x.
+160
+161 Args:
+162 x (`Any`): the input data
+163
+164 Returns:
+165 `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+166 first tuple contains the mean and logvar of the encoded input,
+167 the second item is the reconstructed input.
+168 """
+169 mean , logvar = self . encoder ( x )
+170 z = reparameterize ( mean , logvar )
+171
+172 x_reconstructed = self . decoder ( z )
+173
+174 return ( mean , logvar ), x_reconstructed
+
+
+
+
Encode and decodes from x.
+
+
Arguments:
+
+
+x (Any
): the input data
+
+
+
Returns:
+
+
+ tuple[tuple[torch.Tensor, torch.Tensor], Any]
: The
+ first tuple contains the mean and logvar of the encoded input,
+ the second item is the reconstructed input.
+
+
+
+
+
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+ dump_patches
+ training
+ call_super_init
+ register_buffer
+ register_parameter
+ add_module
+ register_module
+ get_submodule
+ get_parameter
+ get_buffer
+
+
+ apply
+ cuda
+ ipu
+ xpu
+ cpu
+ type
+ float
+ double
+ half
+ bfloat16
+ to_empty
+ to
+ register_full_backward_pre_hook
+ register_backward_hook
+ register_full_backward_hook
+ register_forward_pre_hook
+ register_forward_hook
+ register_state_dict_pre_hook
+ state_dict
+ register_load_state_dict_post_hook
+ load_state_dict
+ parameters
+ named_parameters
+ buffers
+ named_buffers
+ children
+ named_children
+ modules
+ named_modules
+ train
+ eval
+ requires_grad_
+ zero_grad
+ share_memory
+
+ compile
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/types.html b/docs/api/v0.5.1/shimmer/types.html
new file mode 100644
index 00000000..a1057714
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/types.html
@@ -0,0 +1,872 @@
+
+
+
+
+
+
+ shimmer.types API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.types
+
+
+
+
+ View Source
+
+ 1 from collections.abc import Mapping
+ 2 from typing import Any , Literal
+ 3
+ 4 import torch
+ 5
+ 6 RawDomainGroupT = Mapping [ str , Any ]
+ 7 """
+ 8 Matched raw unimodal data from multiple domains.
+ 9 Keys of the mapping are domains names and values are the domain data.
+ 10
+ 11 All values in the mapping should be matched and represent the same information.
+ 12
+ 13 Example:
+ 14 ```python
+ 15 def fun(domain_group: RawDomainGroupT): ...
+ 16
+ 17
+ 18 x = {
+ 19 "vision": PIL.Image.Image("path/to/dog/picture.png"),
+ 20 "language": "This is a picture of a dog.",
+ 21 }
+ 22
+ 23 fun(x)
+ 24 ```
+ 25
+ 26 Note:
+ 27 This type uses `collections.abc.Mapping` and is used for functions' inputs.
+ 28 Use `RawDomainGroupDT` for functions' outputs.
+ 29
+ 30 This allows to be more generic and allow passing other mappings.
+ 31 """
+ 32
+ 33 RawDomainGroupDT = dict [ str , Any ]
+ 34 """
+ 35 Output type version of `RawDomainGroupT`.
+ 36 Matched raw unimodal data from multiple domains.
+ 37 Keys of the mapping are domains names and values are the domain data.
+ 38
+ 39 Example:
+ 40 ```python
+ 41 def fun() -> RawDomainGroupDT:
+ 42 return {
+ 43 "vision": PIL.Image.Image("path/to/dog/picture.png"),
+ 44 "language": "This is a picture of a dog.",
+ 45 }
+ 46
+ 47 ```
+ 48
+ 49 Note:
+ 50 This type uses `dict`s and is used for functions' outputs.
+ 51 Use `RawDomainGroupT` for functions' inputs.
+ 52
+ 53 """
+ 54
+ 55 LatentsDomainGroupT = Mapping [ str , torch . Tensor ]
+ 56 """
+ 57 Matched unimodal latent representations from multiple domains.
+ 58 Keys of the mapping are domains names and values are `torch.Tensor` latent
+ 59 representation of the domain.
+ 60
+ 61 Example:
+ 62 ```python
+ 63 def fun(domain_group: LatentsDomainGroupT): ...
+ 64
+ 65
+ 66 x = {
+ 67 "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+ 68 "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+ 69 }
+ 70
+ 71 fun(x)
+ 72 ```
+ 73
+ 74 Note:
+ 75 This type uses `collections.abc.Mapping` and is used for functions' inputs.
+ 76 Use `LatentsDomainGroupDT` for functions' outputs.
+ 77
+ 78 This allows to be more generic and allow passing other mappings.
+ 79 """
+ 80
+ 81 LatentsDomainGroupDT = dict [ str , torch . Tensor ]
+ 82 """
+ 83 Matched unimodal latent representations from multiple domains.
+ 84 Keys of the dict are domains names and values are `torch.Tensor` latent
+ 85 representation of the domain.
+ 86
+ 87 Example:
+ 88 ```python
+ 89 def fun() -> LatentsDomainGroupDT:
+ 90 return {
+ 91 "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+ 92 "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+ 93 }
+ 94
+ 95 ```
+ 96
+ 97 Note:
+ 98 This type uses `dict`s and is used for functions' outputs.
+ 99 Use `LatentsDomainGroupT` for functions' inputs.
+100 """
+101
+102 RawDomainGroupsT = Mapping [ frozenset [ str ], RawDomainGroupT ]
+103 """
+104 Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
+105 Each group is independent and contains different data (unpaired).
+106
+107 Example:
+108 ```python
+109 def fun() -> RawDomainGroupsDT:
+110 return {
+111 frozenset(["vision"]): {
+112 "vision": PIL.Image.Image("path/to/cat/picture.png"),
+113 },
+114 frozenset(["language"]): {
+115 "language": "This is a picture of a rabbit.",
+116 },
+117 frozenset(["vision", "language"]): {
+118 "vision": PIL.Image.Image("path/to/dog/picture.png"),
+119 "language": "This is a picture of a dog.",
+120 },
+121 }
+122
+123 ```
+124
+125 Note:
+126 This type uses `dict`s and is used for functions' outputs.
+127 Use `RawDomainGroupsT` for functions' inputs.
+128 """
+129
+130 RawDomainGroupsDT = dict [ frozenset [ str ], RawDomainGroupDT ]
+131 """
+132 Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
+133 Each group is independent and contains different data (unpaired).
+134
+135 Example:
+136 ```python
+137 def fun() -> RawDomainGroupsDT:
+138 return {
+139 frozenset(["vision"]): {
+140 "vision": PIL.Image.Image("path/to/cat/picture.png"),
+141 },
+142 frozenset(["language"]): {
+143 "language": "This is a picture of a rabbit.",
+144 },
+145 frozenset(["vision", "language"]): {
+146 "vision": PIL.Image.Image("path/to/dog/picture.png"),
+147 "language": "This is a picture of a dog.",
+148 },
+149 }
+150
+151 ```
+152
+153 Note:
+154 This type uses `dict`s and is used for functions' outputs.
+155 Use `RawDomainGroupsT` for functions' inputs.
+156 """
+157
+158 LatentsDomainGroupsT = Mapping [ frozenset [ str ], LatentsDomainGroupT ]
+159 """
+160 Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
+161 Each group is independent and contains different data (unpaired).
+162
+163 Example:
+164 ```python
+165 def fun(domain_group: LatentsDomainGroupsT): ...
+166
+167
+168 x = {
+169 frozenset(["vision"]): {
+170 "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+171 },
+172 frozenset(["language"]): {
+173 "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+174 },
+175 frozenset(["vision", "language"]): {
+176 "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+177 "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+178 },
+179 }
+180
+181 fun(x)
+182 ```
+183 Note:
+184 This type uses `collections.abc.Mapping` and is used for functions' inputs.
+185 Use `LatentsDomainGroupsDT` for functions' outputs.
+186
+187 This allows to be more generic and allow passing other mappings.
+188
+189 """
+190
+191 LatentsDomainGroupsDT = dict [ frozenset [ str ], LatentsDomainGroupDT ]
+192 """
+193 Mapping of `LatentsDomainGroupDT`.
+194 Keys are frozenset of domains matched in the group.
+195 Each group is independent and contains different data (unpaired).
+196
+197 Example:
+198 ```python
+199 def fun() -> LatentsDomainGroupsDT:
+200 return {
+201 frozenset(["vision"]): {
+202 "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+203 },
+204 frozenset(["language"]): {
+205 "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+206 },
+207 frozenset(["vision", "language"]): {
+208 "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+209 "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+210 },
+211 }
+212
+213 ```
+214
+215 Note:
+216 This type uses `dict`s and is used for functions' outputs.
+217 Use `LatentsDomainGroupT` for functions' inputs.
+218 """
+219
+220
+221 ModelModeT = Literal [ "train" , "val" , "test" , "val/ood" , "test/ood" ]
+222 """
+223 Mode used by pytorch lightning (train/val, ...).
+224
+225 When validating or testing in out-of-distribution data, "val/ood" or "test/ood" mode is
+226 used.
+227 """
+
+
+
+
+
+
+ RawDomainGroupT =
+collections.abc.Mapping[str, typing.Any]
+
+
+
+
+
+ Matched raw unimodal data from multiple domains.
+Keys of the mapping are domains names and values are the domain data.
+
+
All values in the mapping should be matched and represent the same information.
+
+
Example:
+
+
+
+
def fun ( domain_group : RawDomainGroupT ): ...
+
+
+x = {
+ "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ),
+ "language" : "This is a picture of a dog." ,
+}
+
+fun ( x )
+
+
+
+
+
Note:
+
+
+ This type uses collections.abc.Mapping
and is used for functions' inputs.
+ Use RawDomainGroupDT
for functions' outputs.
+
+ This allows to be more generic and allow passing other mappings.
+
+
+
+
+
+
+
+ RawDomainGroupDT =
+dict[str, typing.Any]
+
+
+
+
+
+ Output type version of RawDomainGroupT
.
+Matched raw unimodal data from multiple domains.
+Keys of the mapping are domains names and values are the domain data.
+
+
Example:
+
+
+
+
def fun () -> RawDomainGroupDT :
+ return {
+ "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ),
+ "language" : "This is a picture of a dog." ,
+ }
+
+
+
+
+
Note:
+
+
+ This type uses dict
s and is used for functions' outputs.
+ Use RawDomainGroupT
for functions' inputs.
+
+
+
+
+
+
+
+ LatentsDomainGroupT =
+collections.abc.Mapping[str, torch.Tensor]
+
+
+
+
+
+ Matched unimodal latent representations from multiple domains.
+Keys of the mapping are domains names and values are torch.Tensor
latent
+representation of the domain.
+
+
Example:
+
+
+
+
def fun ( domain_group : LatentsDomainGroupT ): ...
+
+
+x = {
+ "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]),
+ "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]),
+}
+
+fun ( x )
+
+
+
+
+
Note:
+
+
+ This type uses collections.abc.Mapping
and is used for functions' inputs.
+ Use LatentsDomainGroupDT
for functions' outputs.
+
+ This allows to be more generic and allow passing other mappings.
+
+
+
+
+
+
+
+ LatentsDomainGroupDT =
+dict[str, torch.Tensor]
+
+
+
+
+
+ Matched unimodal latent representations from multiple domains.
+Keys of the dict are domains names and values are torch.Tensor
latent
+representation of the domain.
+
+
Example:
+
+
+
+
def fun () -> LatentsDomainGroupDT :
+ return {
+ "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]),
+ "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]),
+ }
+
+
+
+
+
Note:
+
+
+ This type uses dict
s and is used for functions' outputs.
+ Use LatentsDomainGroupT
for functions' inputs.
+
+
+
+
+
+
+
+ RawDomainGroupsT =
+collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]
+
+
+
+
+
+ Mapping of RawDomainGroupT
. Keys are frozenset of domains matched in the group.
+Each group is independent and contains different data (unpaired).
+
+
Example:
+
+
+
+
def fun () -> RawDomainGroupsDT :
+ return {
+ frozenset ([ "vision" ]): {
+ "vision" : PIL . Image . Image ( "path/to/cat/picture.png" ),
+ },
+ frozenset ([ "language" ]): {
+ "language" : "This is a picture of a rabbit." ,
+ },
+ frozenset ([ "vision" , "language" ]): {
+ "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ),
+ "language" : "This is a picture of a dog." ,
+ },
+ }
+
+
+
+
+
Note:
+
+
+ This type uses dict
s and is used for functions' outputs.
+ Use RawDomainGroupsT
for functions' inputs.
+
+
+
+
+
+
+
+ RawDomainGroupsDT =
+dict[frozenset[str], dict[str, typing.Any]]
+
+
+
+
+
+ Mapping of RawDomainGroupT
. Keys are frozenset of domains matched in the group.
+Each group is independent and contains different data (unpaired).
+
+
Example:
+
+
+
+
def fun () -> RawDomainGroupsDT :
+ return {
+ frozenset ([ "vision" ]): {
+ "vision" : PIL . Image . Image ( "path/to/cat/picture.png" ),
+ },
+ frozenset ([ "language" ]): {
+ "language" : "This is a picture of a rabbit." ,
+ },
+ frozenset ([ "vision" , "language" ]): {
+ "vision" : PIL . Image . Image ( "path/to/dog/picture.png" ),
+ "language" : "This is a picture of a dog." ,
+ },
+ }
+
+
+
+
+
Note:
+
+
+ This type uses dict
s and is used for functions' outputs.
+ Use RawDomainGroupsT
for functions' inputs.
+
+
+
+
+
+
+
+ LatentsDomainGroupsT =
+collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]
+
+
+
+
+
+ Mapping of LatentsDomainGroupT
. Keys are frozenset of domains matched in the group.
+Each group is independent and contains different data (unpaired).
+
+
Example:
+
+
+
+
def fun ( domain_group : LatentsDomainGroupsT ): ...
+
+
+x = {
+ frozenset ([ "vision" ]): {
+ "vision" : torch . Tensor ([ 1.0 , 0.0 , 0.3 , ... ]),
+ },
+ frozenset ([ "language" ]): {
+ "language" : torch . Tensor ([ 1.0 , 0.2 , 0.9 , ... ]),
+ },
+ frozenset ([ "vision" , "language" ]): {
+ "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]),
+ "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]),
+ },
+}
+
+fun ( x )
+
+
+
+
+
Note:
+
+
+ This type uses collections.abc.Mapping
and is used for functions' inputs.
+ Use LatentsDomainGroupsDT
for functions' outputs.
+
+ This allows to be more generic and allow passing other mappings.
+
+
+
+
+
+
+
+ LatentsDomainGroupsDT =
+dict[frozenset[str], dict[str, torch.Tensor]]
+
+
+
+
+
+ Mapping of LatentsDomainGroupDT
.
+Keys are frozenset of domains matched in the group.
+Each group is independent and contains different data (unpaired).
+
+
Example:
+
+
+
+
def fun () -> LatentsDomainGroupsDT :
+ return {
+ frozenset ([ "vision" ]): {
+ "vision" : torch . Tensor ([ 1.0 , 0.0 , 0.3 , ... ]),
+ },
+ frozenset ([ "language" ]): {
+ "language" : torch . Tensor ([ 1.0 , 0.2 , 0.9 , ... ]),
+ },
+ frozenset ([ "vision" , "language" ]): {
+ "vision" : torch . Tensor ([ 0.0 , 1.0 , 0.0 , ... ]),
+ "language" : torch . Tensor ([ 0.0 , 0.3 , 0.2 , ... ]),
+ },
+ }
+
+
+
+
+
Note:
+
+
+ This type uses dict
s and is used for functions' outputs.
+ Use LatentsDomainGroupT
for functions' inputs.
+
+
+
+
+
+
+
+ ModelModeT =
+typing.Literal['train', 'val', 'test', 'val/ood', 'test/ood']
+
+
+
+
+
+ Mode used by pytorch lightning (train/val, ...).
+
+
When validating or testing in out-of-distribution data, "val/ood" or "test/ood" mode is
+used.
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/api/v0.5.1/shimmer/utils.html b/docs/api/v0.5.1/shimmer/utils.html
new file mode 100644
index 00000000..87a2b3bd
--- /dev/null
+++ b/docs/api/v0.5.1/shimmer/utils.html
@@ -0,0 +1,701 @@
+
+
+
+
+
+
+ shimmer.utils API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+shimmer.utils
+
+
+
+
+ View Source
+
+ 1 from os import PathLike
+ 2 from pathlib import Path
+ 3 from typing import Any
+ 4
+ 5 import torch
+ 6 from lightning.pytorch import Callback , LightningModule , Trainer
+ 7 from migrate_ckpt import (
+ 8 ckpt_migration_key ,
+ 9 get_folder_migrations ,
+10 migrate_from_folder ,
+11 )
+12
+13 from shimmer.types import LatentsDomainGroupsT , LatentsDomainGroupT
+14
+15 MIGRATION_DIR = Path ( __file__ ) . parent / "ckpt_migrations"
+16
+17
+18 def group_batch_size ( x : LatentsDomainGroupT ) -> int :
+19 for val in x . values ():
+20 return val . size ( 0 )
+21 raise ValueError ( "Got empty group." )
+22
+23
+24 def groups_batch_size ( domain_latents : LatentsDomainGroupsT ) -> int :
+25 """
+26 Get the batch size of the batch.
+27
+28 Args:
+29 domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+30
+31 Returns:
+32 int: the batch size.
+33 """
+34 for data in domain_latents . values ():
+35 for tensor in data . values ():
+36 return tensor . size ( 0 )
+37 raise ValueError ( "Empty batch." )
+38
+39
+40 def groups_device ( domain_latents : LatentsDomainGroupsT ) -> int :
+41 """
+42 Get the batch size of the batch.
+43
+44 Args:
+45 domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+46
+47 Returns:
+48 int: the batch size.
+49 """
+50 for data in domain_latents . values ():
+51 for tensor in data . values ():
+52 return tensor . size ( 0 )
+53 raise ValueError ( "Empty batch." )
+54
+55
+56 def group_device ( x : LatentsDomainGroupT ) -> torch . device :
+57 for val in x . values ():
+58 return val . device
+59 raise ValueError ( "Got empty group." )
+60
+61
+62 def migrate_model ( ckpt_path : str | PathLike , ** torch_load_kwargs ):
+63 """
+64 Migrates a model checkpoint
+65
+66 After the migration, the given checkpoint will be migrated.
+67 Other versions of the checkpoint will be saved under the stem-version.suffix.
+68
+69 Args:
+70 ckpt_path (`str | PathLike`): path to checkpoint
+71 torch_load_kwargs: additional args given to torch.load.
+72 """
+73 ckpt_path = Path ( ckpt_path )
+74 ckpt = torch . load ( ckpt_path , ** torch_load_kwargs )
+75 new_ckpt , done_migrations = migrate_from_folder ( ckpt , MIGRATION_DIR )
+76 done_migration_log = ", " . join ( map ( lambda x : x . name , done_migrations ))
+77 print ( f "Migrating: { done_migration_log } " )
+78 if len ( done_migrations ) or ckpt_migration_key not in ckpt :
+79 version = 0
+80 if ckpt_migration_key in ckpt :
+81 version = len ( ckpt [ ckpt_migration_key ])
+82 torch . save ( ckpt , ckpt_path . with_stem ( f " { ckpt_path . stem } - { version } " ))
+83 torch . save ( new_ckpt , ckpt_path )
+84
+85
+86 class SaveMigrations ( Callback ):
+87 def __init__ ( self ):
+88 self . migrations = get_folder_migrations ( MIGRATION_DIR )
+89
+90 def on_save_checkpoint (
+91 self , trainer : Trainer , pl_module : LightningModule , checkpoint : dict [ str , Any ]
+92 ):
+93 checkpoint [ ckpt_migration_key ] = [ mig . name for mig in self . migrations ]
+
+
+
+
+
+
+ MIGRATION_DIR =
+PosixPath('/home/runner/work/shimmer/shimmer/shimmer/ckpt_migrations')
+
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ groups_batch_size ( domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> int :
+
+ View Source
+
+
+
+ 25 def groups_batch_size ( domain_latents : LatentsDomainGroupsT ) -> int :
+26 """
+27 Get the batch size of the batch.
+28
+29 Args:
+30 domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+31
+32 Returns:
+33 int: the batch size.
+34 """
+35 for data in domain_latents . values ():
+36 for tensor in data . values ():
+37 return tensor . size ( 0 )
+38 raise ValueError ( "Empty batch." )
+
+
+
+ Get the batch size of the batch.
+
+
Arguments:
+
+
+domain_latents (LatentsDomainGroupsT
): the batch of groups.
+
+
+
Returns:
+
+
+ int: the batch size.
+
+
+
+
+
+
+
+
+
+ def
+ groups_device ( domain_latents : collections . abc . Mapping [ frozenset [ str ], collections . abc . Mapping [ str , torch . Tensor ]] ) -> int :
+
+ View Source
+
+
+
+ 41 def groups_device ( domain_latents : LatentsDomainGroupsT ) -> int :
+42 """
+43 Get the batch size of the batch.
+44
+45 Args:
+46 domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+47
+48 Returns:
+49 int: the batch size.
+50 """
+51 for data in domain_latents . values ():
+52 for tensor in data . values ():
+53 return tensor . size ( 0 )
+54 raise ValueError ( "Empty batch." )
+
+
+
+ Get the batch size of the batch.
+
+
Arguments:
+
+
+domain_latents (LatentsDomainGroupsT
): the batch of groups.
+
+
+
Returns:
+
+
+ int: the batch size.
+
+
+
+
+
+
+
+
+
+
+ def
+ migrate_model (ckpt_path : str | os . PathLike , ** torch_load_kwargs ):
+
+ View Source
+
+
+
+ 63 def migrate_model ( ckpt_path : str | PathLike , ** torch_load_kwargs ):
+64 """
+65 Migrates a model checkpoint
+66
+67 After the migration, the given checkpoint will be migrated.
+68 Other versions of the checkpoint will be saved under the stem-version.suffix.
+69
+70 Args:
+71 ckpt_path (`str | PathLike`): path to checkpoint
+72 torch_load_kwargs: additional args given to torch.load.
+73 """
+74 ckpt_path = Path ( ckpt_path )
+75 ckpt = torch . load ( ckpt_path , ** torch_load_kwargs )
+76 new_ckpt , done_migrations = migrate_from_folder ( ckpt , MIGRATION_DIR )
+77 done_migration_log = ", " . join ( map ( lambda x : x . name , done_migrations ))
+78 print ( f "Migrating: { done_migration_log } " )
+79 if len ( done_migrations ) or ckpt_migration_key not in ckpt :
+80 version = 0
+81 if ckpt_migration_key in ckpt :
+82 version = len ( ckpt [ ckpt_migration_key ])
+83 torch . save ( ckpt , ckpt_path . with_stem ( f " { ckpt_path . stem } - { version } " ))
+84 torch . save ( new_ckpt , ckpt_path )
+
+
+
+ Migrates a model checkpoint
+
+
After the migration, the given checkpoint will be migrated.
+Other versions of the checkpoint will be saved under the stem-version.suffix.
+
+
Arguments:
+
+
+ckpt_path (str | PathLike
): path to checkpoint
+torch_load_kwargs: additional args given to torch.load.
+
+
+
+
+
+
+
+
+
+ class
+ SaveMigrations (lightning.pytorch.callbacks.callback.Callback ):
+
+ View Source
+
+
+
+ 87 class SaveMigrations ( Callback ):
+88 def __init__ ( self ):
+89 self . migrations = get_folder_migrations ( MIGRATION_DIR )
+90
+91 def on_save_checkpoint (
+92 self , trainer : Trainer , pl_module : LightningModule , checkpoint : dict [ str , Any ]
+93 ):
+94 checkpoint [ ckpt_migration_key ] = [ mig . name for mig in self . migrations ]
+
+
+
+ Abstract base class used to build new callbacks.
+
+
Subclass this class and override any of the relevant hooks
+
+
+
+
+
+ migrations
+
+
+
+
+
+
+
+
+
+
+
+
+ def
+ on_save_checkpoint ( self , trainer : lightning . pytorch . trainer . trainer . Trainer , pl_module : lightning . pytorch . core . module . LightningModule , checkpoint : dict [ str , typing . Any ] ):
+
+ View Source
+
+
+
+
91 def on_save_checkpoint (
+92 self , trainer : Trainer , pl_module : LightningModule , checkpoint : dict [ str , Any ]
+93 ):
+94 checkpoint [ ckpt_migration_key ] = [ mig . name for mig in self . migrations ]
+
+
+
+
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
+
+
Arguments:
+
+
+trainer: the current ~lightning.pytorch.trainer.trainer.Trainer
instance.
+pl_module: the current ~lightning.pytorch.core.LightningModule
instance.
+checkpoint: the checkpoint dictionary that will be saved.
+
+
+
+
+
+
+
Inherited Members
+
+
lightning.pytorch.callbacks.callback.Callback
+ state_key
+ setup
+ teardown
+ on_fit_start
+ on_fit_end
+ on_sanity_check_start
+ on_sanity_check_end
+ on_train_batch_start
+ on_train_batch_end
+ on_train_epoch_start
+ on_train_epoch_end
+ on_validation_epoch_start
+ on_validation_epoch_end
+ on_test_epoch_start
+ on_test_epoch_end
+ on_predict_epoch_start
+ on_predict_epoch_end
+ on_validation_batch_start
+ on_validation_batch_end
+ on_test_batch_start
+ on_test_batch_end
+ on_predict_batch_start
+ on_predict_batch_end
+ on_train_start
+ on_train_end
+ on_validation_start
+ on_validation_end
+ on_test_start
+ on_test_end
+ on_predict_start
+ on_predict_end
+ on_exception
+ state_dict
+ load_state_dict
+ on_load_checkpoint
+ on_before_backward
+ on_after_backward
+ on_before_optimizer_step
+ on_before_zero_grad
+
+
+
+
+
+
+
+
\ No newline at end of file