From e689fbb41d17af73c83d61bd9b190e7b27501c36 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 24 May 2024 11:30:50 +0200 Subject: [PATCH] DOCS: shimmer API docs for v0.5.1 (#86) Co-authored-by: github-actions --- docs/api/v0.5.1/index.html | 249 + docs/api/v0.5.1/search.js | 46 + .../v0.5.1/shimmer/cli/ckpt_migration.html | 319 ++ docs/api/v0.5.1/shimmer/dataset.html | 448 ++ .../shimmer/modules/contrastive_loss.html | 797 +++ docs/api/v0.5.1/shimmer/modules/domain.html | 1220 +++++ .../shimmer/modules/global_workspace.html | 4260 +++++++++++++++++ .../api/v0.5.1/shimmer/modules/gw_module.html | 2893 +++++++++++ docs/api/v0.5.1/shimmer/modules/losses.html | 3888 +++++++++++++++ .../api/v0.5.1/shimmer/modules/selection.html | 2151 +++++++++ docs/api/v0.5.1/shimmer/modules/utils.html | 760 +++ docs/api/v0.5.1/shimmer/modules/vae.html | 1288 +++++ docs/api/v0.5.1/shimmer/types.html | 872 ++++ docs/api/v0.5.1/shimmer/utils.html | 701 +++ 14 files changed, 19892 insertions(+) create mode 100644 docs/api/v0.5.1/index.html create mode 100644 docs/api/v0.5.1/search.js create mode 100644 docs/api/v0.5.1/shimmer/cli/ckpt_migration.html create mode 100644 docs/api/v0.5.1/shimmer/dataset.html create mode 100644 docs/api/v0.5.1/shimmer/modules/contrastive_loss.html create mode 100644 docs/api/v0.5.1/shimmer/modules/domain.html create mode 100644 docs/api/v0.5.1/shimmer/modules/global_workspace.html create mode 100644 docs/api/v0.5.1/shimmer/modules/gw_module.html create mode 100644 docs/api/v0.5.1/shimmer/modules/losses.html create mode 100644 docs/api/v0.5.1/shimmer/modules/selection.html create mode 100644 docs/api/v0.5.1/shimmer/modules/utils.html create mode 100644 docs/api/v0.5.1/shimmer/modules/vae.html create mode 100644 docs/api/v0.5.1/shimmer/types.html create mode 100644 docs/api/v0.5.1/shimmer/utils.html diff --git a/docs/api/v0.5.1/index.html b/docs/api/v0.5.1/index.html new file mode 100644 index 00000000..b62372d7 --- /dev/null +++ b/docs/api/v0.5.1/index.html @@ -0,0 +1,249 @@ + + + + + + + Module List – pdoc 14.4.0 + + + + + + + + + + + + + + +
+ + pdoc + + +
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/search.js b/docs/api/v0.5.1/search.js new file mode 100644 index 00000000..6b97e191 --- /dev/null +++ b/docs/api/v0.5.1/search.js @@ -0,0 +1,46 @@ +window.pdocSearch = (function(){ +/** elasticlunr - http://weixsong.github.io * Copyright (C) 2017 Oliver Nightingale * Copyright (C) 2017 Wei Song * MIT Licensed */!function(){function e(e){if(null===e||"object"!=typeof e)return e;var t=e.constructor();for(var n in e)e.hasOwnProperty(n)&&(t[n]=e[n]);return t}var t=function(e){var n=new t.Index;return n.pipeline.add(t.trimmer,t.stopWordFilter,t.stemmer),e&&e.call(n,n),n};t.version="0.9.5",lunr=t,t.utils={},t.utils.warn=function(e){return function(t){e.console&&console.warn&&console.warn(t)}}(this),t.utils.toString=function(e){return void 0===e||null===e?"":e.toString()},t.EventEmitter=function(){this.events={}},t.EventEmitter.prototype.addListener=function(){var e=Array.prototype.slice.call(arguments),t=e.pop(),n=e;if("function"!=typeof t)throw new TypeError("last argument must be a function");n.forEach(function(e){this.hasHandler(e)||(this.events[e]=[]),this.events[e].push(t)},this)},t.EventEmitter.prototype.removeListener=function(e,t){if(this.hasHandler(e)){var n=this.events[e].indexOf(t);-1!==n&&(this.events[e].splice(n,1),0==this.events[e].length&&delete this.events[e])}},t.EventEmitter.prototype.emit=function(e){if(this.hasHandler(e)){var t=Array.prototype.slice.call(arguments,1);this.events[e].forEach(function(e){e.apply(void 0,t)},this)}},t.EventEmitter.prototype.hasHandler=function(e){return e in this.events},t.tokenizer=function(e){if(!arguments.length||null===e||void 0===e)return[];if(Array.isArray(e)){var n=e.filter(function(e){return null===e||void 0===e?!1:!0});n=n.map(function(e){return t.utils.toString(e).toLowerCase()});var i=[];return n.forEach(function(e){var n=e.split(t.tokenizer.seperator);i=i.concat(n)},this),i}return e.toString().trim().toLowerCase().split(t.tokenizer.seperator)},t.tokenizer.defaultSeperator=/[\s\-]+/,t.tokenizer.seperator=t.tokenizer.defaultSeperator,t.tokenizer.setSeperator=function(e){null!==e&&void 0!==e&&"object"==typeof e&&(t.tokenizer.seperator=e)},t.tokenizer.resetSeperator=function(){t.tokenizer.seperator=t.tokenizer.defaultSeperator},t.tokenizer.getSeperator=function(){return t.tokenizer.seperator},t.Pipeline=function(){this._queue=[]},t.Pipeline.registeredFunctions={},t.Pipeline.registerFunction=function(e,n){n in t.Pipeline.registeredFunctions&&t.utils.warn("Overwriting existing registered function: "+n),e.label=n,t.Pipeline.registeredFunctions[n]=e},t.Pipeline.getRegisteredFunction=function(e){return e in t.Pipeline.registeredFunctions!=!0?null:t.Pipeline.registeredFunctions[e]},t.Pipeline.warnIfFunctionNotRegistered=function(e){var n=e.label&&e.label in this.registeredFunctions;n||t.utils.warn("Function is not registered with pipeline. This may cause problems when serialising the index.\n",e)},t.Pipeline.load=function(e){var n=new t.Pipeline;return e.forEach(function(e){var i=t.Pipeline.getRegisteredFunction(e);if(!i)throw new Error("Cannot load un-registered function: "+e);n.add(i)}),n},t.Pipeline.prototype.add=function(){var e=Array.prototype.slice.call(arguments);e.forEach(function(e){t.Pipeline.warnIfFunctionNotRegistered(e),this._queue.push(e)},this)},t.Pipeline.prototype.after=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i+1,0,n)},t.Pipeline.prototype.before=function(e,n){t.Pipeline.warnIfFunctionNotRegistered(n);var i=this._queue.indexOf(e);if(-1===i)throw new Error("Cannot find existingFn");this._queue.splice(i,0,n)},t.Pipeline.prototype.remove=function(e){var t=this._queue.indexOf(e);-1!==t&&this._queue.splice(t,1)},t.Pipeline.prototype.run=function(e){for(var t=[],n=e.length,i=this._queue.length,o=0;n>o;o++){for(var r=e[o],s=0;i>s&&(r=this._queue[s](r,o,e),void 0!==r&&null!==r);s++);void 0!==r&&null!==r&&t.push(r)}return t},t.Pipeline.prototype.reset=function(){this._queue=[]},t.Pipeline.prototype.get=function(){return this._queue},t.Pipeline.prototype.toJSON=function(){return this._queue.map(function(e){return t.Pipeline.warnIfFunctionNotRegistered(e),e.label})},t.Index=function(){this._fields=[],this._ref="id",this.pipeline=new t.Pipeline,this.documentStore=new t.DocumentStore,this.index={},this.eventEmitter=new t.EventEmitter,this._idfCache={},this.on("add","remove","update",function(){this._idfCache={}}.bind(this))},t.Index.prototype.on=function(){var e=Array.prototype.slice.call(arguments);return this.eventEmitter.addListener.apply(this.eventEmitter,e)},t.Index.prototype.off=function(e,t){return this.eventEmitter.removeListener(e,t)},t.Index.load=function(e){e.version!==t.version&&t.utils.warn("version mismatch: current "+t.version+" importing "+e.version);var n=new this;n._fields=e.fields,n._ref=e.ref,n.documentStore=t.DocumentStore.load(e.documentStore),n.pipeline=t.Pipeline.load(e.pipeline),n.index={};for(var i in e.index)n.index[i]=t.InvertedIndex.load(e.index[i]);return n},t.Index.prototype.addField=function(e){return this._fields.push(e),this.index[e]=new t.InvertedIndex,this},t.Index.prototype.setRef=function(e){return this._ref=e,this},t.Index.prototype.saveDocument=function(e){return this.documentStore=new t.DocumentStore(e),this},t.Index.prototype.addDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.addDoc(i,e),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));this.documentStore.addFieldLength(i,n,o.length);var r={};o.forEach(function(e){e in r?r[e]+=1:r[e]=1},this);for(var s in r){var u=r[s];u=Math.sqrt(u),this.index[n].addToken(s,{ref:i,tf:u})}},this),n&&this.eventEmitter.emit("add",e,this)}},t.Index.prototype.removeDocByRef=function(e){if(e&&this.documentStore.isDocStored()!==!1&&this.documentStore.hasDoc(e)){var t=this.documentStore.getDoc(e);this.removeDoc(t,!1)}},t.Index.prototype.removeDoc=function(e,n){if(e){var n=void 0===n?!0:n,i=e[this._ref];this.documentStore.hasDoc(i)&&(this.documentStore.removeDoc(i),this._fields.forEach(function(n){var o=this.pipeline.run(t.tokenizer(e[n]));o.forEach(function(e){this.index[n].removeToken(e,i)},this)},this),n&&this.eventEmitter.emit("remove",e,this))}},t.Index.prototype.updateDoc=function(e,t){var t=void 0===t?!0:t;this.removeDocByRef(e[this._ref],!1),this.addDoc(e,!1),t&&this.eventEmitter.emit("update",e,this)},t.Index.prototype.idf=function(e,t){var n="@"+t+"/"+e;if(Object.prototype.hasOwnProperty.call(this._idfCache,n))return this._idfCache[n];var i=this.index[t].getDocFreq(e),o=1+Math.log(this.documentStore.length/(i+1));return this._idfCache[n]=o,o},t.Index.prototype.getFields=function(){return this._fields.slice()},t.Index.prototype.search=function(e,n){if(!e)return[];e="string"==typeof e?{any:e}:JSON.parse(JSON.stringify(e));var i=null;null!=n&&(i=JSON.stringify(n));for(var o=new t.Configuration(i,this.getFields()).get(),r={},s=Object.keys(e),u=0;u0&&t.push(e);for(var i in n)"docs"!==i&&"df"!==i&&this.expandToken(e+i,t,n[i]);return t},t.InvertedIndex.prototype.toJSON=function(){return{root:this.root}},t.Configuration=function(e,n){var e=e||"";if(void 0==n||null==n)throw new Error("fields should not be null");this.config={};var i;try{i=JSON.parse(e),this.buildUserConfig(i,n)}catch(o){t.utils.warn("user configuration parse failed, will use default configuration"),this.buildDefaultConfig(n)}},t.Configuration.prototype.buildDefaultConfig=function(e){this.reset(),e.forEach(function(e){this.config[e]={boost:1,bool:"OR",expand:!1}},this)},t.Configuration.prototype.buildUserConfig=function(e,n){var i="OR",o=!1;if(this.reset(),"bool"in e&&(i=e.bool||i),"expand"in e&&(o=e.expand||o),"fields"in e)for(var r in e.fields)if(n.indexOf(r)>-1){var s=e.fields[r],u=o;void 0!=s.expand&&(u=s.expand),this.config[r]={boost:s.boost||0===s.boost?s.boost:1,bool:s.bool||i,expand:u}}else t.utils.warn("field name in user configuration not found in index instance fields");else this.addAllFields2UserConfig(i,o,n)},t.Configuration.prototype.addAllFields2UserConfig=function(e,t,n){n.forEach(function(n){this.config[n]={boost:1,bool:e,expand:t}},this)},t.Configuration.prototype.get=function(){return this.config},t.Configuration.prototype.reset=function(){this.config={}},lunr.SortedSet=function(){this.length=0,this.elements=[]},lunr.SortedSet.load=function(e){var t=new this;return t.elements=e,t.length=e.length,t},lunr.SortedSet.prototype.add=function(){var e,t;for(e=0;e1;){if(r===e)return o;e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o]}return r===e?o:-1},lunr.SortedSet.prototype.locationFor=function(e){for(var t=0,n=this.elements.length,i=n-t,o=t+Math.floor(i/2),r=this.elements[o];i>1;)e>r&&(t=o),r>e&&(n=o),i=n-t,o=t+Math.floor(i/2),r=this.elements[o];return r>e?o:e>r?o+1:void 0},lunr.SortedSet.prototype.intersect=function(e){for(var t=new lunr.SortedSet,n=0,i=0,o=this.length,r=e.length,s=this.elements,u=e.elements;;){if(n>o-1||i>r-1)break;s[n]!==u[i]?s[n]u[i]&&i++:(t.add(s[n]),n++,i++)}return t},lunr.SortedSet.prototype.clone=function(){var e=new lunr.SortedSet;return e.elements=this.toArray(),e.length=e.elements.length,e},lunr.SortedSet.prototype.union=function(e){var t,n,i;this.length>=e.length?(t=this,n=e):(t=e,n=this),i=t.clone();for(var o=0,r=n.toArray();o

\n"}, "shimmer.types.RawDomainGroupT": {"fullname": "shimmer.types.RawDomainGroupT", "modulename": "shimmer.types", "qualname": "RawDomainGroupT", "kind": "variable", "doc": "

Matched raw unimodal data from multiple domains.\nKeys of the mapping are domains names and values are the domain data.

\n\n

All values in the mapping should be matched and represent the same information.

\n\n
Example:
\n\n
\n
\n
def fun(domain_group: RawDomainGroupT): ...\n\n\nx = {\n    "vision": PIL.Image.Image("path/to/dog/picture.png"),\n    "language": "This is a picture of a dog.",\n}\n\nfun(x)\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses collections.abc.Mapping and is used for functions' inputs.\n Use RawDomainGroupDT for functions' outputs.

\n \n

This allows to be more generic and allow passing other mappings.

\n
\n", "default_value": "collections.abc.Mapping[str, typing.Any]"}, "shimmer.types.RawDomainGroupDT": {"fullname": "shimmer.types.RawDomainGroupDT", "modulename": "shimmer.types", "qualname": "RawDomainGroupDT", "kind": "variable", "doc": "

Output type version of RawDomainGroupT.\nMatched raw unimodal data from multiple domains.\nKeys of the mapping are domains names and values are the domain data.

\n\n
Example:
\n\n
\n
\n
def fun() -> RawDomainGroupDT:\n    return {\n        "vision": PIL.Image.Image("path/to/dog/picture.png"),\n        "language": "This is a picture of a dog.",\n    }\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses dicts and is used for functions' outputs.\n Use RawDomainGroupT for functions' inputs.

\n
\n", "default_value": "dict[str, typing.Any]"}, "shimmer.types.LatentsDomainGroupT": {"fullname": "shimmer.types.LatentsDomainGroupT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupT", "kind": "variable", "doc": "

Matched unimodal latent representations from multiple domains.\nKeys of the mapping are domains names and values are torch.Tensor latent\nrepresentation of the domain.

\n\n
Example:
\n\n
\n
\n
def fun(domain_group: LatentsDomainGroupT): ...\n\n\nx = {\n    "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),\n    "language": torch.Tensor([0.0, 0.3, 0.2, ...]),\n}\n\nfun(x)\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses collections.abc.Mapping and is used for functions' inputs.\n Use LatentsDomainGroupDT for functions' outputs.

\n \n

This allows to be more generic and allow passing other mappings.

\n
\n", "default_value": "collections.abc.Mapping[str, torch.Tensor]"}, "shimmer.types.LatentsDomainGroupDT": {"fullname": "shimmer.types.LatentsDomainGroupDT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupDT", "kind": "variable", "doc": "

Matched unimodal latent representations from multiple domains.\nKeys of the dict are domains names and values are torch.Tensor latent\nrepresentation of the domain.

\n\n
Example:
\n\n
\n
\n
def fun() -> LatentsDomainGroupDT:\n    return {\n        "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),\n        "language": torch.Tensor([0.0, 0.3, 0.2, ...]),\n    }\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses dicts and is used for functions' outputs.\n Use LatentsDomainGroupT for functions' inputs.

\n
\n", "default_value": "dict[str, torch.Tensor]"}, "shimmer.types.RawDomainGroupsT": {"fullname": "shimmer.types.RawDomainGroupsT", "modulename": "shimmer.types", "qualname": "RawDomainGroupsT", "kind": "variable", "doc": "

Mapping of RawDomainGroupT. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).

\n\n
Example:
\n\n
\n
\n
def fun() -> RawDomainGroupsDT:\n    return {\n        frozenset(["vision"]): {\n            "vision": PIL.Image.Image("path/to/cat/picture.png"),\n        },\n        frozenset(["language"]): {\n            "language": "This is a picture of a rabbit.",\n        },\n        frozenset(["vision", "language"]): {\n            "vision": PIL.Image.Image("path/to/dog/picture.png"),\n            "language": "This is a picture of a dog.",\n        },\n    }\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses dicts and is used for functions' outputs.\n Use RawDomainGroupsT for functions' inputs.

\n
\n", "default_value": "collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]"}, "shimmer.types.RawDomainGroupsDT": {"fullname": "shimmer.types.RawDomainGroupsDT", "modulename": "shimmer.types", "qualname": "RawDomainGroupsDT", "kind": "variable", "doc": "

Mapping of RawDomainGroupT. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).

\n\n
Example:
\n\n
\n
\n
def fun() -> RawDomainGroupsDT:\n    return {\n        frozenset(["vision"]): {\n            "vision": PIL.Image.Image("path/to/cat/picture.png"),\n        },\n        frozenset(["language"]): {\n            "language": "This is a picture of a rabbit.",\n        },\n        frozenset(["vision", "language"]): {\n            "vision": PIL.Image.Image("path/to/dog/picture.png"),\n            "language": "This is a picture of a dog.",\n        },\n    }\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses dicts and is used for functions' outputs.\n Use RawDomainGroupsT for functions' inputs.

\n
\n", "default_value": "dict[frozenset[str], dict[str, typing.Any]]"}, "shimmer.types.LatentsDomainGroupsT": {"fullname": "shimmer.types.LatentsDomainGroupsT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupsT", "kind": "variable", "doc": "

Mapping of LatentsDomainGroupT. Keys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).

\n\n
Example:
\n\n
\n
\n
def fun(domain_group: LatentsDomainGroupsT): ...\n\n\nx = {\n    frozenset(["vision"]): {\n        "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),\n    },\n    frozenset(["language"]): {\n        "language": torch.Tensor([1.0, 0.2, 0.9, ...]),\n    },\n    frozenset(["vision", "language"]): {\n        "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),\n        "language": torch.Tensor([0.0, 0.3, 0.2, ...]),\n    },\n}\n\nfun(x)\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses collections.abc.Mapping and is used for functions' inputs.\n Use LatentsDomainGroupsDT for functions' outputs.

\n \n

This allows to be more generic and allow passing other mappings.

\n
\n", "default_value": "collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]"}, "shimmer.types.LatentsDomainGroupsDT": {"fullname": "shimmer.types.LatentsDomainGroupsDT", "modulename": "shimmer.types", "qualname": "LatentsDomainGroupsDT", "kind": "variable", "doc": "

Mapping of LatentsDomainGroupDT.\nKeys are frozenset of domains matched in the group.\nEach group is independent and contains different data (unpaired).

\n\n
Example:
\n\n
\n
\n
def fun() -> LatentsDomainGroupsDT:\n    return {\n        frozenset(["vision"]): {\n            "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),\n        },\n        frozenset(["language"]): {\n            "language": torch.Tensor([1.0, 0.2, 0.9, ...]),\n        },\n        frozenset(["vision", "language"]): {\n            "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),\n            "language": torch.Tensor([0.0, 0.3, 0.2, ...]),\n        },\n    }\n
\n
\n
\n\n
Note:
\n\n
\n

This type uses dicts and is used for functions' outputs.\n Use LatentsDomainGroupT for functions' inputs.

\n
\n", "default_value": "dict[frozenset[str], dict[str, torch.Tensor]]"}, "shimmer.types.ModelModeT": {"fullname": "shimmer.types.ModelModeT", "modulename": "shimmer.types", "qualname": "ModelModeT", "kind": "variable", "doc": "

Mode used by pytorch lightning (train/val, ...).

\n\n

When validating or testing in out-of-distribution data, \"val/ood\" or \"test/ood\" mode is\nused.

\n", "default_value": "typing.Literal['train', 'val', 'test', 'val/ood', 'test/ood']"}, "shimmer.modules.global_workspace": {"fullname": "shimmer.modules.global_workspace", "modulename": "shimmer.modules.global_workspace", "kind": "module", "doc": "

\n"}, "shimmer.modules.global_workspace.SchedulerArgs": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs", "kind": "class", "doc": "

TypedDict of arguments passed to the OneCycle scheduler

\n", "bases": "typing.TypedDict"}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs.max_lr", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs.max_lr", "kind": "variable", "doc": "

Maximum learning rate

\n", "annotation": ": float"}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"fullname": "shimmer.modules.global_workspace.SchedulerArgs.total_steps", "modulename": "shimmer.modules.global_workspace", "qualname": "SchedulerArgs.total_steps", "kind": "variable", "doc": "

Total number of steps

\n", "annotation": ": int"}, "shimmer.modules.global_workspace.GWPredictionsBase": {"fullname": "shimmer.modules.global_workspace.GWPredictionsBase", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictionsBase", "kind": "class", "doc": "

TypedDict of the output given when calling GlobalWorkspaceBase.predict

\n", "bases": "typing.TypedDict"}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"fullname": "shimmer.modules.global_workspace.GWPredictionsBase.states", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictionsBase.states", "kind": "variable", "doc": "

GW state representation from domain groups with only one domain.\nThe key represent the domain's name.

\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase", "kind": "class", "doc": "

Global Workspace Lightning Module.

\n\n

This is the base class to build the Global Workspace.

\n", "bases": "typing.Generic[~_T_gw_mod, ~_T_selection_mod, ~_T_loss_mod], lightning.pytorch.core.module.LightningModule"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.gw_mod", "kind": "variable", "doc": "

a GWModuleBase implementation.

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.selection_mod", "kind": "variable", "doc": "

A SelectionBase implementation.

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.loss_mod", "kind": "variable", "doc": "

The module that computes losses of the GW

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.optim_lr", "kind": "variable", "doc": "

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.optim_weight_decay", "kind": "variable", "doc": "

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.scheduler_args", "kind": "variable", "doc": "

\n"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.domain_mods", "kind": "variable", "doc": "

\n", "annotation": ": collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.workspace_dim", "kind": "variable", "doc": "

Dimension of the GW.

\n", "annotation": ": int"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_and_fuse", "kind": "function", "doc": "

Encode a group of latent representations into the GW representation.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupsT): the input domain representations.
  • \n
  • selection_scores (Mapping[str, torch.Tensor]):
  • \n
\n\n
Returns:
\n\n
\n

dict[frozenset[str], torch.Tensor]: the GW representations.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tselection_module: shimmer.modules.selection.SelectionBase) -> dict[frozenset[str], torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode", "kind": "function", "doc": "

Encode a group of latent representations into the pre-fusion GW representation.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupsT): the input domain representations.
  • \n
\n\n
Returns:
\n\n
\n

LatensDomainGroupsDT: the GW representations.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, torch.Tensor]]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.fuse", "kind": "function", "doc": "

Fuses a group of latent representations into the GW representation.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupsT): the pre-fusion latent representations
  • \n
  • selection_scores (Mapping[frozenset[str], Mapping[str, torch.Tensor]]): selection scores for each group
  • \n
\n\n
Returns:
\n\n
\n

dict[frozenset[str], torch.Tensor]: GW representation of each group

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tselection_scores: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode", "kind": "function", "doc": "

Decode the group GW representation into given domains.

\n\n
Arguments:
\n\n
    \n
  • z (torch.Tensor): the GW representation.
  • \n
  • domains (Iterable[str]): iterable of domains to decode.
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: the decoded unimodal representations.

\n
\n", "signature": "(\tself,\tz: collections.abc.Mapping[frozenset[str], torch.Tensor],\tdomains: collections.abc.Iterable[str] | None = None) -> dict[frozenset[str], dict[str, torch.Tensor]]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.batch_gw_states", "kind": "function", "doc": "

Comptues GW states of a batch of groups of domains.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsT): the batch of groups of domains
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: states for each domain.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_domain", "kind": "function", "doc": "

Encodes a domain from the domain data into the unimodal representation.

\n\n

This is a convenient proxy for the DomainModule.encode method and is\nequivalent to:

\n\n
\n
self.domain_mods[name].encode(domain)\n
\n
\n\n
Arguments:
\n\n
    \n
  • domain (Any): the domain data
  • \n
  • name (str): domain name to encode
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the domain's unimodal representation.

\n
\n", "signature": "(self, domain: Any, name: str) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.encode_domains", "kind": "function", "doc": "

Encode all domains in the batch.

\n\n
Arguments:
\n\n
    \n
  • batch (RawDomainGroupsT): the batch of\ndomain groups with raw unimodal data to encode into groups of latent\nrepresentations.
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupsDT: the domains' unimodal representations.

\n
\n", "signature": "(\tself,\tbatch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]) -> dict[frozenset[str], dict[str, torch.Tensor]]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode_domain", "kind": "function", "doc": "

Decodes a domain from the unimodal representation into the domain data.

\n\n

This is a convenient proxy for the DomainModule.encode method and is\nequivalent to:

\n\n
\n
self.domain_mods[name].decode(domain)\n
\n
\n\n
Arguments:
\n\n
    \n
  • domain (torch.Tensor): the domain data
  • \n
  • name (str): domain name to encode
  • \n
\n\n
Returns:
\n\n
\n

Any: the domain's raw data.

\n
\n", "signature": "(self, domain: torch.Tensor, name: str) -> Any:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.decode_domains", "kind": "function", "doc": "

Decodes all domains in the batch.

\n\n
Arguments:
\n\n
    \n
  • batch (LatentsDomainGroupsT): the batch of\ndomain groups with unimodal latent representation to decode into\ngroups of raw data.
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupsDT: the domains' raw data.

\n
\n", "signature": "(\tself,\tlatents_domain: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, typing.Any]]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBase.generic_step", "kind": "function", "doc": "

The generic step used in training_step, validation_step and\ntest_step.

\n\n
Arguments:
\n\n
    \n
  • batch (RawDomainGroupsT): the batch of groups of raw unimodal data.
  • \n
  • mode (ModelModeT):
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the loss to train on.

\n
\n", "signature": "(\tself,\tbatch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]],\tmode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.global_workspace.freeze_domain_modules": {"fullname": "shimmer.modules.global_workspace.freeze_domain_modules", "modulename": "shimmer.modules.global_workspace", "qualname": "freeze_domain_modules", "kind": "function", "doc": "

Freezes weights and set to eval mode the domain modules.

\n\n
\n\n

The output is casted as dict[str, DomainModule] type for better\nauto-completion, but is actually a torch ModuleDict.

\n\n
\n\n
Arguments:
\n\n
    \n
  • domain_mods (Mapping[str, DomainModule]): mapping of domain modules to freeze
  • \n
\n\n
Returns:
\n\n
\n

ModuleDict: frozen modules.

\n
\n", "signature": "(\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]) -> dict[str, shimmer.modules.domain.DomainModule]:", "funcdef": "def"}, "shimmer.modules.global_workspace.GWPredictions": {"fullname": "shimmer.modules.global_workspace.GWPredictions", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions", "kind": "class", "doc": "

TypedDict of the output given when calling GlobalWorkspaceBase.predict

\n", "bases": "builtins.dict"}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"fullname": "shimmer.modules.global_workspace.GWPredictions.demi_cycles", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.demi_cycles", "kind": "variable", "doc": "

Demi-cycle predictions of the model for each domain. Only computed on domain\ngroups with only one domain.

\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"fullname": "shimmer.modules.global_workspace.GWPredictions.cycles", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.cycles", "kind": "variable", "doc": "

Cycle predictions of the model from one domain through another one.\nOnly computed on domain groups with more than one domain.\nThe keys are tuple with start domain and intermediary domain.

\n", "annotation": ": dict[tuple[str, str], torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.translations": {"fullname": "shimmer.modules.global_workspace.GWPredictions.translations", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.translations", "kind": "variable", "doc": "

Translation predictions of the model from one domain through another one.

\n\n

Only computed on domain groups with more than one domain.\nThe keys are tuples with start domain and target domain.

\n", "annotation": ": dict[tuple[str, str], torch.Tensor]"}, "shimmer.modules.global_workspace.GWPredictions.states": {"fullname": "shimmer.modules.global_workspace.GWPredictions.states", "modulename": "shimmer.modules.global_workspace", "qualname": "GWPredictions.states", "kind": "variable", "doc": "

\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains", "kind": "class", "doc": "

A simple 2-domains max flavor of GlobalWorkspaceBase.

\n\n

This is used to simplify a Global Workspace instanciation and only overrides the\n__init__ method.

\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModule, shimmer.modules.selection.SingleDomainSelection, shimmer.modules.losses.GWLosses2Domains]"}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains.__init__", "kind": "function", "doc": "

Initializes a Global Workspace

\n\n
Arguments:
\n\n
    \n
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to decode a\nGW representation into a unimodal latent representations.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • loss_coefs (LossCoefs): loss coefficients
  • \n
  • optim_lr (float): learning rate
  • \n
  • optim_weight_decay (float): weight decay
  • \n
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • \n
  • learn_logit_scale (bool): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss.
  • \n
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss\nfunction used for alignment. learn_logit_scale will not affect custom\ncontrastive losses.
  • \n
\n", "signature": "(\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tworkspace_dim: int,\tloss_coefs: shimmer.modules.losses.LossCoefs,\toptim_lr: float = 0.001,\toptim_weight_decay: float = 0.0,\tscheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None,\tlearn_logit_scale: bool = False,\tcontrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None)"}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace2Domains.forward", "kind": "function", "doc": "

Computes demi-cycles, cycles, and translations.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsT): Groups of domains for the computation.
  • \n
\n\n
Returns:
\n\n
\n

GWPredictions: the predictions on the batch.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> shimmer.modules.global_workspace.GWPredictions:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspace": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace", "kind": "class", "doc": "

The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.

\n\n

This is used to simplify a Global Workspace instanciation and only overrides the\n__init__ method.

\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModule, shimmer.modules.selection.RandomSelection, shimmer.modules.losses.GWLosses]"}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace.__init__", "kind": "function", "doc": "

Initializes a Global Workspace

\n\n
Arguments:
\n\n
    \n
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to decode a\nGW representation into a unimodal latent representations.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • loss_coefs (BroadcastLossCoefs): loss coefs for the losses.
  • \n
  • selection_temperature (float): temperature value for the RandomSelection\nmodule.
  • \n
  • optim_lr (float): learning rate
  • \n
  • optim_weight_decay (float): weight decay
  • \n
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • \n
  • learn_logit_scale (bool): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss.
  • \n
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss\nfunction used for alignment. learn_logit_scale will not affect custom\ncontrastive losses.
  • \n
\n", "signature": "(\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tworkspace_dim: int,\tloss_coefs: shimmer.modules.losses.BroadcastLossCoefs,\tselection_temperature: float = 0.2,\toptim_lr: float = 0.001,\toptim_weight_decay: float = 0.0,\tscheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None,\tlearn_logit_scale: bool = False,\tcontrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None)"}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspace.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspace.forward", "kind": "function", "doc": "

Computes demi-cycles, cycles, and translations.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsT): Groups of domains for the computation.
  • \n
\n\n
Returns:
\n\n
\n

GWPredictions: the predictions on the batch.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> shimmer.modules.global_workspace.GWPredictions:", "funcdef": "def"}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian", "kind": "class", "doc": "

A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty\nprediction.

\n\n

This is used to simplify a Global Workspace instanciation and only overrides the\n__init__ method.

\n", "bases": "shimmer.modules.global_workspace.GlobalWorkspaceBase[shimmer.modules.gw_module.GWModuleBayesian, shimmer.modules.selection.FixedSharedSelection, shimmer.modules.losses.GWLossesBayesian]"}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian.__init__", "kind": "function", "doc": "

Initializes a Global Workspace

\n\n
Arguments:
\n\n
    \n
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to decode a\nGW representation into a unimodal latent representations.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • loss_coefs (LossCoefs): loss coefficients
  • \n
  • sensitivity_selection (float): sensivity coef $c'_1$
  • \n
  • sensitivity_precision (float): sensitivity coef $c'_2$
  • \n
  • optim_lr (float): learning rate
  • \n
  • optim_weight_decay (float): weight decay
  • \n
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • \n
  • learn_logit_scale (bool): whether to learn the contrastive learning\ncontrastive loss when using the default contrastive loss.
  • \n
  • use_normalized_constrastive (bool): whether to use the normalized cont\nloss by the precision coefs
  • \n
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss\nfunction used for alignment. learn_logit_scale will not affect custom\ncontrastive losses.
  • \n
  • precision_softmax_temp (float): temperature to use in softmax of\nprecision
  • \n
\n", "signature": "(\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tworkspace_dim: int,\tloss_coefs: shimmer.modules.losses.BroadcastLossCoefs,\tsensitivity_selection: float = 1,\tsensitivity_precision: float = 1,\toptim_lr: float = 0.001,\toptim_weight_decay: float = 0.0,\tscheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None,\tlearn_logit_scale: bool = False,\tuse_normalized_constrastive: bool = True,\tcontrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None,\tprecision_softmax_temp: float = 0.01)"}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"fullname": "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward", "modulename": "shimmer.modules.global_workspace", "qualname": "GlobalWorkspaceBayesian.forward", "kind": "function", "doc": "

Computes demi-cycles, cycles, and translations.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsT): Groups of domains for the computation.
  • \n
\n\n
Returns:
\n\n
\n

GWPredictions: the predictions on the batch.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> shimmer.modules.global_workspace.GWPredictions:", "funcdef": "def"}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"fullname": "shimmer.modules.global_workspace.pretrained_global_workspace", "modulename": "shimmer.modules.global_workspace", "qualname": "pretrained_global_workspace", "kind": "function", "doc": "

Load a GlobalWorkspace flavor of GlobalWorkspaceBase from a checkpoint.

\n\n
Arguments:
\n\n
    \n
  • checkpoint_path (str | Path): path to checkpoint
  • \n
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains\nconnected to the GW. Keys are domain names, values are the\nDomainModule.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to encode a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a torch.nn.Module class which role is to decode a\nGW representation into a unimodal latent representations.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • loss_coefs (LossCoefs): loss coefficients
  • \n
  • contrastive_loss (ContrastiveLossType): a contrastive loss\nfunction used for alignment. learn_logit_scale will not affect custom\ncontrastive losses.
  • \n
  • **kwargs: additional arguments to pass to\nGlobalWorkspace.load_from_checkpoint.
  • \n
\n\n
Returns:
\n\n
\n

GlobalWorkspace: the pretrained GlobalWorkspace.

\n
\n\n
Raises:
\n\n
    \n
  • TypeError: if loaded type is not GlobalWorkspace.
  • \n
\n", "signature": "(\tcheckpoint_path: str | pathlib.Path,\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tworkspace_dim: int,\tloss_coefs: shimmer.modules.losses.LossCoefs,\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput],\t**kwargs) -> shimmer.modules.global_workspace.GlobalWorkspace2Domains:", "funcdef": "def"}, "shimmer.modules.domain": {"fullname": "shimmer.modules.domain", "modulename": "shimmer.modules.domain", "kind": "module", "doc": "

\n"}, "shimmer.modules.domain.LossOutput": {"fullname": "shimmer.modules.domain.LossOutput", "modulename": "shimmer.modules.domain", "qualname": "LossOutput", "kind": "class", "doc": "

This is a python dataclass use as a returned value for losses.\nIt keeps track of what is used for training (loss) and what is used\nonly for logging (metrics).

\n"}, "shimmer.modules.domain.LossOutput.__init__": {"fullname": "shimmer.modules.domain.LossOutput.__init__", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.__init__", "kind": "function", "doc": "

\n", "signature": "(loss: torch.Tensor, metrics: dict[str, torch.Tensor] = <factory>)"}, "shimmer.modules.domain.LossOutput.loss": {"fullname": "shimmer.modules.domain.LossOutput.loss", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.loss", "kind": "variable", "doc": "

Loss used during training.

\n", "annotation": ": torch.Tensor"}, "shimmer.modules.domain.LossOutput.metrics": {"fullname": "shimmer.modules.domain.LossOutput.metrics", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.metrics", "kind": "variable", "doc": "

Some additional metrics to log (not used during training).

\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.domain.LossOutput.all": {"fullname": "shimmer.modules.domain.LossOutput.all", "modulename": "shimmer.modules.domain", "qualname": "LossOutput.all", "kind": "variable", "doc": "

Returns a dict with all metrics and loss with \"loss\" key.

\n", "annotation": ": dict[str, torch.Tensor]"}, "shimmer.modules.domain.DomainModule": {"fullname": "shimmer.modules.domain.DomainModule", "modulename": "shimmer.modules.domain", "qualname": "DomainModule", "kind": "class", "doc": "

Base class for a DomainModule that defines domain specific modules of the GW.

\n", "bases": "lightning.pytorch.core.module.LightningModule"}, "shimmer.modules.domain.DomainModule.__init__": {"fullname": "shimmer.modules.domain.DomainModule.__init__", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.__init__", "kind": "function", "doc": "

Initializes a DomainModule.

\n\n
Arguments:
\n\n
    \n
  • latent_dim (int): latent dimension of the unimodal module
  • \n
\n", "signature": "(latent_dim: int)"}, "shimmer.modules.domain.DomainModule.latent_dim": {"fullname": "shimmer.modules.domain.DomainModule.latent_dim", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.latent_dim", "kind": "variable", "doc": "

The latent dimension of the module.

\n"}, "shimmer.modules.domain.DomainModule.encode": {"fullname": "shimmer.modules.domain.DomainModule.encode", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.encode", "kind": "function", "doc": "

Encode the domain data into a unimodal representation.

\n\n
Arguments:
\n\n
    \n
  • x (Any): data of the domain.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: a unimodal representation.

\n
\n", "signature": "(self, x: Any) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.decode": {"fullname": "shimmer.modules.domain.DomainModule.decode", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.decode", "kind": "function", "doc": "

Decode data from unimodal representation back to the domain data.

\n\n
Arguments:
\n\n
    \n
  • z (torch.Tensor): unimodal representation of the domain.
  • \n
\n\n
Returns:
\n\n
\n

Any: the original domain data.

\n
\n", "signature": "(self, z: torch.Tensor) -> Any:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_loss", "kind": "function", "doc": "

Generic loss computation the modality.

\n\n
Arguments:
\n\n
    \n
  • pred (torch.Tensor): prediction of the model
  • \n
  • target (torch.Tensor): target tensor
  • \n
\n\n
Results:
\n\n
\n

LossOutput: LossOuput with training loss and additional metrics.

\n
\n", "signature": "(\tself,\tpred: torch.Tensor,\ttarget: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_dcy_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_dcy_loss", "kind": "function", "doc": "

Computes the loss for a demi-cycle. Override if the demi-cycle loss is\ndifferent that the generic loss.

\n\n
Arguments:
\n\n
    \n
  • pred (torch.Tensor): prediction of the model
  • \n
  • target (torch.Tensor): target tensor
  • \n
\n\n
Results:
\n\n
\n

LossOutput: LossOuput with training loss and additional metrics.

\n
\n", "signature": "(\tself,\tpred: torch.Tensor,\ttarget: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_cy_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_cy_loss", "kind": "function", "doc": "

Computes the loss for a cycle. Override if the cycle loss is\ndifferent that the generic loss.

\n\n
Arguments:
\n\n
    \n
  • pred (torch.Tensor): prediction of the model
  • \n
  • target (torch.Tensor): target tensor
  • \n
\n\n
Results:
\n\n
\n

LossOutput: LossOuput with training loss and additional metrics.

\n
\n", "signature": "(\tself,\tpred: torch.Tensor,\ttarget: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_tr_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_tr_loss", "kind": "function", "doc": "

Computes the loss for a translation. Override if the translation loss is\ndifferent that the generic loss.

\n\n
Arguments:
\n\n
    \n
  • pred (torch.Tensor): prediction of the model
  • \n
  • target (torch.Tensor): target tensor
  • \n
\n\n
Results:
\n\n
\n

LossOutput: LossOuput with training loss and additional metrics.

\n
\n", "signature": "(\tself,\tpred: torch.Tensor,\ttarget: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"fullname": "shimmer.modules.domain.DomainModule.compute_broadcast_loss", "modulename": "shimmer.modules.domain", "qualname": "DomainModule.compute_broadcast_loss", "kind": "function", "doc": "

Computes the loss for a broadcast (fusion). Override if the broadcast loss is\ndifferent that the generic loss.

\n\n
Arguments:
\n\n
    \n
  • pred (torch.Tensor): prediction of the model
  • \n
  • target (torch.Tensor): target tensor
  • \n
\n\n
Results:
\n\n
\n

LossOutput: LossOuput with training loss and additional metrics.

\n
\n", "signature": "(\tself,\tpred: torch.Tensor,\ttarget: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.gw_module": {"fullname": "shimmer.modules.gw_module", "modulename": "shimmer.modules.gw_module", "kind": "module", "doc": "

\n"}, "shimmer.modules.gw_module.get_n_layers": {"fullname": "shimmer.modules.gw_module.get_n_layers", "modulename": "shimmer.modules.gw_module", "qualname": "get_n_layers", "kind": "function", "doc": "

Makes a list of n_layers nn.Linear layers with nn.ReLU.

\n\n
Arguments:
\n\n
    \n
  • n_layers (int): number of layers
  • \n
  • hidden_dim (int): size of the hidden dimension
  • \n
\n\n
Returns:
\n\n
\n

list[nn.Module]: list of linear and relu layers.

\n
\n", "signature": "(n_layers: int, hidden_dim: int) -> list[torch.nn.modules.module.Module]:", "funcdef": "def"}, "shimmer.modules.gw_module.GWDecoder": {"fullname": "shimmer.modules.gw_module.GWDecoder", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder", "kind": "class", "doc": "

A Decoder network for GWModules.

\n", "bases": "torch.nn.modules.container.Sequential"}, "shimmer.modules.gw_module.GWDecoder.__init__": {"fullname": "shimmer.modules.gw_module.GWDecoder.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.__init__", "kind": "function", "doc": "

Initializes the decoder.

\n\n
Arguments:
\n\n
    \n
  • in_dim (int): input dimension
  • \n
  • hidden_dim (int): hidden dimension
  • \n
  • out_dim (int): output dimension
  • \n
  • n_layers (int): number of hidden layers. The total number of layers\nwill be n_layers + 2 (one before, one after).
  • \n
\n", "signature": "(in_dim: int, hidden_dim: int, out_dim: int, n_layers: int)"}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.in_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.in_dim", "kind": "variable", "doc": "

input dimension

\n"}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.hidden_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.hidden_dim", "kind": "variable", "doc": "

hidden dimension

\n"}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"fullname": "shimmer.modules.gw_module.GWDecoder.out_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.out_dim", "kind": "variable", "doc": "

output dimension

\n"}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"fullname": "shimmer.modules.gw_module.GWDecoder.n_layers", "modulename": "shimmer.modules.gw_module", "qualname": "GWDecoder.n_layers", "kind": "variable", "doc": "

number of hidden layers. The total number of layers\n will be n_layers + 2 (one before, one after).

\n"}, "shimmer.modules.gw_module.GWEncoder": {"fullname": "shimmer.modules.gw_module.GWEncoder", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder", "kind": "class", "doc": "

An Encoder network used in GWModules.

\n\n

This is similar to the decoder, but adds a tanh non-linearity at the end.

\n", "bases": "GWDecoder"}, "shimmer.modules.gw_module.GWEncoder.__init__": {"fullname": "shimmer.modules.gw_module.GWEncoder.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder.__init__", "kind": "function", "doc": "

Initializes the encoder.

\n\n
Arguments:
\n\n
    \n
  • in_dim (int): input dimension
  • \n
  • hidden_dim (int): hidden dimension
  • \n
  • out_dim (int): output dimension
  • \n
  • n_layers (int): number of hidden layers. The total number of layers\nwill be n_layers + 2 (one before, one after).
  • \n
\n", "signature": "(in_dim: int, hidden_dim: int, out_dim: int, n_layers: int)"}, "shimmer.modules.gw_module.GWEncoder.forward": {"fullname": "shimmer.modules.gw_module.GWEncoder.forward", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoder.forward", "kind": "function", "doc": "

Define the computation performed at every call.

\n\n

Should be overridden by all subclasses.

\n\n
\n\n

Although the recipe for forward pass needs to be defined within\nthis function, one should call the Module instance afterwards\ninstead of this since the former takes care of running the\nregistered hooks while the latter silently ignores them.

\n\n
\n", "signature": "(self, input: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWEncoderLinear": {"fullname": "shimmer.modules.gw_module.GWEncoderLinear", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoderLinear", "kind": "class", "doc": "

A linear Encoder network used in GWModules.

\n", "bases": "torch.nn.modules.linear.Linear"}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"fullname": "shimmer.modules.gw_module.GWEncoderLinear.forward", "modulename": "shimmer.modules.gw_module", "qualname": "GWEncoderLinear.forward", "kind": "function", "doc": "

Define the computation performed at every call.

\n\n

Should be overridden by all subclasses.

\n\n
\n\n

Although the recipe for forward pass needs to be defined within\nthis function, one should call the Module instance afterwards\ninstead of this since the former takes care of running the\nregistered hooks while the latter silently ignores them.

\n\n
\n", "signature": "(self, input: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase": {"fullname": "shimmer.modules.gw_module.GWModuleBase", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase", "kind": "class", "doc": "

Base class for GWModule.

\n\n

GWModule handles encoding, decoding the unimodal representations\nusing the gw_encoders andgw_decoders, and define\nsome common operations in GW like cycles and translations.

\n\n

This is an abstract class and should be implemented.\nFor an implemented interface, see GWModule.

\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"fullname": "shimmer.modules.gw_module.GWModuleBase.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.__init__", "kind": "function", "doc": "

Initializes the GWModule.

\n\n
Arguments:
\n\n
    \n
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
\n", "signature": "(\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tworkspace_dim: int,\t*args,\t**kwargs)"}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"fullname": "shimmer.modules.gw_module.GWModuleBase.domain_mods", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.domain_mods", "kind": "variable", "doc": "

The unimodal domain modules.

\n"}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"fullname": "shimmer.modules.gw_module.GWModuleBase.workspace_dim", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.workspace_dim", "kind": "variable", "doc": "

Dimension of the GW

\n"}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBase.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.fuse", "kind": "function", "doc": "

Merge function used to combine domains.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the group of latent representation.
  • \n
  • selection_score (Mapping[str, torch.Tensor]): attention scores to\nuse to encode the reprensetation.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The merged representation.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor],\tselection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.encode": {"fullname": "shimmer.modules.gw_module.GWModuleBase.encode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.encode", "kind": "function", "doc": "

Encode the latent representation infos to the pre-fusion GW representation.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the input domain representations
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupT: pre-fusion GW representations

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.encode_and_fuse", "kind": "function", "doc": "

Encode the latent representation infos to the final GW representation.\nIt combines the encode and fuse methods.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the input domain representations
  • \n
  • selection_score (Mapping[str, torch.Tensor]): attention scores to\nuse to encode the reprensetation.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The merged representation.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor],\tselection_module: shimmer.modules.selection.SelectionBase) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBase.decode": {"fullname": "shimmer.modules.gw_module.GWModuleBase.decode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBase.decode", "kind": "function", "doc": "

Decode the GW representation into given domains.

\n\n
Arguments:
\n\n
    \n
  • z (torch.Tensor): the GW representation.
  • \n
  • domains (Iterable[str]): iterable of domains to decode.
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupDT: the decoded unimodal representations.

\n
\n", "signature": "(\tself,\tz: torch.Tensor,\tdomains: collections.abc.Iterable[str] | None = None) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule": {"fullname": "shimmer.modules.gw_module.GWModule", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule", "kind": "class", "doc": "

GW nn.Module. Implements GWModuleBase.

\n", "bases": "GWModuleBase"}, "shimmer.modules.gw_module.GWModule.__init__": {"fullname": "shimmer.modules.gw_module.GWModule.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.__init__", "kind": "function", "doc": "

Initializes the GWModule.

\n\n
Arguments:
\n\n
    \n
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a an torch.nn.Module class that encodes a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a an torch.nn.Module class that decodes a\n GW representation to a unimodal latent representation.
  • \n
\n", "signature": "(\tdomain_modules: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tworkspace_dim: int,\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module])"}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"fullname": "shimmer.modules.gw_module.GWModule.gw_encoders", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.gw_encoders", "kind": "variable", "doc": "

The module's encoders

\n"}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"fullname": "shimmer.modules.gw_module.GWModule.gw_decoders", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.gw_decoders", "kind": "variable", "doc": "

The module's decoders

\n"}, "shimmer.modules.gw_module.GWModule.fuse": {"fullname": "shimmer.modules.gw_module.GWModule.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.fuse", "kind": "function", "doc": "

Merge function used to combine domains.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the group of latent representation.
  • \n
  • selection_score (Mapping[str, torch.Tensor]): attention scores to\nuse to encode the reprensetation.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The merged representation.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor],\tselection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule.encode": {"fullname": "shimmer.modules.gw_module.GWModule.encode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.encode", "kind": "function", "doc": "

Encode the latent representation infos to the pre-fusion GW representation.

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the input domain representations.
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupT: pre-fusion representation

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModule.decode": {"fullname": "shimmer.modules.gw_module.GWModule.decode", "modulename": "shimmer.modules.gw_module", "qualname": "GWModule.decode", "kind": "function", "doc": "

Decodes a GW representation to multiple domains.

\n\n
Arguments:
\n\n
    \n
  • z (torch.Tensor): the GW representation
  • \n
  • domains (Iterable[str] | None): the domains to decode to. Defaults to\nuse keys in gw_interfaces (all domains).
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupDT: decoded unimodal representation

\n
\n", "signature": "(\tself,\tz: torch.Tensor,\tdomains: collections.abc.Iterable[str] | None = None) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.gw_module.compute_fusion_scores": {"fullname": "shimmer.modules.gw_module.compute_fusion_scores", "modulename": "shimmer.modules.gw_module", "qualname": "compute_fusion_scores", "kind": "function", "doc": "

Combine precision scores using std summation in quadrature

\n\n

The two scores should have the same dimension.

\n\n
Arguments:
\n\n
    \n
  • score_1 (torch.Tensor): First scores.
  • \n
  • score_2 (torch.Tensor): Second scores.
  • \n
  • sensitivity_1 (float): sensitivity for the first score
  • \n
  • sensitivity_2 (float): sensitivity for the second score
  • \n
  • eps (float): a value added to avoid numerical unstability.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the combined scores

\n
\n", "signature": "(\tscore_1: torch.Tensor,\tscore_2: torch.Tensor,\tsensitivity_1: float = 1.0,\tsensitivity_2: float = 1.0,\teps: float = 1e-06) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBayesian": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian", "kind": "class", "doc": "

GWModule with a Bayesian based uncertainty prediction.

\n", "bases": "GWModule"}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.__init__", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.__init__", "kind": "function", "doc": "

Initializes the GWModuleBayesian.

\n\n
Arguments:
\n\n
    \n
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • \n
  • workspace_dim (int): dimension of the GW.
  • \n
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a an torch.nn.Module class that encodes a\nunimodal latent representations into a GW representation (pre fusion).
  • \n
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain\nname to a an torch.nn.Module class that decodes a\n GW representation to a unimodal latent representation.
  • \n
  • sensitivity_selection (float): sensivity coef $c'_1$
  • \n
  • sensitivity_precision (float): sensitivity coef $c'_2$
  • \n
  • precision_softmax_temp (float): temperature to use in softmax of\nprecision
  • \n
\n", "signature": "(\tdomain_modules: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tworkspace_dim: int,\tgw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tgw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module],\tsensitivity_selection: float = 1,\tsensitivity_precision: float = 1,\tprecision_softmax_temp: float = 0.01)"}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.precisions", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.precisions", "kind": "variable", "doc": "

Precision at the neuron level for every domain.

\n"}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.sensitivity_selection", "kind": "variable", "doc": "

\n"}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.sensitivity_precision", "kind": "variable", "doc": "

\n"}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.precision_softmax_temp", "kind": "variable", "doc": "

\n"}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.get_precision", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.get_precision", "kind": "function", "doc": "

Get the precision vector of given domain and batch

\n\n
Arguments:
\n\n
    \n
  • domain (str):
  • \n
  • x (torch.Tensor): batch of inputs
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: batch of precision

\n
\n", "signature": "(self, domain: str, x: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"fullname": "shimmer.modules.gw_module.GWModuleBayesian.fuse", "modulename": "shimmer.modules.gw_module", "qualname": "GWModuleBayesian.fuse", "kind": "function", "doc": "

Merge function used to combine domains.

\n\n

In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the\ndimension of the Global Workspace.

\n\n

This function needs to merge two kind of scores:

\n\n
    \n
  • the selection scores $a\\in [0,1]^{D\\times N}$;
  • \n
  • the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
  • \n
\n\n
\n\n

The precision score is obtained by predicting logits and using a softmax

\n\n
\n\n

We can obtain associated uncertainties to the scores by introducing a std\nvariable and using bayesian integration:

\n\n

$$a_k = \\frac{M_1}{\\sigma_k^2}$$\nwhere $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.

\n\n

Similarly,\n$$b_k = \\frac{M_2}{\\mu_k^2}$$\nwhere $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.

\n\n

The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:\n$$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$

\n\n

which, in terms of $a_k$ and $b_k$ yields:\n$$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$\nwhere $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.

\n\n

Finally, the finale combined coefficient is\n$$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$\nwhere\n$$M_3 = \\frac{1}{\\sum_{i=1}^D\n \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$

\n\n
Arguments:
\n\n
    \n
  • x (LatentsDomainGroupT): the group of latent representation.
  • \n
  • selection_score (Mapping[str, torch.Tensor]): attention scores to\nuse to encode the reprensetation.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The merged representation.

\n
\n", "signature": "(\tself,\tx: collections.abc.Mapping[str, torch.Tensor],\tselection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.selection": {"fullname": "shimmer.modules.selection", "modulename": "shimmer.modules.selection", "kind": "module", "doc": "

\n"}, "shimmer.modules.selection.SelectionBase": {"fullname": "shimmer.modules.selection.SelectionBase", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase", "kind": "class", "doc": "

This is the base class for the selection mechanism.\nThe selection mechanisms handles the \"competition\" between modules and selects\nfusion coefficients for the domains.

\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"fullname": "shimmer.modules.selection.SelectionBase.update_gw_state", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase.update_gw_state", "kind": "function", "doc": "

Update the internal copy of the previous GW state.\nBy default, this is not implemented and will raise an error if used.

\n\n

:note..\n This is not defined as an abstractmethod as some selection method may\n not need it.

\n\n
Arguments:
\n\n
    \n
  • gw_state (torch.Tensor): the previous GW state
  • \n
\n", "signature": "(self, gw_state: torch.Tensor) -> None:", "funcdef": "def"}, "shimmer.modules.selection.SelectionBase.forward": {"fullname": "shimmer.modules.selection.SelectionBase.forward", "modulename": "shimmer.modules.selection", "qualname": "SelectionBase.forward", "kind": "function", "doc": "

Forward pass of the selection method.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: for each domain in the group, the fusion\n coefficient for each item in the batch.

\n
\n\n
Example:
\n\n
\n
\n
>>> SomeSelectionImplementation().forward(\n...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}\n... )\n{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}\n
\n
\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.selection.SingleDomainSelection": {"fullname": "shimmer.modules.selection.SingleDomainSelection", "modulename": "shimmer.modules.selection", "qualname": "SingleDomainSelection", "kind": "class", "doc": "

This selection mechanism handles groups that can have multiple domains, but always\nreturn a selection of 1 domain from the group with a uniform distribution.

\n\n

For example, if the group has 2 domains, there is a 50% chance of selecting each\ndomain.

\n", "bases": "SelectionBase"}, "shimmer.modules.selection.SingleDomainSelection.forward": {"fullname": "shimmer.modules.selection.SingleDomainSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "SingleDomainSelection.forward", "kind": "function", "doc": "

Forward pass of the module.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • \n
  • gw_state (torch.Tensor): the previous GW state
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: whether the domain is selected for each input\n in the batch.

\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.selection.FixedSharedSelection": {"fullname": "shimmer.modules.selection.FixedSharedSelection", "modulename": "shimmer.modules.selection", "qualname": "FixedSharedSelection", "kind": "class", "doc": "

This selection mechanism is deterministic and always shares the weights equally\nbetween domains.

\n\n

For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...

\n", "bases": "SelectionBase"}, "shimmer.modules.selection.FixedSharedSelection.forward": {"fullname": "shimmer.modules.selection.FixedSharedSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "FixedSharedSelection.forward", "kind": "function", "doc": "

Forward pass of the module.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • \n
  • gw_state (torch.Tensor): the previous GW state
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: whether the domain is selected for each input\n in the batch.

\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.selection.KQFixedQSelection": {"fullname": "shimmer.modules.selection.KQFixedQSelection", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection", "kind": "class", "doc": "

Key-Query attention with a fixed gw vector.

\n", "bases": "SelectionBase"}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"fullname": "shimmer.modules.selection.KQFixedQSelection.__init__", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.__init__", "kind": "function", "doc": "
Arguments:
\n\n
    \n
  • head_size (int) : dimension of the key and query vectors.
  • \n
  • domain_dim (int) : dimension of the input dims (assumed to be the same\nfor now)
  • \n
  • domain_names (Iterable[str]) : list of input domains
  • \n
\n", "signature": "(\thead_size: int,\tdomain_dim: int,\tdomain_names: collections.abc.Iterable[str])"}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"fullname": "shimmer.modules.selection.KQFixedQSelection.head_size", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.head_size", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"fullname": "shimmer.modules.selection.KQFixedQSelection.query_layer", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.query_layer", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"fullname": "shimmer.modules.selection.KQFixedQSelection.key_layers", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.key_layers", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.KQFixedQSelection.forward": {"fullname": "shimmer.modules.selection.KQFixedQSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "KQFixedQSelection.forward", "kind": "function", "doc": "

Compute keys and queries, match them with dot product and softmax.\nDoes this twice, once with the static query and once with a dynamic query.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • \n
  • encodings (LatentsDomainGroupT): Group of pre-fusion encodings.
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: the attention scores for each domain in the\n group.

\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.selection.RandomSelection": {"fullname": "shimmer.modules.selection.RandomSelection", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection", "kind": "class", "doc": "

Modified random attention to only utilize uniform-softmax scores across modalities.\nThis version omits the binary scaling factors and focuses on generating attention\ncoefficients using a uniform distribution followed by a domain-wise softmax.

\n", "bases": "SelectionBase"}, "shimmer.modules.selection.RandomSelection.__init__": {"fullname": "shimmer.modules.selection.RandomSelection.__init__", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.__init__", "kind": "function", "doc": "
Arguments:
\n\n
    \n
  • temperature (float): Temperature of the softmax applied to uniform\nscaling factors.
  • \n
\n", "signature": "(temperature: float)"}, "shimmer.modules.selection.RandomSelection.temperature": {"fullname": "shimmer.modules.selection.RandomSelection.temperature", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.temperature", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.RandomSelection.forward": {"fullname": "shimmer.modules.selection.RandomSelection.forward", "modulename": "shimmer.modules.selection", "qualname": "RandomSelection.forward", "kind": "function", "doc": "

Generate uniform-then-domain-wise-softmaxed samples for each domain.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.\nThis is not used in the function directly but determines the structure\nof the returned attention coefficients.
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: For each domain in the group, the fusion\n coefficient for each item in the batch, based solely on\n uniform-softmax scores.

\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.selection.DynamicQueryAttention": {"fullname": "shimmer.modules.selection.DynamicQueryAttention", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention", "kind": "class", "doc": "

Key-Query attention with a dynamic gw vector.\nThe query is updated based on the scaled gw vector.

\n", "bases": "SelectionBase"}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.__init__", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.__init__", "kind": "function", "doc": "
Arguments:
\n\n
    \n
  • head_size (int) : dimension of the key and query vectors.
  • \n
  • domain_dim (int) : dimension of the input dims (assumed to be the same\nfor now)
  • \n
  • domain_names (Iterable[str]) : list of input domains
  • \n
\n", "signature": "(\thead_size: int,\tdomain_dim: int,\tdomain_names: collections.abc.Iterable[str])"}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.head_size", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.head_size", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.query_layer", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.query_layer", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.key_layers", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.key_layers", "kind": "variable", "doc": "

\n"}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.fuse_weighted_encodings", "kind": "function", "doc": "

Fuse the weighted encodings using the attention scores.

\n\n
Arguments:
\n\n
    \n
  • encodings (LatentsDomainGroupT): Unimodal latent representation
  • \n
  • attention_dict (dict[str, torch.Tensor]): The attention scores for each\ndomain in the group.
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The fused tensor.

\n
\n", "signature": "(\tself,\tencodings: collections.abc.Mapping[str, torch.Tensor],\tattention_dict: dict[str, torch.Tensor]) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"fullname": "shimmer.modules.selection.DynamicQueryAttention.forward", "modulename": "shimmer.modules.selection", "qualname": "DynamicQueryAttention.forward", "kind": "function", "doc": "

Compute keys and queries, match them with dot product and softmax.\nDoes this twice, once with the static query and once with a dynamic query.

\n\n
Arguments:
\n\n
    \n
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • \n
  • encodings (LatentsDomainGroupT): Group of pre-fusion encodings.
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: the attention scores for each domain in the\n group.

\n
\n", "signature": "(\tself,\tdomains: collections.abc.Mapping[str, torch.Tensor],\tencodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses": {"fullname": "shimmer.modules.losses", "modulename": "shimmer.modules.losses", "kind": "module", "doc": "

\n"}, "shimmer.modules.losses.GWLossesBase": {"fullname": "shimmer.modules.losses.GWLossesBase", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBase", "kind": "class", "doc": "

Base Abstract Class for Global Workspace (GW) losses. This module is used\nto compute the different losses of the GW (typically translation, cycle,\ndemi-cycle, contrastive losses).

\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.losses.GWLossesBase.step": {"fullname": "shimmer.modules.losses.GWLossesBase.step", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBase.step", "kind": "function", "doc": "

Computes the losses.

\n\n
Arguments:
\n\n
    \n
  • domain_latents (LatentsDomainGroupsT): All latent groups
  • \n
  • mode (Literal[\"train\", \"val\", \"test\", \"val/ood\", \"test/ood\"]): model mode
  • \n
\n\n
Returns:
\n\n
\n

LossOutput: the losses

\n
\n", "signature": "(\tself,\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tmode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.losses.demi_cycle_loss": {"fullname": "shimmer.modules.losses.demi_cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "demi_cycle_loss", "kind": "function", "doc": "

Computes the demi-cycle loss.

\n\n
This return multiple metrics:
\n\n
\n
    \n
  • demi_cycle_{domain_name} with the demi-cycle of a particular domain;
  • \n
  • demi_cycle_{domain_name}_{metric} with additional metrics provided by\n the domain_mod's compute_dcy_loss output;
  • \n
  • demi_cycles with the average value of all demi_cycle_{domain_name} values.
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • gw_mod (shimmer.modules.gw_module.GWModuleBase): The GWModule to use
  • \n
  • selection_mod (shimmer.modules.selection.SelectionBase): Selection mod to use
  • \n
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • \n
  • latent_domains (shimmer.types.LatentsDomainGroupsT): the latent unimodal\ngroups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.cycle_loss": {"fullname": "shimmer.modules.losses.cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "cycle_loss", "kind": "function", "doc": "

Computes the cycle loss.

\n\n
This return multiple metrics:
\n\n
\n
    \n
  • cycle_{domain_source}_through_{domain_target} with the cycle of\n a particular domain;
  • \n
  • cycle_{domain_source}_through_{domain_target}_{metric} with additional\n metrics provided by the domain_mod's compute_cy_loss output;
  • \n
  • cycles with the average value of all\n cycle_{domain_source}_through_{domain_target} values.
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): The GWModule to use
  • \n
  • selection_mod (shimmer.modules.selection.SelectionBase): Selection mod to use
  • \n
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.translation_loss": {"fullname": "shimmer.modules.losses.translation_loss", "modulename": "shimmer.modules.losses", "qualname": "translation_loss", "kind": "function", "doc": "

Computes the translation loss.

\n\n
This return multiple metrics:
\n\n
\n
    \n
  • translation_{domain_source}_to_{domain_target} with the translation\n from a domain source to a domain target;
  • \n
  • translation_{domain_source}_to_{domain_target}_{metric} with\n additional metrics provided by the domain_mod's\n compute_tr_loss output;
  • \n
  • translations with the average value of all\n translation_{domain_source}_to_{domain_target} values.
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): The GWModule to use
  • \n
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.contrastive_loss": {"fullname": "shimmer.modules.losses.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "contrastive_loss", "kind": "function", "doc": "

Computes the contrastive loss.

\n\n
This return multiple metrics:
\n\n
\n
    \n
  • contrastive_{domain_1}_and_{domain_2} with the contrastive\n between 2 domains;
  • \n
  • contrastive_{domain_1}_and_{domain_2}_{metric} with\n additional metrics provided by the domain_mod's\n compute_cont_loss output;
  • \n
  • contrastives with the average value of all\n contrastive_{domain_1}_and_{domain_2} values.
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): The GWModule to use
  • \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
  • contrastive_fn (ContrastiveLossType): the contrastive function to apply
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.contrastive_loss_bayesian": {"fullname": "shimmer.modules.losses.contrastive_loss_bayesian", "modulename": "shimmer.modules.losses", "qualname": "contrastive_loss_bayesian", "kind": "function", "doc": "

Computes the contrastive loss with a Bayesian based uncertainty prediction.

\n\n
This return multiple metrics:
\n\n
\n
    \n
  • contrastive_{domain_1}_and_{domain_2} with the contrastive\n between 2 domains;
  • \n
  • contrastive_{domain_1}_and_{domain_2}_{metric} with\n additional metrics provided by the domain_mod's\n compute_cont_loss output;
  • \n
  • contrastives with the average value of all\n contrastive_{domain_1}_and_{domain_2} values.
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBayesian): The GWModule to use
  • \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
  • contrastive_fn (ContrastiveLossBayesianType): the contrastive function\nto apply
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBayesian,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.LossCoefs": {"fullname": "shimmer.modules.losses.LossCoefs", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs", "kind": "class", "doc": "

Dict of loss coefficients used in the GWLosses.

\n\n

If one is not provided, the coefficient is assumed to be 0 and will not be logged.\nIf the loss is excplicitely set to 0, it will be logged, but not take part in\nthe total loss.

\n", "bases": "typing.TypedDict"}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"fullname": "shimmer.modules.losses.LossCoefs.demi_cycles", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.demi_cycles", "kind": "variable", "doc": "

Demi-cycle loss coefficient.

\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.cycles": {"fullname": "shimmer.modules.losses.LossCoefs.cycles", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.cycles", "kind": "variable", "doc": "

Cycle loss coefficient.

\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.translations": {"fullname": "shimmer.modules.losses.LossCoefs.translations", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.translations", "kind": "variable", "doc": "

Translation loss coefficient.

\n", "annotation": ": float"}, "shimmer.modules.losses.LossCoefs.contrastives": {"fullname": "shimmer.modules.losses.LossCoefs.contrastives", "modulename": "shimmer.modules.losses", "qualname": "LossCoefs.contrastives", "kind": "variable", "doc": "

Contrastive loss coefficient.

\n", "annotation": ": float"}, "shimmer.modules.losses.GWLosses2Domains": {"fullname": "shimmer.modules.losses.GWLosses2Domains", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains", "kind": "class", "doc": "

Implementation of GWLossesBase used for GWModule.

\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"fullname": "shimmer.modules.losses.GWLosses2Domains.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.__init__", "kind": "function", "doc": "

Main loss module to use with the GlobalWorkspace

\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModule): the GWModule
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • domain_mods (dict[str, DomainModule]): a dict where the key is the\ndomain name and value is the DomainModule
  • \n
  • loss_coefs (LossCoefs): loss coefficients. LossCoefs object, or a\nmapping to float with correct keys.
  • \n
  • contrastive_fn (ContrastiveLossType): the contrastive function to use\nin contrastive loss
  • \n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModule,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: dict[str, shimmer.modules.domain.DomainModule],\tloss_coefs: shimmer.modules.losses.LossCoefs,\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput])"}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"fullname": "shimmer.modules.losses.GWLosses2Domains.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.gw_mod", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"fullname": "shimmer.modules.losses.GWLosses2Domains.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.selection_mod", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"fullname": "shimmer.modules.losses.GWLosses2Domains.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.domain_mods", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"fullname": "shimmer.modules.losses.GWLosses2Domains.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.loss_coefs", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLosses2Domains.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.contrastive_fn", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.demi_cycle_loss", "kind": "function", "doc": "

Computes the demi-cycle loss.

\n\n

See shimmer.modules.losses.demi_cycle_loss.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.cycle_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.cycle_loss", "kind": "function", "doc": "

Computes the cycle loss.

\n\n

See shimmer.modules.losses.cycle_loss.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.translation_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.translation_loss", "kind": "function", "doc": "

Computes the translation loss.

\n\n

See shimmer.modules.losses.translation_loss.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLosses2Domains.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.contrastive_loss", "kind": "function", "doc": "

Computes the contrastive loss.

\n\n

See shimmer.modules.losses.contrastive_loss.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses2Domains.step": {"fullname": "shimmer.modules.losses.GWLosses2Domains.step", "modulename": "shimmer.modules.losses", "qualname": "GWLosses2Domains.step", "kind": "function", "doc": "

Computes and returns the losses

\n\n
Contains:
\n\n
\n
    \n
  • Demi-cycle metrics (see GWLosses.demi_cycle_loss)
  • \n
  • Cycle metrics (see GWLosses.cycle_loss)
  • \n
  • Translation metrics (see GWLosses.translation_loss)
  • \n
  • Contrastive metrics (see GWLosses.contrastive_loss)
  • \n
\n
\n\n
Arguments:
\n\n
    \n
  • domain_latents (LatentsDomainGroupsT): All latent groups
  • \n
  • mode (ModelModeT): model mode
  • \n
\n\n
Returns:
\n\n
\n

LossOutput: the losses

\n
\n", "signature": "(\tself,\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tmode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.losses.generate_partitions": {"fullname": "shimmer.modules.losses.generate_partitions", "modulename": "shimmer.modules.losses", "qualname": "generate_partitions", "kind": "function", "doc": "

Generates all possible partitions of zeros and ones for n elements,\nexcluding the all-zeros partition.

\n\n
Arguments:
\n\n
    \n
  • n (int): The number of modalities to generate partitions for.
  • \n
\n\n
Yields:
\n\n
\n

tuple[int, ...]: A partition of zeros and ones, excluding the\n all-zeros partition.

\n
\n", "signature": "(n: int) -> collections.abc.Generator[tuple[int, ...], None, None]:", "funcdef": "def"}, "shimmer.modules.losses.broadcast_loss": {"fullname": "shimmer.modules.losses.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "broadcast_loss", "kind": "function", "doc": "

Computes broadcast loss including demi-cycle, cycle, and translation losses.

\n\n
Arguments:
\n\n
    \n
  • gw_mod (shimmer.modules.gw_module.GWModuleBase): The GWModule to use
  • \n
  • selection_mod (shimmer.modules.selection.SelectionBase): Selection mod to use
  • \n
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • \n
  • latent_domains: The latent domain representations.
  • \n
\n\n
Returns:
\n\n
\n

A dictionary with the total loss and additional metrics.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule],\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.BroadcastLossCoefs": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs", "kind": "class", "doc": "

Dict of loss coefficients used in the GWLossesFusion.

\n\n

If one is not provided, the coefficient is assumed to be 0 and will not be logged.\nIf the loss is excplicitely set to 0, it will be logged, but not take part in\nthe total loss.

\n", "bases": "typing.TypedDict"}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.contrastives", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.contrastives", "kind": "variable", "doc": "

Contrastive loss coefficient.

\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.fused", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.fused", "kind": "variable", "doc": "

fused loss coefficient (encode multiple domains and decode to one of them).

\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.demi_cycles", "kind": "variable", "doc": "

demi_cycles loss coefficient. Demi-cycles are always one-to-one

\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.cycles", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.cycles", "kind": "variable", "doc": "

cycles loss coefficient. Cycles can be many-to-one

\n", "annotation": ": float"}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"fullname": "shimmer.modules.losses.BroadcastLossCoefs.translations", "modulename": "shimmer.modules.losses", "qualname": "BroadcastLossCoefs.translations", "kind": "variable", "doc": "

translation loss coefficient. Translation, like cycles, can be many-to-one.

\n", "annotation": ": float"}, "shimmer.modules.losses.GWLosses": {"fullname": "shimmer.modules.losses.GWLosses", "modulename": "shimmer.modules.losses", "qualname": "GWLosses", "kind": "class", "doc": "

Implementation of GWLossesBase for fusion-based models.

\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLosses.__init__": {"fullname": "shimmer.modules.losses.GWLosses.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.__init__", "kind": "function", "doc": "

Initializes the loss computation module for a Global Workspace Fusion model.

\n\n
Arguments:
\n\n
    \n
  • gw_mod: The GWModule for the global workspace.
  • \n
  • selection_mod: The selection mechanism for the model.
  • \n
  • domain_mods: A mapping of domain names to their respective DomainModule.
  • \n
  • loss_coefs (BroadcastLossCoefs): coefs for the losses
  • \n
  • contrastive_fn: The function used for computing contrastive loss.
  • \n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModule,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: dict[str, shimmer.modules.domain.DomainModule],\tloss_coefs: shimmer.modules.losses.BroadcastLossCoefs,\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput])"}, "shimmer.modules.losses.GWLosses.gw_mod": {"fullname": "shimmer.modules.losses.GWLosses.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.gw_mod", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses.selection_mod": {"fullname": "shimmer.modules.losses.GWLosses.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.selection_mod", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses.domain_mods": {"fullname": "shimmer.modules.losses.GWLosses.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.domain_mods", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses.loss_coefs": {"fullname": "shimmer.modules.losses.GWLosses.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.loss_coefs", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLosses.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.contrastive_fn", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLosses.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.contrastive_loss", "kind": "function", "doc": "

Computes the contrastive loss for the given latent domains.

\n\n
Arguments:
\n\n
    \n
  • latent_domains: The latent domain representations.
  • \n
\n\n
Returns:
\n\n
\n

A dictionary of contrastive loss metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"fullname": "shimmer.modules.losses.GWLosses.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.broadcast_loss", "kind": "function", "doc": "

\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLosses.step": {"fullname": "shimmer.modules.losses.GWLosses.step", "modulename": "shimmer.modules.losses", "qualname": "GWLosses.step", "kind": "function", "doc": "

Performs a step of loss computation.

\n\n
Arguments:
\n\n
    \n
  • domain_latents: Latent representations for all domains.
  • \n
  • mode: The mode in which the model is currently operating.
  • \n
\n\n
Returns:
\n\n
\n

A LossOutput object containing the loss and metrics for this step.

\n
\n", "signature": "(\tself,\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tmode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian": {"fullname": "shimmer.modules.losses.GWLossesBayesian", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian", "kind": "class", "doc": "

Implementation of GWLossesBase used for GWModuleBayesian.

\n", "bases": "GWLossesBase"}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"fullname": "shimmer.modules.losses.GWLossesBayesian.__init__", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.__init__", "kind": "function", "doc": "

Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian

\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBayesian): the GWModule
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • domain_mods (dict[str, DomainModule]): a dict where the key is the\ndomain name and value is the DomainModule
  • \n
  • loss_coefs (BroadcastLossCoefs): loss coefficients
  • \n
  • contrastive_fn (ContrastiveLossType): the contrastive function\nto use in contrastive loss
  • \n
  • use_normalized_constrastive (bool): whether to use the normalized cont\nloss by the precision coefs
  • \n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBayesian,\tselection_mod: shimmer.modules.selection.SelectionBase,\tdomain_mods: dict[str, shimmer.modules.domain.DomainModule],\tloss_coefs: shimmer.modules.losses.BroadcastLossCoefs,\tcontrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput],\tuse_normalized_constrastive: bool = True)"}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"fullname": "shimmer.modules.losses.GWLossesBayesian.gw_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.gw_mod", "kind": "variable", "doc": "

The GWModule.

\n"}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"fullname": "shimmer.modules.losses.GWLossesBayesian.selection_mod", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.selection_mod", "kind": "variable", "doc": "

Selection module

\n"}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"fullname": "shimmer.modules.losses.GWLossesBayesian.domain_mods", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.domain_mods", "kind": "variable", "doc": "

Domain modules linked to the GW.

\n"}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"fullname": "shimmer.modules.losses.GWLossesBayesian.loss_coefs", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.loss_coefs", "kind": "variable", "doc": "

The loss coefficients.

\n"}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"fullname": "shimmer.modules.losses.GWLossesBayesian.contrastive_fn", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.contrastive_fn", "kind": "variable", "doc": "

Contrastive loss to use.

\n"}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"fullname": "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.use_normalized_constrastive", "kind": "variable", "doc": "

\n"}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"fullname": "shimmer.modules.losses.GWLossesBayesian.contrastive_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.contrastive_loss", "kind": "function", "doc": "

Contrastive loss.

\n\n
Arguments:
\n\n
    \n
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: a dict of metrics.

\n
\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"fullname": "shimmer.modules.losses.GWLossesBayesian.broadcast_loss", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.broadcast_loss", "kind": "function", "doc": "

\n", "signature": "(\tself,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.losses.GWLossesBayesian.step": {"fullname": "shimmer.modules.losses.GWLossesBayesian.step", "modulename": "shimmer.modules.losses", "qualname": "GWLossesBayesian.step", "kind": "function", "doc": "

Performs a step of loss computation.

\n\n
Arguments:
\n\n
    \n
  • domain_latents: Latent representations for all domains.
  • \n
  • mode: The mode in which the model is currently operating.
  • \n
\n\n
Returns:
\n\n
\n

A LossOutput object containing the loss and metrics for this step.

\n
\n", "signature": "(\tself,\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tmode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.modules.contrastive_loss": {"fullname": "shimmer.modules.contrastive_loss", "modulename": "shimmer.modules.contrastive_loss", "kind": "module", "doc": "

Various contrastive loss definitions

\n"}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLossType", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLossType", "kind": "variable", "doc": "

Contrastive loss function type.

\n\n

A function taking the prediction and targets and returning a LossOutput.

\n", "default_value": "collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]"}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLossBayesianType", "kind": "variable", "doc": "

Contrastive loss function type for GlobalWorkspaceBayesian.

\n\n

A function taking the prediction mean, prediction std, target mean and target std and\n returns a LossOutput.

\n", "default_value": "collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]"}, "shimmer.modules.contrastive_loss.info_nce": {"fullname": "shimmer.modules.contrastive_loss.info_nce", "modulename": "shimmer.modules.contrastive_loss", "qualname": "info_nce", "kind": "function", "doc": "

InfoNCE loss

\n\n
Arguments:
\n\n
    \n
  • x (torch.Tensor): prediction
  • \n
  • y (torch.Tensor): target
  • \n
  • logit_scale (torch.Tensor): logit scale
  • \n
  • reduction (Literal[\"mean\", \"sum\", \"none\"]): reduction to apply
  • \n
\n\n

Returns: the InfoNCE loss

\n", "signature": "(\tx: torch.Tensor,\ty: torch.Tensor,\tlogit_scale: torch.Tensor,\treduction: Literal['mean', 'sum', 'none'] = 'mean') -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.contrastive_loss.contrastive_loss": {"fullname": "shimmer.modules.contrastive_loss.contrastive_loss", "modulename": "shimmer.modules.contrastive_loss", "qualname": "contrastive_loss", "kind": "function", "doc": "

CLIP-like contrastive loss

\n\n
Arguments:
\n\n
    \n
  • x (torch.Tensor): prediction
  • \n
  • y (torch.Tensor): target
  • \n
  • logit_scale (torch.Tensor): logit scale
  • \n
  • reduction (Literal[\"mean\", \"sum\", \"none\"]): reduction to apply
  • \n
\n\n

Returns: the contrastive loss

\n", "signature": "(\tx: torch.Tensor,\ty: torch.Tensor,\tlogit_scale: torch.Tensor,\treduction: Literal['mean', 'sum', 'none'] = 'mean') -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss", "kind": "class", "doc": "

CLIP-like ContrastiveLoss torch module.

\n", "bases": "torch.nn.modules.module.Module"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.__init__", "kind": "function", "doc": "

Initializes a contrastive loss.

\n\n
Arguments:
\n\n
    \n
  • logit_scale (torch.Tensor): logit_scale tensor.
  • \n
  • reduction (Literal[\"mean\", \"sum\", \"none\"]): reduction to apply to the\nloss. Defaults to \"mean\".
  • \n
  • learn_logit_scale (torch.Tensor): whether to learn the logit_scale\nparameter. Defaults to False.
  • \n
\n", "signature": "(\tlogit_scale: torch.Tensor,\treduction: Literal['mean', 'sum', 'none'] = 'mean',\tlearn_logit_scale: bool = False)"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.learn_logit_scale", "kind": "variable", "doc": "

\n"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.reduction", "kind": "variable", "doc": "

\n", "annotation": ": Literal['mean', 'sum', 'none']"}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"fullname": "shimmer.modules.contrastive_loss.ContrastiveLoss.forward", "modulename": "shimmer.modules.contrastive_loss", "qualname": "ContrastiveLoss.forward", "kind": "function", "doc": "

Computes the loss.

\n\n
Arguments:
\n\n
    \n
  • x (torch.Tensor): prediction
  • \n
  • y (torch.Tensor): target
  • \n
\n\n
Returns:
\n\n
\n

LossOutput of the loss. Contains a logit_scale metric.

\n
\n", "signature": "(\tself,\tx: torch.Tensor,\ty: torch.Tensor) -> shimmer.modules.domain.LossOutput:", "funcdef": "def"}, "shimmer.dataset": {"fullname": "shimmer.dataset", "modulename": "shimmer.dataset", "kind": "module", "doc": "

\n"}, "shimmer.dataset.RepeatedDataset": {"fullname": "shimmer.dataset.RepeatedDataset", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset", "kind": "class", "doc": "

Dataset that cycles through its items to have a size of at least min size.\nIf drop_last is True, the size will be exaclty min_size. If drop_last is False,\nthe min_size \u2264 size < min_size + len(dataset).

\n", "bases": "typing.Generic[+T_co]"}, "shimmer.dataset.RepeatedDataset.__init__": {"fullname": "shimmer.dataset.RepeatedDataset.__init__", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.__init__", "kind": "function", "doc": "
Arguments:
\n\n
    \n
  • dataset (SizedDataset): dataset to repeat. The dataset should have a size\n(where __len__ is defined).
  • \n
  • min_size (int): minimum size of the final dataset
  • \n
  • drop_last (bool): whether to remove overflow when repeating the\ndataset.
  • \n
\n", "signature": "(\tdataset: shimmer.dataset._SizedDataset,\tmin_size: int,\tdrop_last: bool = False)"}, "shimmer.dataset.RepeatedDataset.dataset": {"fullname": "shimmer.dataset.RepeatedDataset.dataset", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.dataset", "kind": "variable", "doc": "

\n"}, "shimmer.dataset.RepeatedDataset.dataset_size": {"fullname": "shimmer.dataset.RepeatedDataset.dataset_size", "modulename": "shimmer.dataset", "qualname": "RepeatedDataset.dataset_size", "kind": "variable", "doc": "

\n"}, "shimmer.modules.vae": {"fullname": "shimmer.modules.vae", "modulename": "shimmer.modules.vae", "kind": "module", "doc": "

\n"}, "shimmer.modules.vae.reparameterize": {"fullname": "shimmer.modules.vae.reparameterize", "modulename": "shimmer.modules.vae", "qualname": "reparameterize", "kind": "function", "doc": "

Reparameterization trick for VAE

\n\n
Arguments:
\n\n
    \n
  • mean (torch.Tensor): predicted means
  • \n
  • logvar (torch.Tensor): predicted log variance
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: a sample from normal distribution with provided\n parameters, sampled using the reparameterization trick.

\n
\n", "signature": "(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.vae.kl_divergence_loss": {"fullname": "shimmer.modules.vae.kl_divergence_loss", "modulename": "shimmer.modules.vae", "qualname": "kl_divergence_loss", "kind": "function", "doc": "

Computes the KL divergence loss used in VAE.

\n\n
Arguments:
\n\n
    \n
  • mean (torch.Tensor): predicted means
  • \n
  • logvar (torch.Tensor): predicted logvars
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the loss

\n
\n", "signature": "(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.vae.gaussian_nll": {"fullname": "shimmer.modules.vae.gaussian_nll", "modulename": "shimmer.modules.vae", "qualname": "gaussian_nll", "kind": "function", "doc": "

Computes gaussian nll loss used in VAE.

\n\n
Arguments:
\n\n
    \n
  • mu (torch.Tensor): predictions
  • \n
  • log_sigma (torch.Tensor): log sigma
  • \n
  • x (torch.Tensor): targets
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the Gaussian NLL loss

\n
\n", "signature": "(\tmu: torch.Tensor,\tlog_sigma: torch.Tensor,\tx: torch.Tensor) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.vae.VAEEncoder": {"fullname": "shimmer.modules.vae.VAEEncoder", "modulename": "shimmer.modules.vae", "qualname": "VAEEncoder", "kind": "class", "doc": "

Base class for a VAE encoder.

\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.vae.VAEEncoder.forward": {"fullname": "shimmer.modules.vae.VAEEncoder.forward", "modulename": "shimmer.modules.vae", "qualname": "VAEEncoder.forward", "kind": "function", "doc": "

Encode representation with VAE.

\n\n
Arguments:
\n\n
    \n
  • x (Any): Some input value
  • \n
\n\n
Returns:
\n\n
\n

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

\n
\n", "signature": "(self, x: Any) -> tuple[torch.Tensor, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.vae.VAEDecoder": {"fullname": "shimmer.modules.vae.VAEDecoder", "modulename": "shimmer.modules.vae", "qualname": "VAEDecoder", "kind": "class", "doc": "

Base class for a VAE decoder.

\n", "bases": "torch.nn.modules.module.Module, abc.ABC"}, "shimmer.modules.vae.VAEDecoder.forward": {"fullname": "shimmer.modules.vae.VAEDecoder.forward", "modulename": "shimmer.modules.vae", "qualname": "VAEDecoder.forward", "kind": "function", "doc": "

Decode representation with VAE

\n\n
Arguments:
\n\n
    \n
  • x (torch.Tensor): VAE latent representation representation
  • \n
\n\n
Returns:
\n\n
\n

Any: the reconstructed input

\n
\n", "signature": "(self, x: torch.Tensor) -> Any:", "funcdef": "def"}, "shimmer.modules.vae.VAE": {"fullname": "shimmer.modules.vae.VAE", "modulename": "shimmer.modules.vae", "qualname": "VAE", "kind": "class", "doc": "

VAE module

\n", "bases": "torch.nn.modules.module.Module"}, "shimmer.modules.vae.VAE.__init__": {"fullname": "shimmer.modules.vae.VAE.__init__", "modulename": "shimmer.modules.vae", "qualname": "VAE.__init__", "kind": "function", "doc": "

Initializes a VAE.

\n\n
Arguments:
\n\n
    \n
  • encoder (VAEEncoder): VAE encode
  • \n
  • decoder (VAEDecoder): VAE decoder
  • \n
  • beta (float): beta value for Beta-VAE. Defaults to 1.
  • \n
\n", "signature": "(\tencoder: shimmer.modules.vae.VAEEncoder,\tdecoder: shimmer.modules.vae.VAEDecoder,\tbeta: float = 1)"}, "shimmer.modules.vae.VAE.beta": {"fullname": "shimmer.modules.vae.VAE.beta", "modulename": "shimmer.modules.vae", "qualname": "VAE.beta", "kind": "variable", "doc": "

Beta value for Beta-VAEs

\n"}, "shimmer.modules.vae.VAE.encoder": {"fullname": "shimmer.modules.vae.VAE.encoder", "modulename": "shimmer.modules.vae", "qualname": "VAE.encoder", "kind": "variable", "doc": "

The encoder

\n"}, "shimmer.modules.vae.VAE.decoder": {"fullname": "shimmer.modules.vae.VAE.decoder", "modulename": "shimmer.modules.vae", "qualname": "VAE.decoder", "kind": "variable", "doc": "

The decoder

\n"}, "shimmer.modules.vae.VAE.encode": {"fullname": "shimmer.modules.vae.VAE.encode", "modulename": "shimmer.modules.vae", "qualname": "VAE.encode", "kind": "function", "doc": "

Encode the representation and returns the mean prediction of VAE.

\n\n
Arguments:
\n\n
    \n
  • x (Any): Some input value
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: The mean representation.

\n
\n", "signature": "(self, x: Any) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.vae.VAE.decode": {"fullname": "shimmer.modules.vae.VAE.decode", "modulename": "shimmer.modules.vae", "qualname": "VAE.decode", "kind": "function", "doc": "

Decode the VAE latent representation into input value.

\n\n
Arguments:
\n\n
    \n
  • z (torch.Tensor): the VAE latent representation.
  • \n
\n\n
Returns:
\n\n
\n

Any: the reconstructed input.

\n
\n", "signature": "(self, z: torch.Tensor) -> Any:", "funcdef": "def"}, "shimmer.modules.vae.VAE.forward": {"fullname": "shimmer.modules.vae.VAE.forward", "modulename": "shimmer.modules.vae", "qualname": "VAE.forward", "kind": "function", "doc": "

Encode and decodes from x.

\n\n
Arguments:
\n\n
    \n
  • x (Any): the input data
  • \n
\n\n
Returns:
\n\n
\n

tuple[tuple[torch.Tensor, torch.Tensor], Any]: The\n first tuple contains the mean and logvar of the encoded input,\n the second item is the reconstructed input.

\n
\n", "signature": "(self, x: Any) -> tuple[tuple[torch.Tensor, torch.Tensor], typing.Any]:", "funcdef": "def"}, "shimmer.modules.utils": {"fullname": "shimmer.modules.utils", "modulename": "shimmer.modules.utils", "kind": "module", "doc": "

\n"}, "shimmer.modules.utils.translation": {"fullname": "shimmer.modules.utils.translation", "modulename": "shimmer.modules.utils", "qualname": "translation", "kind": "function", "doc": "

Translate from multiple domains to one domain.

\n\n
Arguments:
\n\n
    \n
  • gw_module (GWModuleBase): GWModule to perform the translation over
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • x (LatentsDomainGroupT): the group of latent representations
  • \n
  • to (str): the domain name to encode to
  • \n
\n\n
Returns:
\n\n
\n

torch.Tensor: the translated unimodal representation\n of the provided domain.

\n
\n", "signature": "(\tgw_module: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tx: collections.abc.Mapping[str, torch.Tensor],\tto: str) -> torch.Tensor:", "funcdef": "def"}, "shimmer.modules.utils.cycle": {"fullname": "shimmer.modules.utils.cycle", "modulename": "shimmer.modules.utils", "qualname": "cycle", "kind": "function", "doc": "

Do a full cycle from a group of representation through one domain.

\n\n

[Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]

\n\n
Arguments:
\n\n
    \n
  • gw_module (GWModuleBase): GWModule to perform the translation over
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • x (LatentsDomainGroupT): group of unimodal latent representation
  • \n
  • through (str): domain name to cycle through
  • \n
\n\n
Returns:
\n\n
\n

LatentsDomainGroupDT: group of unimodal latent representation after\n cycling.

\n
\n", "signature": "(\tgw_module: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tx: collections.abc.Mapping[str, torch.Tensor],\tthrough: str) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.utils.batch_demi_cycles": {"fullname": "shimmer.modules.utils.batch_demi_cycles", "modulename": "shimmer.modules.utils", "qualname": "batch_demi_cycles", "kind": "function", "doc": "

Computes demi-cycles of a batch of groups of domains.

\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): the GWModuleBase
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • latent_domains (LatentsT): the batch of groups of domains
  • \n
\n\n
Returns:
\n\n
\n

dict[str, torch.Tensor]: demi-cycles predictions for each domain.

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.utils.batch_cycles": {"fullname": "shimmer.modules.utils.batch_cycles", "modulename": "shimmer.modules.utils", "qualname": "batch_cycles", "kind": "function", "doc": "

Computes cycles of a batch of groups of domains.

\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): GWModule to use for the cycle
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • latent_domains (LatentsT): the batch of groups of domains
  • \n
  • out_domains (Iterable[str]): iterable of domain names to do the cycle through.\nEach domain will be done separetely.
  • \n
\n\n
Returns:
\n\n
\n

dict[tuple[str, str], torch.Tensor]: cycles predictions for each\n couple of (start domain, intermediary domain).

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]],\tthrough_domains: collections.abc.Iterable[str]) -> dict[tuple[str, str], torch.Tensor]:", "funcdef": "def"}, "shimmer.modules.utils.batch_translations": {"fullname": "shimmer.modules.utils.batch_translations", "modulename": "shimmer.modules.utils", "qualname": "batch_translations", "kind": "function", "doc": "

Computes translations of a batch of groups of domains.

\n\n
Arguments:
\n\n
    \n
  • gw_mod (GWModuleBase): GWModule to do the translation
  • \n
  • selection_mod (SelectionBase): selection module
  • \n
  • latent_domains (LatentsT): the batch of groups of domains
  • \n
\n\n
Returns:
\n\n
\n

dict[tuple[str, str], torch.Tensor]: translation predictions for each\n couple of (start domain, target domain).

\n
\n", "signature": "(\tgw_mod: shimmer.modules.gw_module.GWModuleBase,\tselection_mod: shimmer.modules.selection.SelectionBase,\tlatent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[tuple[str, str], torch.Tensor]:", "funcdef": "def"}, "shimmer.utils": {"fullname": "shimmer.utils", "modulename": "shimmer.utils", "kind": "module", "doc": "

\n"}, "shimmer.utils.MIGRATION_DIR": {"fullname": "shimmer.utils.MIGRATION_DIR", "modulename": "shimmer.utils", "qualname": "MIGRATION_DIR", "kind": "variable", "doc": "

\n", "default_value": "PosixPath('/home/runner/work/shimmer/shimmer/shimmer/ckpt_migrations')"}, "shimmer.utils.group_batch_size": {"fullname": "shimmer.utils.group_batch_size", "modulename": "shimmer.utils", "qualname": "group_batch_size", "kind": "function", "doc": "

\n", "signature": "(x: collections.abc.Mapping[str, torch.Tensor]) -> int:", "funcdef": "def"}, "shimmer.utils.groups_batch_size": {"fullname": "shimmer.utils.groups_batch_size", "modulename": "shimmer.utils", "qualname": "groups_batch_size", "kind": "function", "doc": "

Get the batch size of the batch.

\n\n
Arguments:
\n\n
    \n
  • domain_latents (LatentsDomainGroupsT): the batch of groups.
  • \n
\n\n
Returns:
\n\n
\n

int: the batch size.

\n
\n", "signature": "(\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> int:", "funcdef": "def"}, "shimmer.utils.groups_device": {"fullname": "shimmer.utils.groups_device", "modulename": "shimmer.utils", "qualname": "groups_device", "kind": "function", "doc": "

Get the batch size of the batch.

\n\n
Arguments:
\n\n
    \n
  • domain_latents (LatentsDomainGroupsT): the batch of groups.
  • \n
\n\n
Returns:
\n\n
\n

int: the batch size.

\n
\n", "signature": "(\tdomain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> int:", "funcdef": "def"}, "shimmer.utils.group_device": {"fullname": "shimmer.utils.group_device", "modulename": "shimmer.utils", "qualname": "group_device", "kind": "function", "doc": "

\n", "signature": "(x: collections.abc.Mapping[str, torch.Tensor]) -> torch.device:", "funcdef": "def"}, "shimmer.utils.migrate_model": {"fullname": "shimmer.utils.migrate_model", "modulename": "shimmer.utils", "qualname": "migrate_model", "kind": "function", "doc": "

Migrates a model checkpoint

\n\n

After the migration, the given checkpoint will be migrated.\nOther versions of the checkpoint will be saved under the stem-version.suffix.

\n\n
Arguments:
\n\n
    \n
  • ckpt_path (str | PathLike): path to checkpoint
  • \n
  • torch_load_kwargs: additional args given to torch.load.
  • \n
\n", "signature": "(ckpt_path: str | os.PathLike, **torch_load_kwargs):", "funcdef": "def"}, "shimmer.utils.SaveMigrations": {"fullname": "shimmer.utils.SaveMigrations", "modulename": "shimmer.utils", "qualname": "SaveMigrations", "kind": "class", "doc": "

Abstract base class used to build new callbacks.

\n\n

Subclass this class and override any of the relevant hooks

\n", "bases": "lightning.pytorch.callbacks.callback.Callback"}, "shimmer.utils.SaveMigrations.migrations": {"fullname": "shimmer.utils.SaveMigrations.migrations", "modulename": "shimmer.utils", "qualname": "SaveMigrations.migrations", "kind": "variable", "doc": "

\n"}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"fullname": "shimmer.utils.SaveMigrations.on_save_checkpoint", "modulename": "shimmer.utils", "qualname": "SaveMigrations.on_save_checkpoint", "kind": "function", "doc": "

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

\n\n
Arguments:
\n\n
    \n
  • trainer: the current ~lightning.pytorch.trainer.trainer.Trainer instance.
  • \n
  • pl_module: the current ~lightning.pytorch.core.LightningModule instance.
  • \n
  • checkpoint: the checkpoint dictionary that will be saved.
  • \n
\n", "signature": "(\tself,\ttrainer: lightning.pytorch.trainer.trainer.Trainer,\tpl_module: lightning.pytorch.core.module.LightningModule,\tcheckpoint: dict[str, typing.Any]):", "funcdef": "def"}, "shimmer.cli.ckpt_migration": {"fullname": "shimmer.cli.ckpt_migration", "modulename": "shimmer.cli.ckpt_migration", "kind": "module", "doc": "

\n"}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"fullname": "shimmer.cli.ckpt_migration.migrate_ckpt", "modulename": "shimmer.cli.ckpt_migration", "qualname": "migrate_ckpt", "kind": "variable", "doc": "

Script to migrate a list of checkpoints.\nThis can be called with:

\n\n
\n
shimmer migrate-ckpt PATH_1 PATH_2 ... PATH_N\n
\n
\n\n

where paths point to checkpoints.

\n\n

Internally, this calls shimmer.utils.migrate_model for each of the given paths.

\n", "default_value": "<Command migrate-ckpt>"}}, "docInfo": {"shimmer.types": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.types.RawDomainGroupT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 5, "signature": 0, "bases": 0, "doc": 210}, "shimmer.types.RawDomainGroupDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 3, "signature": 0, "bases": 0, "doc": 165}, "shimmer.types.LatentsDomainGroupT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 5, "signature": 0, "bases": 0, "doc": 234}, "shimmer.types.LatentsDomainGroupDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 3, "signature": 0, "bases": 0, "doc": 198}, "shimmer.types.RawDomainGroupsT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 8, "signature": 0, "bases": 0, "doc": 297}, "shimmer.types.RawDomainGroupsDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 4, "signature": 0, "bases": 0, "doc": 297}, "shimmer.types.LatentsDomainGroupsT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 8, "signature": 0, "bases": 0, "doc": 401}, "shimmer.types.LatentsDomainGroupsDT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 4, "signature": 0, "bases": 0, "doc": 365}, "shimmer.types.ModelModeT": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 18, "signature": 0, "bases": 0, "doc": 27}, "shimmer.modules.global_workspace": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.SchedulerArgs": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 10}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 5}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.global_workspace.GWPredictionsBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 13}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"qualname": 2, "fullname": 6, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 20}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 18, "doc": 20}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 8}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 8}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 10}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"qualname": 3, "fullname": 7, "annotation": 8, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"qualname": 3, "fullname": 7, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 125, "bases": 0, "doc": 65}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 108, "bases": 0, "doc": 52}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 161, "bases": 0, "doc": 71}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 128, "bases": 0, "doc": 68}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 55}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 39, "bases": 0, "doc": 122}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 108, "bases": 0, "doc": 59}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 39, "bases": 0, "doc": 122}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 109, "bases": 0, "doc": 59}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 143, "bases": 0, "doc": 75}, "shimmer.modules.global_workspace.freeze_domain_modules": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 85, "bases": 0, "doc": 81}, "shimmer.modules.global_workspace.GWPredictions": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 13}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"qualname": 3, "fullname": 7, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 21}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"qualname": 2, "fullname": 6, "annotation": 5, "default_value": 0, "signature": 0, "bases": 0, "doc": 34}, "shimmer.modules.global_workspace.GWPredictions.translations": {"qualname": 2, "fullname": 6, "annotation": 5, "default_value": 0, "signature": 0, "bases": 0, "doc": 37}, "shimmer.modules.global_workspace.GWPredictions.states": {"qualname": 2, "fullname": 6, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 33}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 380, "bases": 0, "doc": 250}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.GlobalWorkspace": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 35}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 400, "bases": 0, "doc": 271}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 17, "doc": 37}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 459, "bases": 0, "doc": 318}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 94, "bases": 0, "doc": 51}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 317, "bases": 0, "doc": 264}, "shimmer.modules.domain": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.domain.LossOutput": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 37}, "shimmer.modules.domain.LossOutput.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 61, "bases": 0, "doc": 3}, "shimmer.modules.domain.LossOutput.loss": {"qualname": 2, "fullname": 5, "annotation": 3, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.domain.LossOutput.metrics": {"qualname": 2, "fullname": 5, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 12}, "shimmer.modules.domain.LossOutput.all": {"qualname": 2, "fullname": 5, "annotation": 4, "default_value": 0, "signature": 0, "bases": 0, "doc": 14}, "shimmer.modules.domain.DomainModule": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 16}, "shimmer.modules.domain.DomainModule.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 15, "bases": 0, "doc": 29}, "shimmer.modules.domain.DomainModule.latent_dim": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 9}, "shimmer.modules.domain.DomainModule.encode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.domain.DomainModule.decode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 53}, "shimmer.modules.domain.DomainModule.compute_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 61}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 75}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 73}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 73}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 74}, "shimmer.modules.gw_module": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.get_n_layers": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 57, "bases": 0, "doc": 76}, "shimmer.modules.gw_module.GWDecoder": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 8}, "shimmer.modules.gw_module.GWDecoder.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 48, "bases": 0, "doc": 81}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 24}, "shimmer.modules.gw_module.GWEncoder": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 27}, "shimmer.modules.gw_module.GWEncoder.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 48, "bases": 0, "doc": 81}, "shimmer.modules.gw_module.GWEncoder.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 34, "bases": 0, "doc": 67}, "shimmer.modules.gw_module.GWEncoderLinear": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 10}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 34, "bases": 0, "doc": 67}, "shimmer.modules.gw_module.GWModuleBase": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 58}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 81, "bases": 0, "doc": 43}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 69}, "shimmer.modules.gw_module.GWModuleBase.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 70, "bases": 0, "doc": 50}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 85, "bases": 0, "doc": 78}, "shimmer.modules.gw_module.GWModuleBase.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 88, "bases": 0, "doc": 65}, "shimmer.modules.gw_module.GWModule": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 10}, "shimmer.modules.gw_module.GWModule.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 173, "bases": 0, "doc": 117}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.gw_module.GWModule.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 69}, "shimmer.modules.gw_module.GWModule.encode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 70, "bases": 0, "doc": 50}, "shimmer.modules.gw_module.GWModule.decode": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 88, "bases": 0, "doc": 73}, "shimmer.modules.gw_module.compute_fusion_scores": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 107, "bases": 0, "doc": 119}, "shimmer.modules.gw_module.GWModuleBayesian": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 12}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 232, "bases": 0, "doc": 163}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 11}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 57}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 97, "bases": 0, "doc": 296}, "shimmer.modules.selection": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.SelectionBase": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 29}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 30, "bases": 0, "doc": 66}, "shimmer.modules.selection.SelectionBase.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 202}, "shimmer.modules.selection.SingleDomainSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 48}, "shimmer.modules.selection.SingleDomainSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 69}, "shimmer.modules.selection.FixedSharedSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 35}, "shimmer.modules.selection.FixedSharedSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 69}, "shimmer.modules.selection.KQFixedQSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 11}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 56, "bases": 0, "doc": 62}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.KQFixedQSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 89}, "shimmer.modules.selection.RandomSelection": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 37}, "shimmer.modules.selection.RandomSelection.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 14, "bases": 0, "doc": 26}, "shimmer.modules.selection.RandomSelection.temperature": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.RandomSelection.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 88}, "shimmer.modules.selection.DynamicQueryAttention": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 21}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 56, "bases": 0, "doc": 62}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 87, "bases": 0, "doc": 69}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 110, "bases": 0, "doc": 89}, "shimmer.modules.losses": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBase": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 30}, "shimmer.modules.losses.GWLossesBase.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 57}, "shimmer.modules.losses.demi_cycle_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 175}, "shimmer.modules.losses.cycle_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 172}, "shimmer.modules.losses.translation_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 159}, "shimmer.modules.losses.contrastive_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 155}, "shimmer.modules.losses.contrastive_loss_bayesian": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 161}, "shimmer.modules.losses.LossCoefs": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 51}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.losses.LossCoefs.cycles": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.LossCoefs.translations": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.LossCoefs.contrastives": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.GWLosses2Domains": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 13}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 190, "bases": 0, "doc": 107}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 60}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 58}, "shimmer.modules.losses.GWLosses2Domains.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 109}, "shimmer.modules.losses.generate_partitions": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 58, "bases": 0, "doc": 71}, "shimmer.modules.losses.broadcast_loss": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 193, "bases": 0, "doc": 99}, "shimmer.modules.losses.BroadcastLossCoefs": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 2, "doc": 51}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 15}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 13}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 11}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"qualname": 2, "fullname": 5, "annotation": 2, "default_value": 0, "signature": 0, "bases": 0, "doc": 14}, "shimmer.modules.losses.GWLosses": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 12}, "shimmer.modules.losses.GWLosses.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 190, "bases": 0, "doc": 91}, "shimmer.modules.losses.GWLosses.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 46}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLosses.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 64}, "shimmer.modules.losses.GWLossesBayesian": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 1, "doc": 13}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 210, "bases": 0, "doc": 120}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 5}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 9}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"qualname": 4, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 45}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 95, "bases": 0, "doc": 3}, "shimmer.modules.losses.GWLossesBayesian.step": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 64}, "shimmer.modules.contrastive_loss": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 6}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 10, "signature": 0, "bases": 0, "doc": 21}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 14, "signature": 0, "bases": 0, "doc": 29}, "shimmer.modules.contrastive_loss.info_nce": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 115, "bases": 0, "doc": 68}, "shimmer.modules.contrastive_loss.contrastive_loss": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 115, "bases": 0, "doc": 70}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"qualname": 1, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 8}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"qualname": 3, "fullname": 7, "annotation": 0, "default_value": 0, "signature": 93, "bases": 0, "doc": 83}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"qualname": 4, "fullname": 8, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"qualname": 2, "fullname": 6, "annotation": 12, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 62, "bases": 0, "doc": 56}, "shimmer.dataset": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.dataset.RepeatedDataset": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 3, "doc": 46}, "shimmer.dataset.RepeatedDataset.__init__": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 57, "bases": 0, "doc": 63}, "shimmer.dataset.RepeatedDataset.dataset": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.dataset.RepeatedDataset.dataset_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.vae": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.vae.reparameterize": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 65}, "shimmer.modules.vae.kl_divergence_loss": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 44, "bases": 0, "doc": 57}, "shimmer.modules.vae.gaussian_nll": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 63, "bases": 0, "doc": 69}, "shimmer.modules.vae.VAEEncoder": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 9}, "shimmer.modules.vae.VAEEncoder.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 46, "bases": 0, "doc": 46}, "shimmer.modules.vae.VAEDecoder": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 7, "doc": 9}, "shimmer.modules.vae.VAEDecoder.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 42}, "shimmer.modules.vae.VAE": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 4}, "shimmer.modules.vae.VAE.__init__": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 74, "bases": 0, "doc": 53}, "shimmer.modules.vae.VAE.beta": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 7}, "shimmer.modules.vae.VAE.encoder": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.vae.VAE.decoder": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 4}, "shimmer.modules.vae.VAE.encode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.vae.VAE.decode": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 29, "bases": 0, "doc": 49}, "shimmer.modules.vae.VAE.forward": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 61, "bases": 0, "doc": 63}, "shimmer.modules.utils": {"qualname": 0, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.modules.utils.translation": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 118, "bases": 0, "doc": 96}, "shimmer.modules.utils.cycle": {"qualname": 1, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 130, "bases": 0, "doc": 111}, "shimmer.modules.utils.batch_demi_cycles": {"qualname": 3, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 144, "bases": 0, "doc": 81}, "shimmer.modules.utils.batch_cycles": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 182, "bases": 0, "doc": 115}, "shimmer.modules.utils.batch_translations": {"qualname": 2, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 154, "bases": 0, "doc": 88}, "shimmer.utils": {"qualname": 0, "fullname": 2, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.MIGRATION_DIR": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 6, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.group_batch_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 46, "bases": 0, "doc": 3}, "shimmer.utils.groups_batch_size": {"qualname": 3, "fullname": 5, "annotation": 0, "default_value": 0, "signature": 72, "bases": 0, "doc": 46}, "shimmer.utils.groups_device": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 72, "bases": 0, "doc": 46}, "shimmer.utils.group_device": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 51, "bases": 0, "doc": 3}, "shimmer.utils.migrate_model": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 37, "bases": 0, "doc": 67}, "shimmer.utils.SaveMigrations": {"qualname": 1, "fullname": 3, "annotation": 0, "default_value": 0, "signature": 0, "bases": 5, "doc": 23}, "shimmer.utils.SaveMigrations.migrations": {"qualname": 2, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"qualname": 4, "fullname": 6, "annotation": 0, "default_value": 0, "signature": 103, "bases": 0, "doc": 74}, "shimmer.cli.ckpt_migration": {"qualname": 0, "fullname": 4, "annotation": 0, "default_value": 0, "signature": 0, "bases": 0, "doc": 3}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"qualname": 2, "fullname": 6, "annotation": 0, "default_value": 7, "signature": 0, "bases": 0, "doc": 72}}, "length": 232, "save": true}, "index": {"qualname": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 4}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 4}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}}, "df": 2}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 26, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 9, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 5}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.utils.SaveMigrations.migrations": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}}}}, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}}, "df": 2}}}}}}}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}}, "df": 1}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}}, "df": 8, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 6}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 7, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}}, "df": 8}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 8}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 12}}}}}}}}, "docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}}, "df": 10, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 11}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}}, "df": 3}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 19}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 2}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 2}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "w": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 3}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 7, "r": {"docs": {"shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}}, "df": 1}}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 6}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 2}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 10}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 6}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}}, "df": 1}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 2}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 2}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 9, "r": {"docs": {"shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 7, "d": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15}}}}}}, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}}, "df": 3}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 6}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}}, "df": 4, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 8}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 6}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}}, "df": 9, "s": {"docs": {"shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}}, "df": 2}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 3}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19}}, "f": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 2, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}, "k": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 2}}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}, "fullname": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types": {"tf": 1}, "shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 232}}}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 30, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}}, "df": 2}}}}}}}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types": {"tf": 1}, "shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}}, "df": 10}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}}, "df": 1}}}}}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1}}, "df": 4}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupDT": {"tf": 1}}, "df": 1}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 1}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 4}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}}, "df": 2}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 35, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}}}}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 54}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}}, "df": 9, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 37, "s": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 205}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.utils.SaveMigrations.migrations": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}}}}, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}}, "df": 41, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}}, "df": 3}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 19}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.gw_module": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}}, "df": 43, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}}, "df": 5, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 6}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 7, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}}, "df": 8}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 8}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 12}}}}}}}}, "docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}}, "df": 10, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 11}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 2}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 2}}}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 42}}}}}}}}, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 7, "r": {"docs": {"shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}}, "df": 1}}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 6}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 24, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 2}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 10}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}}, "df": 6}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}}, "df": 1}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1.4142135623730951}}, "df": 5}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 2}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 9, "r": {"docs": {"shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 7, "d": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15}}}}}}, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}}, "df": 3}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}}, "df": 6}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}}, "df": 4, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 8}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 6}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 18, "s": {"docs": {"shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}}, "df": 2}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 3}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 19}}, "f": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}}, "df": 2, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1}}, "df": 2}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1}}, "df": 1}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.utils": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 16}}}}}, "k": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "q": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}}, "df": 6}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1}}, "df": 2}}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1}}, "df": 2}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}}, "df": 16, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}}, "df": 2}}}}}}}}}}}}, "annotation": {"root": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1.4142135623730951}}, "df": 22, "f": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 10}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}}, "df": 2}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 5}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 8}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}}, "df": 8}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}}, "x": {"2": {"7": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 2.449489742783178}}, "df": 1}, "docs": {}, "df": 0}, "docs": {}, "df": 0}, "n": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1}}, "df": 1}}}}}}, "default_value": {"root": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.utils.MIGRATION_DIR": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 3, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 6}}}}}}}}}, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 6}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 4}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}}, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}}, "df": 5}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.7320508075688772}}, "df": 6}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 2}}, "df": 6}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}, "x": {"2": {"7": {"docs": {"shimmer.types.ModelModeT": {"tf": 3.1622776601683795}, "shimmer.utils.MIGRATION_DIR": {"tf": 1.4142135623730951}}, "df": 2}, "docs": {}, "df": 0}, "docs": {}, "df": 0}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}}}}}}, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.MIGRATION_DIR": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "signature": {"root": {"0": {"0": {"1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "docs": {}, "df": 0}, "1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}, "6": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 5}, "1": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 4, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}, "2": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}}, "df": 2}, "3": {"9": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLosses.step": {"tf": 3.1622776601683795}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 3.1622776601683795}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2.8284271247461903}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2.8284271247461903}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2.8284271247461903}}, "df": 8}, "docs": {}, "df": 0}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 10.04987562112089}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 11.40175425099138}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 10.246950765959598}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 8.774964387392123}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 5.656854249492381}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 5.656854249492381}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 9.38083151964686}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 10.535653752852738}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 8.306623862918075}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 17.291616465790582}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 17.72004514666935}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 18.947295321496416}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 8.660254037844387}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 15.874507866387544}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 7.0710678118654755}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 3.4641016151377544}, "shimmer.modules.domain.DomainModule.encode": {"tf": 4.898979485566356}, "shimmer.modules.domain.DomainModule.decode": {"tf": 4.898979485566356}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 7.14142842854285}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 7.14142842854285}, "shimmer.modules.gw_module.get_n_layers": {"tf": 6.708203932499369}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 6}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 6}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 5.291502622129181}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 5.291502622129181}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 8.12403840463596}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 8.888194417315589}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 7.615773105863909}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 8.306623862918075}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 8.54400374531753}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 11.74734012447073}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 8.888194417315589}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 7.615773105863909}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 8.54400374531753}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 9.1104335791443}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 13.527749258468683}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 6}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 8.888194417315589}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 4.898979485566356}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 6.6332495807108}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 3.4641016151377544}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 9.433981132056603}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 6.6332495807108}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 8.426149773176359}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 9.433981132056603}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.cycle_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.translation_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.contrastive_loss": {"tf": 12.041594578792296}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 12.041594578792296}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 12.24744871391589}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.generate_partitions": {"tf": 7}, "shimmer.modules.losses.broadcast_loss": {"tf": 12.36931687685298}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 12.24744871391589}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLosses.step": {"tf": 10.908712114635714}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 12.84523257866513}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 10.908712114635714}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 9.433981132056603}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 9.433981132056603}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 8.366600265340756}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 7.14142842854285}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 6.782329983125268}, "shimmer.modules.vae.reparameterize": {"tf": 6}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 6}, "shimmer.modules.vae.gaussian_nll": {"tf": 7.14142842854285}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 6.164414002968976}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.__init__": {"tf": 7.810249675906654}, "shimmer.modules.vae.VAE.encode": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.decode": {"tf": 4.898979485566356}, "shimmer.modules.vae.VAE.forward": {"tf": 7.0710678118654755}, "shimmer.modules.utils.translation": {"tf": 9.695359714832659}, "shimmer.modules.utils.cycle": {"tf": 10.198039027185569}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 10.677078252031311}, "shimmer.modules.utils.batch_cycles": {"tf": 12}, "shimmer.modules.utils.batch_translations": {"tf": 11.045361017187261}, "shimmer.utils.group_batch_size": {"tf": 6.164414002968976}, "shimmer.utils.groups_batch_size": {"tf": 7.615773105863909}, "shimmer.utils.groups_device": {"tf": 7.615773105863909}, "shimmer.utils.group_device": {"tf": 6.48074069840786}, "shimmer.utils.migrate_model": {"tf": 5.385164807134504}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 9.16515138991168}}, "df": 103, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 58}, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 21, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 14}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2}, "shimmer.modules.losses.cycle_loss": {"tf": 2}, "shimmer.modules.losses.translation_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 2}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_cycles": {"tf": 2.23606797749979}, "shimmer.modules.utils.batch_translations": {"tf": 2}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 72}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 39}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 6}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 3, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 3}}}, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 23}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 67}}}}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 9}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 9}}}}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}}}}}}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 67}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 13}}, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}}, "df": 4}}}, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.group_device": {"tf": 1}}, "df": 59}}}}}}, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 14, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 24, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.449489742783178}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 39}}}}, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 13}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 5}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}, "u": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 34}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 7}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 5}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 6}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 6}}}}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 91}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 2}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.group_device": {"tf": 1}}, "df": 87}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 5}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 2}}, "df": 1}}}}}, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 5}}}}, "h": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 2}}}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 42}}, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 13}}, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 35, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 30}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 15}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 6}}}}}, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.group_device": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}, "z": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 5}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}}, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 2, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.group_batch_size": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 18}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 4, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 11}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 7}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 22, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 7}}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 8}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 3}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 19}}}}}}}}, "g": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 6}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 2}}}}, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4}}}}, "t": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 2}}}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 6}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 5}}}}}, "e": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}}, "df": 1, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "g": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 21, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 10}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 7, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}}, "docs": {}, "df": 0}}}}}}}}}}}}}}, "t": {"docs": {"shimmer.modules.domain.LossOutput.__init__": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 6}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 7}}}}}}}}, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 10}}}}}}}}, "e": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}, "o": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 6, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2}}}}}}, "d": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 2, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}}, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "l": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}, "k": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 3}}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 3}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "bases": {"root": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}}, "df": 1, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.7320508075688772}}, "df": 1, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 6}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 4}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 9}}}}}, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}}, "df": 1, "[": {"docs": {}, "df": 0, "+": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 4, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}}, "df": 3, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}}, "df": 1}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"2": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}}, "df": 1}}}}}}}}, "docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}}, "df": 1, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 3}}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}}}}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.7320508075688772}}, "df": 1, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}}, "df": 12, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 12}}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 4, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 5}}}}}}}}}}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}}, "df": 1}}}}}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1.7320508075688772}}, "df": 3}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "p": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3}}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 2}}, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1.4142135623730951}}, "df": 1, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}}}}}, "b": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 1}}}}}}}}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}, "n": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}}, "df": 9}}, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.4142135623730951}}, "df": 5}}}}}, "doc": {"root": {"0": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 3}, "shimmer.types.LatentsDomainGroupDT": {"tf": 3}, "shimmer.types.LatentsDomainGroupsT": {"tf": 4}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 4}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.8284271247461903}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}}, "df": 9}, "1": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 14, "]": {"docs": {}, "df": 0, "^": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}}, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "^": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}}, "/": {"3": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}, "docs": {}, "df": 0}}, "2": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 2}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 19, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "b": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}}}, "3": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 7, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}, "4": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1}, "5": {"0": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}, "docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}, "6": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}, "8": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}, "9": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 2}, "docs": {"shimmer.types": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupT": {"tf": 10.816653826391969}, "shimmer.types.RawDomainGroupDT": {"tf": 9.746794344808963}, "shimmer.types.LatentsDomainGroupT": {"tf": 12.24744871391589}, "shimmer.types.LatentsDomainGroupDT": {"tf": 11.357816691600547}, "shimmer.types.RawDomainGroupsT": {"tf": 13.92838827718412}, "shimmer.types.RawDomainGroupsDT": {"tf": 13.92838827718412}, "shimmer.types.LatentsDomainGroupsT": {"tf": 16.822603841260722}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 16.186414056238647}, "shimmer.types.ModelModeT": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 2}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_lr": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.optim_weight_decay": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.scheduler_args": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 6}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 5.744562646538029}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 6.164414002968976}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 5.196152422706632}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 8.426149773176359}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 8.426149773176359}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 6.48074069840786}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 6.164414002968976}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 2}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GWPredictions.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 9.695359714832659}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 10.14889156509222}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 3.1622776601683795}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 11.045361017187261}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 5.291502622129181}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 10.535653752852738}, "shimmer.modules.domain": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 2.6457513110645907}, "shimmer.modules.domain.LossOutput.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 4}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.encode": {"tf": 5.291502622129181}, "shimmer.modules.domain.DomainModule.decode": {"tf": 5.291502622129181}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 5.830951894845301}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 5.830951894845301}, "shimmer.modules.gw_module": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.get_n_layers": {"tf": 6.324555320336759}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoder": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 3.872983346207417}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 3.872983346207417}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 3.7416573867739413}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 5}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 6}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 5.0990195135927845}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 5.916079783099616}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 6.164414002968976}, "shimmer.modules.gw_module.GWModule": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 6.4031242374328485}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 6}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 5.196152422706632}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 6.082762530298219}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 7.681145747868608}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 7.874007874011811}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_selection": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.sensitivity_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precision_softmax_temp": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 5.744562646538029}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 8.717797887081348}, "shimmer.modules.selection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 4.358898943540674}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 11.269427669584644}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 2.449489742783178}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 5.830951894845301}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 2.449489742783178}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 5.830951894845301}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 5.385164807134504}, "shimmer.modules.selection.KQFixedQSelection.head_size": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.query_layer": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.key_layers": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 6}, "shimmer.modules.selection.RandomSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 3.7416573867739413}, "shimmer.modules.selection.RandomSelection.temperature": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 5.291502622129181}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 5.385164807134504}, "shimmer.modules.selection.DynamicQueryAttention.head_size": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.query_layer": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.key_layers": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 5.916079783099616}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 6}, "shimmer.modules.losses": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 5.830951894845301}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.cycle_loss": {"tf": 8.774964387392123}, "shimmer.modules.losses.translation_loss": {"tf": 8.366600265340756}, "shimmer.modules.losses.contrastive_loss": {"tf": 8.366600265340756}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 8.366600265340756}, "shimmer.modules.losses.LossCoefs": {"tf": 2.449489742783178}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 6.6332495807108}, "shimmer.modules.losses.GWLosses2Domains.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.selection_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 5.656854249492381}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 7.874007874011811}, "shimmer.modules.losses.generate_partitions": {"tf": 5.5677643628300215}, "shimmer.modules.losses.broadcast_loss": {"tf": 6.6332495807108}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 2.449489742783178}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 6}, "shimmer.modules.losses.GWLosses.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.selection_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 4.795831523312719}, "shimmer.modules.losses.GWLosses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 5.291502622129181}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 7.0710678118654755}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.use_normalized_constrastive": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 5.196152422706632}, "shimmer.modules.losses.GWLossesBayesian.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 5.291502622129181}, "shimmer.modules.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 2.449489742783178}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 2.449489742783178}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 6.244997998398398}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 6.244997998398398}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 6.48074069840786}, "shimmer.modules.contrastive_loss.ContrastiveLoss.learn_logit_scale": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.reduction": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 5.830951894845301}, "shimmer.dataset": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset": {"tf": 2.23606797749979}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 5.196152422706632}, "shimmer.dataset.RepeatedDataset.dataset": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset.dataset_size": {"tf": 1.7320508075688772}, "shimmer.modules.vae": {"tf": 1.7320508075688772}, "shimmer.modules.vae.reparameterize": {"tf": 5.744562646538029}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 5.744562646538029}, "shimmer.modules.vae.gaussian_nll": {"tf": 6.324555320336759}, "shimmer.modules.vae.VAEEncoder": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 5.0990195135927845}, "shimmer.modules.vae.VAEDecoder": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 5}, "shimmer.modules.vae.VAE": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 5.5677643628300215}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.encode": {"tf": 5.196152422706632}, "shimmer.modules.vae.VAE.decode": {"tf": 5.291502622129181}, "shimmer.modules.vae.VAE.forward": {"tf": 5.196152422706632}, "shimmer.modules.utils": {"tf": 1.7320508075688772}, "shimmer.modules.utils.translation": {"tf": 6.928203230275509}, "shimmer.modules.utils.cycle": {"tf": 7.3484692283495345}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 6.4031242374328485}, "shimmer.modules.utils.batch_cycles": {"tf": 7}, "shimmer.modules.utils.batch_translations": {"tf": 6.4031242374328485}, "shimmer.utils": {"tf": 1.7320508075688772}, "shimmer.utils.MIGRATION_DIR": {"tf": 1.7320508075688772}, "shimmer.utils.group_batch_size": {"tf": 1.7320508075688772}, "shimmer.utils.groups_batch_size": {"tf": 5.0990195135927845}, "shimmer.utils.groups_device": {"tf": 5.0990195135927845}, "shimmer.utils.group_device": {"tf": 1.7320508075688772}, "shimmer.utils.migrate_model": {"tf": 4.898979485566356}, "shimmer.utils.SaveMigrations": {"tf": 2.23606797749979}, "shimmer.utils.SaveMigrations.migrations": {"tf": 1.7320508075688772}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 5.291502622129181}, "shimmer.cli.ckpt_migration": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 5.744562646538029}}, "df": 232, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}}, "df": 1, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}}}}}, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 16, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 18}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}}, "df": 1}}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}}, "y": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 1}}, "n": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 2}}}, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 2, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 13}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 5}}, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 2}, "shimmer.modules.losses.cycle_loss": {"tf": 2}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 14, "e": {"docs": {"shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}}, "df": 7, "l": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 15, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}}, "df": 2}}}}}, "s": {"docs": {"shimmer.modules.losses.GWLosses": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 33, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 16}, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 14}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2}}}}}}}, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 7, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 23}}}}}, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 3, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 9, "s": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 2}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.7320508075688772}}, "df": 1, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 6, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}}, "df": 2}}, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "b": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 4}}, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}}, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}}, "df": 2, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}}, "df": 33, "s": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 29}}}}}}}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4}}}}}}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 11, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 71}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}}}}}}, "l": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}}, "df": 1}, "e": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}}, "df": 4}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}}, "df": 42}}}}}, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}, "c": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 5}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 26, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 36}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 9}}}}, "p": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 1}}}}}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 15, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 2.23606797749979}}, "df": 2}}}}}}, "o": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.cycle_loss": {"tf": 3.1622776601683795}, "shimmer.modules.losses.translation_loss": {"tf": 3.3166247903554}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.7320508075688772}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 2}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 67, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_cycles": {"tf": 2}, "shimmer.modules.utils.batch_translations": {"tf": 1.7320508075688772}}, "df": 57}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 19}}}}}}}}}}, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}, "t": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 4}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss": {"tf": 1}}, "df": 1}}}}}}}}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 13, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 3}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 6}, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 8}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.449489742783178}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}}, "df": 13}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}}, "df": 21, "[": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 2}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}}, "df": 23}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 2}}}}}}}}}}, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 3}}}}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}}, "df": 9}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 4}}}}}}}}}}, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 13, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}}, "df": 21}}}}}}, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}}}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}}, "df": 2}}}}}, "y": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 3}}}}}}, "c": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 17}, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 4}}}}}}}, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"1": {"docs": {}, "df": 0, "}": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "\\": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "m": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}, "f": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "{": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}}, "docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}, "c": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}}, "df": 1}}}}}, "u": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}}, "df": 17, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 8}}}}}}}, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 19}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 2}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}}, "df": 1}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 66, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 5}}}}, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 3}}}}, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 8}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2}}}, "x": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}}, "df": 1}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}, "l": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 2}}}}, "n": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 5}}, "k": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 3}}, "df": 1, "e": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 18}}, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 2}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "^": {"2": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}}, "df": 1}, "docs": {}, "df": 0}, "l": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 2.23606797749979}, "shimmer.modules.utils.batch_cycles": {"tf": 2.6457513110645907}, "shimmer.modules.utils.batch_translations": {"tf": 2.449489742783178}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 110}, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 4, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.out_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 12, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 8}}}}}}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 4}}}}, "r": {"docs": {"shimmer.types.ModelModeT": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}}, "df": 2}}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 10, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 17, "c": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 9}}, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 2}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 2, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "f": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}, "p": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}}, "b": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}, "j": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 3}}}}}, "m": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1, "h": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 2.23606797749979}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.23606797749979}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.domain_mods": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 4}, "shimmer.modules.selection.SelectionBase": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 2}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 2}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 2.449489742783178}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 2.23606797749979}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.cycle_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.translation_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.LossCoefs": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.broadcast_loss": {"tf": 2}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 2}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 2.6457513110645907}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}, "shimmer.modules.vae.VAE.decoder": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.decode": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.forward": {"tf": 2.449489742783178}, "shimmer.modules.utils.translation": {"tf": 2.23606797749979}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 2}, "shimmer.utils.groups_device": {"tf": 2}, "shimmer.utils.migrate_model": {"tf": 2}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 146, "m": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}}, "df": 5}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}, "n": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}, "i": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 38}}, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 11}, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 2}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 12, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "s": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.GWLossesBase": {"tf": 1}}, "df": 1}}}}}}}}, "o": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.8284271247461903}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2.23606797749979}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 2.23606797749979}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 72, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.23606797749979}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}}, "df": 69}}}, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 7}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1.7320508075688772}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.7320508075688772}, "shimmer.modules.vae.gaussian_nll": {"tf": 2}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 60}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}}}}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {"shimmer.modules.losses.translation_loss": {"tf": 1}}, "df": 1, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2, "/": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 9}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 2}}, "df": 1}}}}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 12, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 6}}}}, "e": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}}, "df": 1}}}}}}}, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 1}, "[": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "[": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 13, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 2}}}}}, "n": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}, "w": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 2}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.8284271247461903}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 3}, "shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.7320508075688772}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 80, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}}, "df": 15}, "g": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 97}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 4}}}, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}}, "df": 5, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 58}, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 10, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 2}}}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 17, "o": {"docs": {}, "df": 0, "w": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}}}}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}}, "df": 3}}}}}, "b": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 4}}}}}}, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 5, "w": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}}, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 14}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}, "t": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 5, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 11}}}}}}}}, "v": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 5}}}}}}, "\\": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "p": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 1}}}, "y": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 5}}}}}, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 7, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 14, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 12}}}}, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}}, "df": 9, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 9}}, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 7}}, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2}}}}}}}}, "w": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}, "u": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 7}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2}}, "df": 8}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 3}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 3}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}}, "w": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 1}}}, "v": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}}, "df": 1, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 1, "u": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.beta": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 15, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 13}}}, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}}, "df": 1}}, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}}, "df": 1}}}}}}}, "/": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}}, "df": 2}}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 2}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}}, "df": 11, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.vae.VAE.beta": {"tf": 1}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 8}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}}}}, "i": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}}, "df": 1, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}}, "df": 34, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}}, "df": 1}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}}, "df": 3}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}}, "df": 1}}}}}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.in_dim": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1.7320508075688772}}, "df": 17, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}}, "df": 9}}}}, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}}}}, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.generate_partitions": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 17, "o": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 17}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 2}}}}}}}, "f": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.gw_module.GWModule.decode": {"tf": 1}}, "df": 1}}}}}, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 3}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "z": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}}, "df": 12}}}}}}}}}, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 1}}}}}}}}, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}}, "df": 4}}}, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 5}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.modules.gw_module.GWModule": {"tf": 1}}, "df": 1}}}}}}}}}, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.domain.LossOutput": {"tf": 1.7320508075688772}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 46}, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 6, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 3, "[": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 6}}}}}}}}}, "m": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "f": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}}, "df": 11}, "g": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "^": {"2": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}, "docs": {}, "df": 0}}, "s": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_decoders": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 18, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 6}}}}, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 8}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 4}, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "d": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}, "c": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 4, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 5, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.23606797749979}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 12}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 8, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}}, "df": 2}}}}}, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 2}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.total_steps": {"tf": 1}}, "df": 1}}, "m": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}}, "df": 4, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}}, "df": 1}}, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}, "r": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 4}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1}}, "df": 9, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}}, "df": 3}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.selection_mod": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}}, "df": 26, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.selection_mod": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 11}}}}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}}, "s": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 2}}}}}, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 3}, "n": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 2}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}}, "df": 3}}}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}}, "df": 6}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 2}}}}, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}}, "df": 1}}}}}}}}}, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 2}, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}}, "df": 3}}}}}, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "z": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 2.6457513110645907}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1.7320508075688772}, "shimmer.utils.groups_batch_size": {"tf": 1.4142135623730951}, "shimmer.utils.groups_device": {"tf": 1.4142135623730951}}, "df": 8, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 1}}}}}}}}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}, "g": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 2}}}}, "o": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.__init__": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 8, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}}}}, "m": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 5, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase.forward": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 1}}}}, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 2}}, "df": 2}}}}}, "p": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.modules.domain.DomainModule": {"tf": 1}}, "df": 1}}}}}}}, "u": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}}}}}, "m": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}, "f": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "x": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}}}}, "b": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 20, "t": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}, "w": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 4}}}}, "a": {"docs": {"shimmer.modules.vae.VAE.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.vae.VAE.beta": {"tf": 1.4142135623730951}}, "df": 2}}, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}}, "df": 3}}}}}, "y": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 13}, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 9, "d": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}}, "df": 5}}}, "t": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 2}, "shimmer.utils.groups_device": {"tf": 2}}, "df": 18}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 4}}}}}}, "c": {"docs": {}, "df": 0, "k": {"docs": {"shimmer.modules.domain.DomainModule.decode": {"tf": 1}}, "df": 1}}}, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 2}}}, "t": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 6}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 5}}}, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}}, "df": 3, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 3}}}}}}}}}}}}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 11}}}}, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}}}, "c": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2}}}}}}}}}, "l": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 26}}}, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 22, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 3}, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAE.__init__": {"tf": 1}, "shimmer.modules.vae.VAE.encoder": {"tf": 1}}, "df": 6, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.gw_encoders": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}}, "df": 8}}, "d": {"docs": {"shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 1}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 3}}}}}}}, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 3}}}}, "p": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.7320508075688772}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}}, "df": 21, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_translations": {"tf": 1.4142135623730951}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 29}}}}}, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 9}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}}, "df": 1}}}, "e": {"docs": {"shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}}, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 3}}, "t": {"docs": {"shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1.7320508075688772}}, "df": 6}, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1, "n": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 8}, "s": {"docs": {"shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}}, "df": 10, "w": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 6}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}}}}}}}}, "w": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.workspace_dim": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2.449489742783178}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.workspace_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.7320508075688772}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.7320508075688772}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 45, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.gw_mod": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 19, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.gw_mod": {"tf": 1}, "shimmer.modules.gw_module.GWModule": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 12}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4}}}}}}}}, "s": {"docs": {"shimmer.modules.gw_module.GWDecoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 3}}}}}}}, "p": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}}, "df": 3}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}}, "df": 2, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.GWLosses2Domains": {"tf": 1}, "shimmer.modules.losses.GWLosses": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian": {"tf": 1}}, "df": 3}}}}, "f": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "a": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}, "x": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.domain.DomainModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 24, "i": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.7320508075688772}}, "df": 1}}, "q": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 2.8284271247461903}, "shimmer.types.RawDomainGroupDT": {"tf": 2.8284271247461903}, "shimmer.types.LatentsDomainGroupT": {"tf": 2}, "shimmer.types.LatentsDomainGroupDT": {"tf": 2}, "shimmer.types.RawDomainGroupsT": {"tf": 4.898979485566356}, "shimmer.types.RawDomainGroupsDT": {"tf": 4.898979485566356}, "shimmer.types.LatentsDomainGroupsT": {"tf": 4}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 4}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 2.8284271247461903}}, "df": 9}}, "a": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}}, "df": 1}}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1.4142135623730951}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}}, "df": 6}, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}}}}, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}}}}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.7320508075688772}}, "df": 3, "/": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 4}}}}}}}}}}}, "c": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "/": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}}, "df": 2}}}}}}}}}}}}}}}, "l": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 1}}}, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 6, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}}, "df": 2}}}}}, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.7320508075688772}}, "df": 1, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}, "a": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 1, "s": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}}, "df": 1}}}}}}}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 1.4142135623730951}, "shimmer.types.RawDomainGroupsDT": {"tf": 1.4142135623730951}}, "df": 4}}, "y": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 2}}}}, "h": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 12, "d": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.encode": {"tf": 1}}, "df": 15, "s": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 10}}}, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1.4142135623730951}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}}, "df": 2}}}}}}, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.get_precision": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}}}}, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}}, "df": 3}}}}}}, "o": {"docs": {}, "df": 0, "x": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}}, "df": 2}}}}, "v": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.utils.translation": {"tf": 1}}, "df": 9}}}}}}}, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 1}}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.RawDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.RawDomainGroupsT": {"tf": 2}, "shimmer.types.RawDomainGroupsDT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsT": {"tf": 2}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 2}}, "df": 8}}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1.4142135623730951}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.latent_dim": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 55, "s": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 6, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.encode": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.4142135623730951}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.4142135623730951}, "shimmer.modules.utils.translation": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 19}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.utils.cycle": {"tf": 1}}, "df": 6}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_and_fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.fuse": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.utils.groups_batch_size": {"tf": 1}, "shimmer.utils.groups_device": {"tf": 1}}, "df": 19}, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}}, "df": 4}}}}}}}}}}}}}}, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 7}}}, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 2.23606797749979}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 2}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 2}}, "df": 4}}}}, "m": {"docs": {}, "df": 0, "b": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "a": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}}, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 3, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}}}}}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 4, "[": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1}}, "df": 1}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear": {"tf": 1}}, "df": 2, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.gw_module.GWEncoder": {"tf": 1}}, "df": 1}}}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.GWLossesBayesian.domain_mods": {"tf": 1}}, "df": 1}}}}, "k": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 4}}, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 4}}}}}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}}, "df": 5, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.SchedulerArgs.max_lr": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.precisions": {"tf": 1}}, "df": 1}}}, "n": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 2}}, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.generic_step": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.6457513110645907}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.domain.LossOutput.loss": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 2}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 2}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1.4142135623730951}, "shimmer.modules.losses.LossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.7320508075688772}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 2.23606797749979}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1.4142135623730951}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1.4142135623730951}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}}, "df": 54, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.domain.LossOutput": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 15}}, "c": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "f": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}}, "df": 4}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 12}}}}, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}}, "df": 5}}}}}}}, "g": {"docs": {"shimmer.modules.domain.LossOutput.metrics": {"tf": 1}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1.4142135623730951}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}}, "df": 4, "i": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.contrastive_loss.info_nce": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 2}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 1}}}, "g": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1}}, "df": 1}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}}, "df": 2}}}, "v": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 3, "s": {"docs": {"shimmer.modules.vae.kl_divergence_loss": {"tf": 1}}, "df": 1}}}}}, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}}, "df": 2, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 1}}}}}, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}}, "df": 3}, "t": {"docs": {"shimmer.dataset.RepeatedDataset": {"tf": 1}}, "df": 1}}, "c": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2.449489742783178}}, "df": 3, "o": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}}, "df": 3}}}}}}}}}, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 4, "a": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.types.RawDomainGroupsT": {"tf": 1}, "shimmer.types.RawDomainGroupsDT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsT": {"tf": 1}, "shimmer.types.LatentsDomainGroupsDT": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.VAE.forward": {"tf": 1}}, "df": 7}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 2.449489742783178}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 2.6457513110645907}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2.6457513110645907}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.losses.GWLossesBayesian.contrastive_fn": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossType": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLossBayesianType": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1.4142135623730951}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}}, "df": 22, "l": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 1, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 7}}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "y": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 1}}}}}}}}}}}}}}}}, "s": {"docs": {"shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 2}}}}}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domain": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domain": {"tf": 1}}, "df": 2}}}}}}}, "n": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}}, "df": 2}}}}}}}}}}, "m": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}}, "df": 8, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.loss_mod": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBase.step": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.translation_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.translation_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.contrastive_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 1}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLosses.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}, "shimmer.modules.vae.kl_divergence_loss": {"tf": 1}, "shimmer.modules.vae.gaussian_nll": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 27}, "d": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1}}, "df": 3}}, "a": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 9}}}}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.losses.GWLosses.__init__": {"tf": 1}}, "df": 1}}}}}, "t": {"docs": {}, "df": 0, "u": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.batch_gw_states": {"tf": 1}}, "df": 1}}}}, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}}}, "e": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.selection.SelectionBase": {"tf": 1}}, "df": 1}}}}}}}}, "m": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}}, "df": 1}}}, "b": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBase.fuse": {"tf": 1}, "shimmer.modules.gw_module.GWModule.fuse": {"tf": 1}, "shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 4, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase.encode_and_fuse": {"tf": 1}}, "df": 1}, "d": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}}, "df": 2}}}}}}, "e": {"docs": {}, "df": 0, "f": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 2, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 7}, "f": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.selection.SelectionBase.forward": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.translations": {"tf": 1}, "shimmer.modules.losses.LossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.contrastives": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.fused": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}}, "df": 14, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.loss_coefs": {"tf": 1}}, "df": 11}}}}}}}}}}, "p": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}}, "df": 1}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}}, "df": 1}}}}, "e": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}, "u": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.modules.utils.batch_translations": {"tf": 1}}, "df": 2}}}}}, "a": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1.4142135623730951}}, "df": 2, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}}, "df": 2}}}, "b": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 1}}}}}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 1}}}, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}}, "df": 1}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}, "n": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 5}}, "l": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.vae.VAEEncoder": {"tf": 1}, "shimmer.modules.vae.VAEDecoder": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1.4142135623730951}}, "df": 14}}}, "i": {"docs": {}, "df": 0, "p": {"docs": {"shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss": {"tf": 1}}, "df": 2}}}, "y": {"docs": {"shimmer.modules.losses.cycle_loss": {"tf": 1}}, "df": 1, "c": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLossesBase": {"tf": 1.4142135623730951}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.cycle_loss": {"tf": 2.23606797749979}, "shimmer.modules.losses.LossCoefs.demi_cycles": {"tf": 1}, "shimmer.modules.losses.LossCoefs.cycles": {"tf": 1}, "shimmer.modules.losses.GWLosses2Domains.demi_cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.cycle_loss": {"tf": 1.4142135623730951}, "shimmer.modules.losses.GWLosses2Domains.step": {"tf": 2}, "shimmer.modules.losses.broadcast_loss": {"tf": 1.4142135623730951}, "shimmer.modules.utils.cycle": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}}, "df": 15, "s": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.forward": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.forward": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1}, "shimmer.modules.losses.cycle_loss": {"tf": 1}, "shimmer.modules.losses.BroadcastLossCoefs.demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs.translations": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_demi_cycles": {"tf": 1.4142135623730951}, "shimmer.modules.utils.batch_cycles": {"tf": 1.4142135623730951}}, "df": 12}}, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "g": {"docs": {"shimmer.modules.utils.cycle": {"tf": 1}}, "df": 1}}}}}}, "u": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "m": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}}, "df": 4}}}}, "r": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1, "l": {"docs": {}, "df": 0, "y": {"docs": {"shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 2}}}}}}}}, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 2}, "shimmer.utils.migrate_model": {"tf": 2}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.7320508075688772}}, "df": 3, "s": {"docs": {"shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1.4142135623730951}}, "df": 1}}}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 2}}}}}, "d": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1}}}, "k": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.migrate_model": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 2}}}}, "w": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.types.ModelModeT": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictionsBase": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 8}, "t": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "r": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.selection.SingleDomainSelection.forward": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection.forward": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 8}}}}, "r": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 5}}}, "i": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLosses.step": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.step": {"tf": 1}}, "df": 7}}, "l": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}, "a": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.domain.LossOutput": {"tf": 1.4142135623730951}}, "df": 1}}}, "i": {"docs": {}, "df": 0, "t": {"docs": {}, "df": 0, "h": {"docs": {"shimmer.modules.global_workspace.GWPredictionsBase.states": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.encode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBase.decode_domains": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.demi_cycles": {"tf": 1}, "shimmer.modules.global_workspace.GWPredictions.cycles": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GWPredictions.translations": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.domain.LossOutput.all": {"tf": 1.4142135623730951}, "shimmer.modules.domain.DomainModule.compute_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_dcy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_cy_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_tr_loss": {"tf": 1}, "shimmer.modules.domain.DomainModule.compute_broadcast_loss": {"tf": 1}, "shimmer.modules.gw_module.get_n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection": {"tf": 1}, "shimmer.modules.selection.KQFixedQSelection.forward": {"tf": 1.7320508075688772}, "shimmer.modules.selection.DynamicQueryAttention": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.forward": {"tf": 1.7320508075688772}, "shimmer.modules.losses.demi_cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.cycle_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.translation_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss": {"tf": 1.7320508075688772}, "shimmer.modules.losses.contrastive_loss_bayesian": {"tf": 2}, "shimmer.modules.losses.GWLosses2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.losses.broadcast_loss": {"tf": 1}, "shimmer.modules.losses.GWLossesBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.vae.reparameterize": {"tf": 1}, "shimmer.modules.vae.VAEEncoder.forward": {"tf": 1}, "shimmer.modules.vae.VAEDecoder.forward": {"tf": 1}, "shimmer.cli.ckpt_migration.migrate_ckpt": {"tf": 1}}, "df": 33, "i": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}}, "df": 2}}}}, "l": {"docs": {}, "df": 0, "l": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1}, "shimmer.modules.selection.SelectionBase.update_gw_state": {"tf": 1}, "shimmer.modules.losses.LossCoefs": {"tf": 1.4142135623730951}, "shimmer.modules.losses.BroadcastLossCoefs": {"tf": 1.4142135623730951}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.modules.utils.batch_cycles": {"tf": 1}, "shimmer.utils.migrate_model": {"tf": 1.4142135623730951}, "shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 14}}, "s": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.selection.RandomSelection": {"tf": 1}, "shimmer.modules.selection.RandomSelection.forward": {"tf": 1}}, "df": 2}}}, "o": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {}, "df": 0, "p": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "c": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian": {"tf": 1}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.pretrained_global_workspace": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModule.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.__init__": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.GWLossesBase": {"tf": 1}, "shimmer.modules.losses.GWLosses.__init__": {"tf": 1.4142135623730951}}, "df": 14}}}}}}}}, "e": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1.4142135623730951}}, "df": 1, "i": {"docs": {}, "df": 0, "g": {"docs": {}, "df": 0, "h": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspace2Domains.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspace.__init__": {"tf": 1.4142135623730951}, "shimmer.modules.global_workspace.GlobalWorkspaceBayesian.__init__": {"tf": 1.4142135623730951}}, "df": 3, "s": {"docs": {"shimmer.modules.global_workspace.freeze_domain_modules": {"tf": 1}, "shimmer.modules.selection.FixedSharedSelection": {"tf": 1}}, "df": 2}, "e": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.DynamicQueryAttention.fuse_weighted_encodings": {"tf": 1}}, "df": 1}}}}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "t": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1}}, "df": 1}}}}, "z": {"docs": {"shimmer.modules.global_workspace.GlobalWorkspaceBase.decode": {"tf": 1}, "shimmer.modules.domain.DomainModule.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModuleBase.decode": {"tf": 1}, "shimmer.modules.gw_module.GWModule.decode": {"tf": 1}, "shimmer.modules.vae.VAE.decode": {"tf": 1}}, "df": 5, "e": {"docs": {}, "df": 0, "r": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.losses.generate_partitions": {"tf": 2}}, "df": 1}}}}}, "h": {"docs": {}, "df": 0, "i": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "n": {"docs": {"shimmer.modules.gw_module.get_n_layers": {"tf": 1.4142135623730951}, "shimmer.modules.gw_module.GWDecoder.__init__": {"tf": 1.7320508075688772}, "shimmer.modules.gw_module.GWDecoder.hidden_dim": {"tf": 1}, "shimmer.modules.gw_module.GWDecoder.n_layers": {"tf": 1}, "shimmer.modules.gw_module.GWEncoder.__init__": {"tf": 1.7320508075688772}}, "df": 5}}}}}, "o": {"docs": {}, "df": 0, "o": {"docs": {}, "df": 0, "k": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWEncoder.forward": {"tf": 1}, "shimmer.modules.gw_module.GWEncoderLinear.forward": {"tf": 1}, "shimmer.utils.SaveMigrations": {"tf": 1}}, "df": 3}}}}, "a": {"docs": {}, "df": 0, "n": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBase": {"tf": 1}, "shimmer.modules.selection.SelectionBase": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 3}}}}}, "v": {"docs": {}, "df": 0, "e": {"docs": {"shimmer.modules.gw_module.compute_fusion_scores": {"tf": 1}, "shimmer.modules.selection.SingleDomainSelection": {"tf": 1}, "shimmer.dataset.RepeatedDataset": {"tf": 1}, "shimmer.dataset.RepeatedDataset.__init__": {"tf": 1}}, "df": 4}}, "s": {"docs": {"shimmer.modules.selection.SingleDomainSelection": {"tf": 1}}, "df": 1}}, "e": {"docs": {}, "df": 0, "a": {"docs": {}, "df": 0, "d": {"docs": {"shimmer.modules.selection.KQFixedQSelection.__init__": {"tf": 1}, "shimmer.modules.selection.DynamicQueryAttention.__init__": {"tf": 1}}, "df": 2}}}}, "y": {"docs": {"shimmer.modules.contrastive_loss.info_nce": {"tf": 1}, "shimmer.modules.contrastive_loss.contrastive_loss": {"tf": 1}, "shimmer.modules.contrastive_loss.ContrastiveLoss.forward": {"tf": 1}}, "df": 3, "i": {"docs": {}, "df": 0, "e": {"docs": {}, "df": 0, "l": {"docs": {}, "df": 0, "d": {"docs": {}, "df": 0, "s": {"docs": {"shimmer.modules.gw_module.GWModuleBayesian.fuse": {"tf": 1}, "shimmer.modules.losses.generate_partitions": {"tf": 1}}, "df": 2}}}}}, "o": {"docs": {}, "df": 0, "u": {"docs": {"shimmer.utils.SaveMigrations.on_save_checkpoint": {"tf": 1.4142135623730951}}, "df": 1}}}}}}, "pipeline": ["trimmer"], "_isPrebuiltIndex": true}; + + // mirrored in build-search-index.js (part 1) + // Also split on html tags. this is a cheap heuristic, but good enough. + elasticlunr.tokenizer.setSeperator(/[\s\-.;&_'"=,()]+|<[^>]*>/); + + let searchIndex; + if (docs._isPrebuiltIndex) { + console.info("using precompiled search index"); + searchIndex = elasticlunr.Index.load(docs); + } else { + console.time("building search index"); + // mirrored in build-search-index.js (part 2) + searchIndex = elasticlunr(function () { + this.pipeline.remove(elasticlunr.stemmer); + this.pipeline.remove(elasticlunr.stopWordFilter); + this.addField("qualname"); + this.addField("fullname"); + this.addField("annotation"); + this.addField("default_value"); + this.addField("signature"); + this.addField("bases"); + this.addField("doc"); + this.setRef("fullname"); + }); + for (let doc of docs) { + searchIndex.addDoc(doc); + } + console.timeEnd("building search index"); + } + + return (term) => searchIndex.search(term, { + fields: { + qualname: {boost: 4}, + fullname: {boost: 2}, + annotation: {boost: 2}, + default_value: {boost: 2}, + signature: {boost: 2}, + bases: {boost: 2}, + doc: {boost: 1}, + }, + expand: true + }); +})(); \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html b/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html new file mode 100644 index 00000000..a6f4201a --- /dev/null +++ b/docs/api/v0.5.1/shimmer/cli/ckpt_migration.html @@ -0,0 +1,319 @@ + + + + + + + shimmer.cli.ckpt_migration API documentation + + + + + + + + + + + + + +
+
+

+shimmer.cli.ckpt_migration

+ + + + + + +
 1from collections.abc import Sequence
+ 2from pathlib import Path
+ 3
+ 4import click
+ 5
+ 6from shimmer.utils import migrate_model
+ 7
+ 8
+ 9@click.command("migrate-ckpt")
+10@click.argument(
+11    "paths",
+12    nargs=-1,
+13    type=click.Path(exists=True, path_type=Path, file_okay=True, dir_okay=False),
+14)
+15def migrate_ckpt(paths: Sequence[Path]):
+16    """
+17    Script to migrate a list of checkpoints.
+18    This can be called with:
+19    ```sh
+20    shimmer migrate-ckpt PATH_1 PATH_2 ... PATH_N
+21    ```
+22    where paths point to checkpoints.
+23
+24    Internally, this calls `shimmer.utils.migrate_model` for each of the given paths.
+25    """
+26    for path in paths:
+27        migrate_model(path)
+
+ + +
+
+
+ migrate_ckpt = +<Command migrate-ckpt> + + +
+ + +

Script to migrate a list of checkpoints. +This can be called with:

+ +
+
shimmer migrate-ckpt PATH_1 PATH_2 ... PATH_N
+
+
+ +

where paths point to checkpoints.

+ +

Internally, this calls shimmer.utils.migrate_model for each of the given paths.

+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/dataset.html b/docs/api/v0.5.1/shimmer/dataset.html new file mode 100644 index 00000000..1ba30a7d --- /dev/null +++ b/docs/api/v0.5.1/shimmer/dataset.html @@ -0,0 +1,448 @@ + + + + + + + shimmer.dataset API documentation + + + + + + + + + + + + + +
+
+

+shimmer.dataset

+ + + + + + +
 1from typing import Any, Protocol
+ 2
+ 3from torch.utils.data import Dataset
+ 4
+ 5
+ 6class _SizedDataset(Protocol):
+ 7    def __getitem__(self, k: int) -> Any: ...
+ 8
+ 9    def __len__(self) -> int: ...
+10
+11
+12class RepeatedDataset(Dataset):
+13    """
+14    Dataset that cycles through its items to have a size of at least min size.
+15    If drop_last is True, the size will be exaclty min_size. If drop_last is False,
+16    the min_size ≤ size < min_size + len(dataset).
+17    """
+18
+19    def __init__(self, dataset: _SizedDataset, min_size: int, drop_last: bool = False):
+20        """
+21        Args:
+22            dataset (SizedDataset): dataset to repeat. The dataset should have a size
+23                (where `__len__` is defined).
+24            min_size (int): minimum size of the final dataset
+25            drop_last (bool): whether to remove overflow when repeating the
+26                dataset.
+27        """
+28        self.dataset = dataset
+29        assert min_size >= len(self.dataset)
+30        self.dataset_size = len(self.dataset)
+31        if drop_last:
+32            self.total_size = min_size
+33        else:
+34            self.total_size = (
+35                min_size // self.dataset_size + int(min_size % self.dataset_size > 0)
+36            ) * self.dataset_size
+37
+38    def __len__(self) -> int:
+39        """
+40        Size of the dataset. Will be min_size if drop_last is True.
+41        Otherwise, min_size ≤ size < min_size + len(dataset).
+42        """
+43        return self.total_size
+44
+45    def __getitem__(self, index: int) -> Any:
+46        return self.dataset[index % self.dataset_size]
+
+ + +
+
+ +
+ + class + RepeatedDataset(typing.Generic[+T_co]): + + + +
+ +
13class RepeatedDataset(Dataset):
+14    """
+15    Dataset that cycles through its items to have a size of at least min size.
+16    If drop_last is True, the size will be exaclty min_size. If drop_last is False,
+17    the min_size ≤ size < min_size + len(dataset).
+18    """
+19
+20    def __init__(self, dataset: _SizedDataset, min_size: int, drop_last: bool = False):
+21        """
+22        Args:
+23            dataset (SizedDataset): dataset to repeat. The dataset should have a size
+24                (where `__len__` is defined).
+25            min_size (int): minimum size of the final dataset
+26            drop_last (bool): whether to remove overflow when repeating the
+27                dataset.
+28        """
+29        self.dataset = dataset
+30        assert min_size >= len(self.dataset)
+31        self.dataset_size = len(self.dataset)
+32        if drop_last:
+33            self.total_size = min_size
+34        else:
+35            self.total_size = (
+36                min_size // self.dataset_size + int(min_size % self.dataset_size > 0)
+37            ) * self.dataset_size
+38
+39    def __len__(self) -> int:
+40        """
+41        Size of the dataset. Will be min_size if drop_last is True.
+42        Otherwise, min_size ≤ size < min_size + len(dataset).
+43        """
+44        return self.total_size
+45
+46    def __getitem__(self, index: int) -> Any:
+47        return self.dataset[index % self.dataset_size]
+
+ + +

Dataset that cycles through its items to have a size of at least min size. +If drop_last is True, the size will be exaclty min_size. If drop_last is False, +the min_size ≤ size < min_size + len(dataset).

+
+ + +
+ +
+ + RepeatedDataset( dataset: shimmer.dataset._SizedDataset, min_size: int, drop_last: bool = False) + + + +
+ +
20    def __init__(self, dataset: _SizedDataset, min_size: int, drop_last: bool = False):
+21        """
+22        Args:
+23            dataset (SizedDataset): dataset to repeat. The dataset should have a size
+24                (where `__len__` is defined).
+25            min_size (int): minimum size of the final dataset
+26            drop_last (bool): whether to remove overflow when repeating the
+27                dataset.
+28        """
+29        self.dataset = dataset
+30        assert min_size >= len(self.dataset)
+31        self.dataset_size = len(self.dataset)
+32        if drop_last:
+33            self.total_size = min_size
+34        else:
+35            self.total_size = (
+36                min_size // self.dataset_size + int(min_size % self.dataset_size > 0)
+37            ) * self.dataset_size
+
+ + +
Arguments:
+ +
    +
  • dataset (SizedDataset): dataset to repeat. The dataset should have a size +(where __len__ is defined).
  • +
  • min_size (int): minimum size of the final dataset
  • +
  • drop_last (bool): whether to remove overflow when repeating the +dataset.
  • +
+
+ + +
+
+
+ dataset + + +
+ + + + +
+
+
+ dataset_size + + +
+ + + + +
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html b/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html new file mode 100644 index 00000000..55fb2e87 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/contrastive_loss.html @@ -0,0 +1,797 @@ + + + + + + + shimmer.modules.contrastive_loss API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.contrastive_loss

+ +

Various contrastive loss definitions

+
+ + + + + +
  1"""Various contrastive loss definitions"""
+  2
+  3from collections.abc import Callable
+  4from typing import Literal
+  5
+  6import torch
+  7from torch.nn.functional import cross_entropy, normalize
+  8
+  9from shimmer.modules.domain import LossOutput
+ 10
+ 11ContrastiveLossType = Callable[[torch.Tensor, torch.Tensor], LossOutput]
+ 12"""
+ 13Contrastive loss function type.
+ 14
+ 15A function taking the prediction and targets and returning a LossOutput.
+ 16"""
+ 17
+ 18ContrastiveLossBayesianType = Callable[
+ 19    [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
+ 20]
+ 21"""
+ 22Contrastive loss function type for GlobalWorkspaceBayesian.
+ 23
+ 24A function taking the prediction mean, prediction std, target mean and target std and
+ 25    returns a LossOutput.
+ 26"""
+ 27
+ 28
+ 29def info_nce(
+ 30    x: torch.Tensor,
+ 31    y: torch.Tensor,
+ 32    logit_scale: torch.Tensor,
+ 33    reduction: Literal["mean", "sum", "none"] = "mean",
+ 34) -> torch.Tensor:
+ 35    """
+ 36    InfoNCE loss
+ 37
+ 38    Args:
+ 39        x (`torch.Tensor`): prediction
+ 40        y (`torch.Tensor`): target
+ 41        logit_scale (`torch.Tensor`): logit scale
+ 42        reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+ 43
+ 44    Returns: the InfoNCE loss
+ 45    """
+ 46    xn = normalize(x)
+ 47    yn = normalize(y)
+ 48    logits = torch.clamp(logit_scale.exp(), max=100) * xn @ yn.t()
+ 49    labels = torch.arange(xn.size(0)).to(logits.device)
+ 50    return cross_entropy(logits, labels, reduction=reduction)
+ 51
+ 52
+ 53def contrastive_loss(
+ 54    x: torch.Tensor,
+ 55    y: torch.Tensor,
+ 56    logit_scale: torch.Tensor,
+ 57    reduction: Literal["mean", "sum", "none"] = "mean",
+ 58) -> torch.Tensor:
+ 59    """
+ 60    CLIP-like contrastive loss
+ 61
+ 62    Args:
+ 63        x (`torch.Tensor`): prediction
+ 64        y (`torch.Tensor`): target
+ 65        logit_scale (`torch.Tensor`): logit scale
+ 66        reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+ 67
+ 68    Returns: the contrastive loss
+ 69    """
+ 70    xn = normalize(x)
+ 71    yn = normalize(y)
+ 72    logits = torch.clamp(logit_scale.exp(), max=100) * xn @ yn.t()
+ 73    labels = torch.arange(xn.size(0)).to(logits.device)
+ 74    ce = cross_entropy(logits, labels, reduction=reduction)
+ 75    ce_t = cross_entropy(logits.t(), labels, reduction=reduction)
+ 76    return 0.5 * (ce + ce_t)
+ 77
+ 78
+ 79class ContrastiveLoss(torch.nn.Module):
+ 80    """CLIP-like ContrastiveLoss torch module."""
+ 81
+ 82    def __init__(
+ 83        self,
+ 84        logit_scale: torch.Tensor,
+ 85        reduction: Literal["mean", "sum", "none"] = "mean",
+ 86        learn_logit_scale: bool = False,
+ 87    ) -> None:
+ 88        """
+ 89        Initializes a contrastive loss.
+ 90
+ 91        Args:
+ 92            logit_scale (`torch.Tensor`): logit_scale tensor.
+ 93            reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 94                loss. Defaults to `"mean"`.
+ 95            learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 96                parameter. Defaults to `False`.
+ 97        """
+ 98        super().__init__()
+ 99
+100        if learn_logit_scale:
+101            self.logit_scale = torch.nn.Parameter(logit_scale)
+102        else:
+103            self.register_buffer("logit_scale", logit_scale)
+104        self.learn_logit_scale = learn_logit_scale
+105        self.reduction: Literal["mean", "sum", "none"] = reduction
+106
+107    def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
+108        """
+109        Computes the loss.
+110
+111        Args:
+112            x (`torch.Tensor`): prediction
+113            y (`torch.Tensor`): target
+114
+115        Returns:
+116            LossOutput of the loss. Contains a `logit_scale` metric.
+117        """
+118        return LossOutput(
+119            contrastive_loss(x, y, self.logit_scale, self.reduction),
+120            {"logit_scale": self.logit_scale.exp()},
+121        )
+
+ + +
+
+
+ ContrastiveLossType = +collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] + + +
+ + +

Contrastive loss function type.

+ +

A function taking the prediction and targets and returning a LossOutput.

+
+ + +
+
+
+ ContrastiveLossBayesianType = + + collections.abc.Callable[[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] + + +
+ + +

Contrastive loss function type for GlobalWorkspaceBayesian.

+ +

A function taking the prediction mean, prediction std, target mean and target std and + returns a LossOutput.

+
+ + +
+
+ +
+ + def + info_nce( x: torch.Tensor, y: torch.Tensor, logit_scale: torch.Tensor, reduction: Literal['mean', 'sum', 'none'] = 'mean') -> torch.Tensor: + + + +
+ +
30def info_nce(
+31    x: torch.Tensor,
+32    y: torch.Tensor,
+33    logit_scale: torch.Tensor,
+34    reduction: Literal["mean", "sum", "none"] = "mean",
+35) -> torch.Tensor:
+36    """
+37    InfoNCE loss
+38
+39    Args:
+40        x (`torch.Tensor`): prediction
+41        y (`torch.Tensor`): target
+42        logit_scale (`torch.Tensor`): logit scale
+43        reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+44
+45    Returns: the InfoNCE loss
+46    """
+47    xn = normalize(x)
+48    yn = normalize(y)
+49    logits = torch.clamp(logit_scale.exp(), max=100) * xn @ yn.t()
+50    labels = torch.arange(xn.size(0)).to(logits.device)
+51    return cross_entropy(logits, labels, reduction=reduction)
+
+ + +

InfoNCE loss

+ +
Arguments:
+ +
    +
  • x (torch.Tensor): prediction
  • +
  • y (torch.Tensor): target
  • +
  • logit_scale (torch.Tensor): logit scale
  • +
  • reduction (Literal["mean", "sum", "none"]): reduction to apply
  • +
+ +

Returns: the InfoNCE loss

+
+ + +
+
+ +
+ + def + contrastive_loss( x: torch.Tensor, y: torch.Tensor, logit_scale: torch.Tensor, reduction: Literal['mean', 'sum', 'none'] = 'mean') -> torch.Tensor: + + + +
+ +
54def contrastive_loss(
+55    x: torch.Tensor,
+56    y: torch.Tensor,
+57    logit_scale: torch.Tensor,
+58    reduction: Literal["mean", "sum", "none"] = "mean",
+59) -> torch.Tensor:
+60    """
+61    CLIP-like contrastive loss
+62
+63    Args:
+64        x (`torch.Tensor`): prediction
+65        y (`torch.Tensor`): target
+66        logit_scale (`torch.Tensor`): logit scale
+67        reduction (`Literal["mean", "sum", "none"]`): reduction to apply
+68
+69    Returns: the contrastive loss
+70    """
+71    xn = normalize(x)
+72    yn = normalize(y)
+73    logits = torch.clamp(logit_scale.exp(), max=100) * xn @ yn.t()
+74    labels = torch.arange(xn.size(0)).to(logits.device)
+75    ce = cross_entropy(logits, labels, reduction=reduction)
+76    ce_t = cross_entropy(logits.t(), labels, reduction=reduction)
+77    return 0.5 * (ce + ce_t)
+
+ + +

CLIP-like contrastive loss

+ +
Arguments:
+ +
    +
  • x (torch.Tensor): prediction
  • +
  • y (torch.Tensor): target
  • +
  • logit_scale (torch.Tensor): logit scale
  • +
  • reduction (Literal["mean", "sum", "none"]): reduction to apply
  • +
+ +

Returns: the contrastive loss

+
+ + +
+
+ +
+ + class + ContrastiveLoss(torch.nn.modules.module.Module): + + + +
+ +
 80class ContrastiveLoss(torch.nn.Module):
+ 81    """CLIP-like ContrastiveLoss torch module."""
+ 82
+ 83    def __init__(
+ 84        self,
+ 85        logit_scale: torch.Tensor,
+ 86        reduction: Literal["mean", "sum", "none"] = "mean",
+ 87        learn_logit_scale: bool = False,
+ 88    ) -> None:
+ 89        """
+ 90        Initializes a contrastive loss.
+ 91
+ 92        Args:
+ 93            logit_scale (`torch.Tensor`): logit_scale tensor.
+ 94            reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 95                loss. Defaults to `"mean"`.
+ 96            learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 97                parameter. Defaults to `False`.
+ 98        """
+ 99        super().__init__()
+100
+101        if learn_logit_scale:
+102            self.logit_scale = torch.nn.Parameter(logit_scale)
+103        else:
+104            self.register_buffer("logit_scale", logit_scale)
+105        self.learn_logit_scale = learn_logit_scale
+106        self.reduction: Literal["mean", "sum", "none"] = reduction
+107
+108    def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
+109        """
+110        Computes the loss.
+111
+112        Args:
+113            x (`torch.Tensor`): prediction
+114            y (`torch.Tensor`): target
+115
+116        Returns:
+117            LossOutput of the loss. Contains a `logit_scale` metric.
+118        """
+119        return LossOutput(
+120            contrastive_loss(x, y, self.logit_scale, self.reduction),
+121            {"logit_scale": self.logit_scale.exp()},
+122        )
+
+ + +

CLIP-like ContrastiveLoss torch module.

+
+ + +
+ +
+ + ContrastiveLoss( logit_scale: torch.Tensor, reduction: Literal['mean', 'sum', 'none'] = 'mean', learn_logit_scale: bool = False) + + + +
+ +
 83    def __init__(
+ 84        self,
+ 85        logit_scale: torch.Tensor,
+ 86        reduction: Literal["mean", "sum", "none"] = "mean",
+ 87        learn_logit_scale: bool = False,
+ 88    ) -> None:
+ 89        """
+ 90        Initializes a contrastive loss.
+ 91
+ 92        Args:
+ 93            logit_scale (`torch.Tensor`): logit_scale tensor.
+ 94            reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the
+ 95                loss. Defaults to `"mean"`.
+ 96            learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale`
+ 97                parameter. Defaults to `False`.
+ 98        """
+ 99        super().__init__()
+100
+101        if learn_logit_scale:
+102            self.logit_scale = torch.nn.Parameter(logit_scale)
+103        else:
+104            self.register_buffer("logit_scale", logit_scale)
+105        self.learn_logit_scale = learn_logit_scale
+106        self.reduction: Literal["mean", "sum", "none"] = reduction
+
+ + +

Initializes a contrastive loss.

+ +
Arguments:
+ +
    +
  • logit_scale (torch.Tensor): logit_scale tensor.
  • +
  • reduction (Literal["mean", "sum", "none"]): reduction to apply to the +loss. Defaults to "mean".
  • +
  • learn_logit_scale (torch.Tensor): whether to learn the logit_scale +parameter. Defaults to False.
  • +
+
+ + +
+
+
+ learn_logit_scale + + +
+ + + + +
+
+
+ reduction: Literal['mean', 'sum', 'none'] + + +
+ + + + +
+
+ +
+ + def + forward( self, x: torch.Tensor, y: torch.Tensor) -> shimmer.modules.domain.LossOutput: + + + +
+ +
108    def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
+109        """
+110        Computes the loss.
+111
+112        Args:
+113            x (`torch.Tensor`): prediction
+114            y (`torch.Tensor`): target
+115
+116        Returns:
+117            LossOutput of the loss. Contains a `logit_scale` metric.
+118        """
+119        return LossOutput(
+120            contrastive_loss(x, y, self.logit_scale, self.reduction),
+121            {"logit_scale": self.logit_scale.exp()},
+122        )
+
+ + +

Computes the loss.

+ +
Arguments:
+ +
    +
  • x (torch.Tensor): prediction
  • +
  • y (torch.Tensor): target
  • +
+ +
Returns:
+ +
+

LossOutput of the loss. Contains a logit_scale metric.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/domain.html b/docs/api/v0.5.1/shimmer/modules/domain.html new file mode 100644 index 00000000..5462b500 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/domain.html @@ -0,0 +1,1220 @@ + + + + + + + shimmer.modules.domain API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.domain

+ + + + + + +
  1from dataclasses import dataclass, field
+  2from typing import Any
+  3
+  4import lightning.pytorch as pl
+  5import torch
+  6
+  7
+  8@dataclass
+  9class LossOutput:
+ 10    """
+ 11    This is a python dataclass use as a returned value for losses.
+ 12    It keeps track of what is used for training (`loss`) and what is used
+ 13    only for logging (`metrics`).
+ 14    """
+ 15
+ 16    loss: torch.Tensor
+ 17    """Loss used during training."""
+ 18
+ 19    metrics: dict[str, torch.Tensor] = field(default_factory=dict)
+ 20    """Some additional metrics to log (not used during training)."""
+ 21
+ 22    def __post_init__(self):
+ 23        if "loss" in self.metrics:
+ 24            raise ValueError("'loss' cannot be a key of metrics.")
+ 25
+ 26    @property
+ 27    def all(self) -> dict[str, torch.Tensor]:
+ 28        """
+ 29        Returns a dict with all metrics and loss with "loss" key.
+ 30        """
+ 31        return {**self.metrics, "loss": self.loss}
+ 32
+ 33
+ 34class DomainModule(pl.LightningModule):
+ 35    """
+ 36    Base class for a DomainModule that defines domain specific modules of the GW.
+ 37    """
+ 38
+ 39    def __init__(
+ 40        self,
+ 41        latent_dim: int,
+ 42    ) -> None:
+ 43        """
+ 44        Initializes a DomainModule.
+ 45
+ 46        Args:
+ 47            latent_dim (`int`): latent dimension of the unimodal module
+ 48        """
+ 49        super().__init__()
+ 50
+ 51        self.latent_dim = latent_dim
+ 52        """The latent dimension of the module."""
+ 53
+ 54    def encode(self, x: Any) -> torch.Tensor:
+ 55        """
+ 56        Encode the domain data into a unimodal representation.
+ 57
+ 58        Args:
+ 59            x (`Any`): data of the domain.
+ 60        Returns:
+ 61            `torch.Tensor`: a unimodal representation.
+ 62        """
+ 63        raise NotImplementedError
+ 64
+ 65    def decode(self, z: torch.Tensor) -> Any:
+ 66        """
+ 67        Decode data from unimodal representation back to the domain data.
+ 68
+ 69        Args:
+ 70            z (`torch.Tensor`): unimodal representation of the domain.
+ 71        Returns:
+ 72            `Any`: the original domain data.
+ 73        """
+ 74        raise NotImplementedError
+ 75
+ 76    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+ 77        """
+ 78        Generic loss computation  the modality.
+ 79
+ 80        Args:
+ 81            pred (`torch.Tensor`): prediction of the model
+ 82            target (`torch.Tensor`): target tensor
+ 83        Results:
+ 84            `LossOutput`: LossOuput with training loss and additional metrics.
+ 85        """
+ 86        raise NotImplementedError
+ 87
+ 88    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+ 89        """
+ 90        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 91        different that the generic loss.
+ 92
+ 93        Args:
+ 94            pred (`torch.Tensor`): prediction of the model
+ 95            target (`torch.Tensor`): target tensor
+ 96        Results:
+ 97            `LossOutput`: LossOuput with training loss and additional metrics.
+ 98        """
+ 99        return self.compute_loss(pred, target)
+100
+101    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+102        """
+103        Computes the loss for a cycle. Override if the cycle loss is
+104        different that the generic loss.
+105
+106        Args:
+107            pred (`torch.Tensor`): prediction of the model
+108            target (`torch.Tensor`): target tensor
+109        Results:
+110            `LossOutput`: LossOuput with training loss and additional metrics.
+111        """
+112        return self.compute_loss(pred, target)
+113
+114    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+115        """
+116        Computes the loss for a translation. Override if the translation loss is
+117        different that the generic loss.
+118
+119        Args:
+120            pred (`torch.Tensor`): prediction of the model
+121            target (`torch.Tensor`): target tensor
+122        Results:
+123            `LossOutput`: LossOuput with training loss and additional metrics.
+124        """
+125        return self.compute_loss(pred, target)
+126
+127    def compute_broadcast_loss(
+128        self, pred: torch.Tensor, target: torch.Tensor
+129    ) -> LossOutput:
+130        """
+131        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+132        different that the generic loss.
+133
+134        Args:
+135            pred (`torch.Tensor`): prediction of the model
+136            target (`torch.Tensor`): target tensor
+137        Results:
+138            `LossOutput`: LossOuput with training loss and additional metrics.
+139        """
+140        return self.compute_loss(pred, target)
+
+ + +
+
+ +
+
@dataclass
+ + class + LossOutput: + + + +
+ +
 9@dataclass
+10class LossOutput:
+11    """
+12    This is a python dataclass use as a returned value for losses.
+13    It keeps track of what is used for training (`loss`) and what is used
+14    only for logging (`metrics`).
+15    """
+16
+17    loss: torch.Tensor
+18    """Loss used during training."""
+19
+20    metrics: dict[str, torch.Tensor] = field(default_factory=dict)
+21    """Some additional metrics to log (not used during training)."""
+22
+23    def __post_init__(self):
+24        if "loss" in self.metrics:
+25            raise ValueError("'loss' cannot be a key of metrics.")
+26
+27    @property
+28    def all(self) -> dict[str, torch.Tensor]:
+29        """
+30        Returns a dict with all metrics and loss with "loss" key.
+31        """
+32        return {**self.metrics, "loss": self.loss}
+
+ + +

This is a python dataclass use as a returned value for losses. +It keeps track of what is used for training (loss) and what is used +only for logging (metrics).

+
+ + +
+
+ + LossOutput(loss: torch.Tensor, metrics: dict[str, torch.Tensor] = <factory>) + + +
+ + + + +
+
+
+ loss: torch.Tensor + + +
+ + +

Loss used during training.

+
+ + +
+
+
+ metrics: dict[str, torch.Tensor] + + +
+ + +

Some additional metrics to log (not used during training).

+
+ + +
+
+ +
+ all: dict[str, torch.Tensor] + + + +
+ +
27    @property
+28    def all(self) -> dict[str, torch.Tensor]:
+29        """
+30        Returns a dict with all metrics and loss with "loss" key.
+31        """
+32        return {**self.metrics, "loss": self.loss}
+
+ + +

Returns a dict with all metrics and loss with "loss" key.

+
+ + +
+
+
+ +
+ + class + DomainModule(lightning.pytorch.core.module.LightningModule): + + + +
+ +
 35class DomainModule(pl.LightningModule):
+ 36    """
+ 37    Base class for a DomainModule that defines domain specific modules of the GW.
+ 38    """
+ 39
+ 40    def __init__(
+ 41        self,
+ 42        latent_dim: int,
+ 43    ) -> None:
+ 44        """
+ 45        Initializes a DomainModule.
+ 46
+ 47        Args:
+ 48            latent_dim (`int`): latent dimension of the unimodal module
+ 49        """
+ 50        super().__init__()
+ 51
+ 52        self.latent_dim = latent_dim
+ 53        """The latent dimension of the module."""
+ 54
+ 55    def encode(self, x: Any) -> torch.Tensor:
+ 56        """
+ 57        Encode the domain data into a unimodal representation.
+ 58
+ 59        Args:
+ 60            x (`Any`): data of the domain.
+ 61        Returns:
+ 62            `torch.Tensor`: a unimodal representation.
+ 63        """
+ 64        raise NotImplementedError
+ 65
+ 66    def decode(self, z: torch.Tensor) -> Any:
+ 67        """
+ 68        Decode data from unimodal representation back to the domain data.
+ 69
+ 70        Args:
+ 71            z (`torch.Tensor`): unimodal representation of the domain.
+ 72        Returns:
+ 73            `Any`: the original domain data.
+ 74        """
+ 75        raise NotImplementedError
+ 76
+ 77    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+ 78        """
+ 79        Generic loss computation  the modality.
+ 80
+ 81        Args:
+ 82            pred (`torch.Tensor`): prediction of the model
+ 83            target (`torch.Tensor`): target tensor
+ 84        Results:
+ 85            `LossOutput`: LossOuput with training loss and additional metrics.
+ 86        """
+ 87        raise NotImplementedError
+ 88
+ 89    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+ 90        """
+ 91        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 92        different that the generic loss.
+ 93
+ 94        Args:
+ 95            pred (`torch.Tensor`): prediction of the model
+ 96            target (`torch.Tensor`): target tensor
+ 97        Results:
+ 98            `LossOutput`: LossOuput with training loss and additional metrics.
+ 99        """
+100        return self.compute_loss(pred, target)
+101
+102    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+103        """
+104        Computes the loss for a cycle. Override if the cycle loss is
+105        different that the generic loss.
+106
+107        Args:
+108            pred (`torch.Tensor`): prediction of the model
+109            target (`torch.Tensor`): target tensor
+110        Results:
+111            `LossOutput`: LossOuput with training loss and additional metrics.
+112        """
+113        return self.compute_loss(pred, target)
+114
+115    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+116        """
+117        Computes the loss for a translation. Override if the translation loss is
+118        different that the generic loss.
+119
+120        Args:
+121            pred (`torch.Tensor`): prediction of the model
+122            target (`torch.Tensor`): target tensor
+123        Results:
+124            `LossOutput`: LossOuput with training loss and additional metrics.
+125        """
+126        return self.compute_loss(pred, target)
+127
+128    def compute_broadcast_loss(
+129        self, pred: torch.Tensor, target: torch.Tensor
+130    ) -> LossOutput:
+131        """
+132        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+133        different that the generic loss.
+134
+135        Args:
+136            pred (`torch.Tensor`): prediction of the model
+137            target (`torch.Tensor`): target tensor
+138        Results:
+139            `LossOutput`: LossOuput with training loss and additional metrics.
+140        """
+141        return self.compute_loss(pred, target)
+
+ + +

Base class for a DomainModule that defines domain specific modules of the GW.

+
+ + +
+ +
+ + DomainModule(latent_dim: int) + + + +
+ +
40    def __init__(
+41        self,
+42        latent_dim: int,
+43    ) -> None:
+44        """
+45        Initializes a DomainModule.
+46
+47        Args:
+48            latent_dim (`int`): latent dimension of the unimodal module
+49        """
+50        super().__init__()
+51
+52        self.latent_dim = latent_dim
+53        """The latent dimension of the module."""
+
+ + +

Initializes a DomainModule.

+ +
Arguments:
+ +
    +
  • latent_dim (int): latent dimension of the unimodal module
  • +
+
+ + +
+
+
+ latent_dim + + +
+ + +

The latent dimension of the module.

+
+ + +
+
+ +
+ + def + encode(self, x: Any) -> torch.Tensor: + + + +
+ +
55    def encode(self, x: Any) -> torch.Tensor:
+56        """
+57        Encode the domain data into a unimodal representation.
+58
+59        Args:
+60            x (`Any`): data of the domain.
+61        Returns:
+62            `torch.Tensor`: a unimodal representation.
+63        """
+64        raise NotImplementedError
+
+ + +

Encode the domain data into a unimodal representation.

+ +
Arguments:
+ +
    +
  • x (Any): data of the domain.
  • +
+ +
Returns:
+ +
+

torch.Tensor: a unimodal representation.

+
+
+ + +
+
+ +
+ + def + decode(self, z: torch.Tensor) -> Any: + + + +
+ +
66    def decode(self, z: torch.Tensor) -> Any:
+67        """
+68        Decode data from unimodal representation back to the domain data.
+69
+70        Args:
+71            z (`torch.Tensor`): unimodal representation of the domain.
+72        Returns:
+73            `Any`: the original domain data.
+74        """
+75        raise NotImplementedError
+
+ + +

Decode data from unimodal representation back to the domain data.

+ +
Arguments:
+ +
    +
  • z (torch.Tensor): unimodal representation of the domain.
  • +
+ +
Returns:
+ +
+

Any: the original domain data.

+
+
+ + +
+
+ +
+ + def + compute_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + + + +
+ +
77    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+78        """
+79        Generic loss computation  the modality.
+80
+81        Args:
+82            pred (`torch.Tensor`): prediction of the model
+83            target (`torch.Tensor`): target tensor
+84        Results:
+85            `LossOutput`: LossOuput with training loss and additional metrics.
+86        """
+87        raise NotImplementedError
+
+ + +

Generic loss computation the modality.

+ +
Arguments:
+ +
    +
  • pred (torch.Tensor): prediction of the model
  • +
  • target (torch.Tensor): target tensor
  • +
+ +
Results:
+ +
+

LossOutput: LossOuput with training loss and additional metrics.

+
+
+ + +
+
+ +
+ + def + compute_dcy_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + + + +
+ +
 89    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+ 90        """
+ 91        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
+ 92        different that the generic loss.
+ 93
+ 94        Args:
+ 95            pred (`torch.Tensor`): prediction of the model
+ 96            target (`torch.Tensor`): target tensor
+ 97        Results:
+ 98            `LossOutput`: LossOuput with training loss and additional metrics.
+ 99        """
+100        return self.compute_loss(pred, target)
+
+ + +

Computes the loss for a demi-cycle. Override if the demi-cycle loss is +different that the generic loss.

+ +
Arguments:
+ +
    +
  • pred (torch.Tensor): prediction of the model
  • +
  • target (torch.Tensor): target tensor
  • +
+ +
Results:
+ +
+

LossOutput: LossOuput with training loss and additional metrics.

+
+
+ + +
+
+ +
+ + def + compute_cy_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + + + +
+ +
102    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+103        """
+104        Computes the loss for a cycle. Override if the cycle loss is
+105        different that the generic loss.
+106
+107        Args:
+108            pred (`torch.Tensor`): prediction of the model
+109            target (`torch.Tensor`): target tensor
+110        Results:
+111            `LossOutput`: LossOuput with training loss and additional metrics.
+112        """
+113        return self.compute_loss(pred, target)
+
+ + +

Computes the loss for a cycle. Override if the cycle loss is +different that the generic loss.

+ +
Arguments:
+ +
    +
  • pred (torch.Tensor): prediction of the model
  • +
  • target (torch.Tensor): target tensor
  • +
+ +
Results:
+ +
+

LossOutput: LossOuput with training loss and additional metrics.

+
+
+ + +
+
+ +
+ + def + compute_tr_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + + + +
+ +
115    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
+116        """
+117        Computes the loss for a translation. Override if the translation loss is
+118        different that the generic loss.
+119
+120        Args:
+121            pred (`torch.Tensor`): prediction of the model
+122            target (`torch.Tensor`): target tensor
+123        Results:
+124            `LossOutput`: LossOuput with training loss and additional metrics.
+125        """
+126        return self.compute_loss(pred, target)
+
+ + +

Computes the loss for a translation. Override if the translation loss is +different that the generic loss.

+ +
Arguments:
+ +
    +
  • pred (torch.Tensor): prediction of the model
  • +
  • target (torch.Tensor): target tensor
  • +
+ +
Results:
+ +
+

LossOutput: LossOuput with training loss and additional metrics.

+
+
+ + +
+
+ +
+ + def + compute_broadcast_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + + + +
+ +
128    def compute_broadcast_loss(
+129        self, pred: torch.Tensor, target: torch.Tensor
+130    ) -> LossOutput:
+131        """
+132        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
+133        different that the generic loss.
+134
+135        Args:
+136            pred (`torch.Tensor`): prediction of the model
+137            target (`torch.Tensor`): target tensor
+138        Results:
+139            `LossOutput`: LossOuput with training loss and additional metrics.
+140        """
+141        return self.compute_loss(pred, target)
+
+ + +

Computes the loss for a broadcast (fusion). Override if the broadcast loss is +different that the generic loss.

+ +
Arguments:
+ +
    +
  • pred (torch.Tensor): prediction of the model
  • +
  • target (torch.Tensor): target tensor
  • +
+ +
Results:
+ +
+

LossOutput: LossOuput with training loss and additional metrics.

+
+
+ + +
+
+
Inherited Members
+
+
lightning.pytorch.core.module.LightningModule
+
CHECKPOINT_HYPER_PARAMS_KEY
+
CHECKPOINT_HYPER_PARAMS_NAME
+
CHECKPOINT_HYPER_PARAMS_TYPE
+
optimizers
+
lr_schedulers
+
trainer
+
fabric
+
example_input_array
+
current_epoch
+
global_step
+
global_rank
+
local_rank
+
on_gpu
+
automatic_optimization
+
strict_loading
+
logger
+
loggers
+
print
+
log
+
log_dict
+
all_gather
+
forward
+
training_step
+
validation_step
+
test_step
+
predict_step
+
configure_callbacks
+
configure_optimizers
+
manual_backward
+
backward
+
toggle_optimizer
+
untoggle_optimizer
+
clip_gradients
+
configure_gradient_clipping
+
lr_scheduler_step
+
optimizer_step
+
optimizer_zero_grad
+
freeze
+
unfreeze
+
to_onnx
+
to_torchscript
+
load_from_checkpoint
+ +
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+
dtype
+
device
+
to
+
cuda
+
cpu
+
type
+
float
+
double
+
half
+ +
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+
save_hyperparameters
+
hparams
+
hparams_initial
+ +
+
lightning.pytorch.core.hooks.ModelHooks
+
on_fit_start
+
on_fit_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_train_batch_start
+
on_train_batch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_validation_model_zero_grad
+
on_validation_model_eval
+
on_validation_model_train
+
on_test_model_eval
+
on_test_model_train
+
on_predict_model_eval
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_before_zero_grad
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
configure_sharded_model
+
configure_model
+ +
+
lightning.pytorch.core.hooks.DataHooks
+
prepare_data_per_node
+
allow_zero_length_dataloader_with_multiple_devices
+
prepare_data
+
setup
+
teardown
+
train_dataloader
+
test_dataloader
+
val_dataloader
+
predict_dataloader
+
transfer_batch_to_device
+
on_before_batch_transfer
+
on_after_batch_transfer
+ +
+
lightning.pytorch.core.hooks.CheckpointHooks
+
on_load_checkpoint
+
on_save_checkpoint
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
ipu
+
xpu
+
bfloat16
+
to_empty
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/global_workspace.html b/docs/api/v0.5.1/shimmer/modules/global_workspace.html new file mode 100644 index 00000000..03727560 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/global_workspace.html @@ -0,0 +1,4260 @@ + + + + + + + shimmer.modules.global_workspace API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.global_workspace

+ + + + + + +
  1from collections.abc import Iterable, Mapping
+  2from pathlib import Path
+  3from typing import Any, Generic, TypedDict, TypeVar, cast
+  4
+  5import torch
+  6from lightning.pytorch import LightningModule
+  7from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
+  8from torch.nn import Module, ModuleDict
+  9from torch.optim.lr_scheduler import OneCycleLR
+ 10
+ 11from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
+ 12from shimmer.modules.domain import DomainModule
+ 13from shimmer.modules.gw_module import (
+ 14    GWModule,
+ 15    GWModuleBase,
+ 16    GWModuleBayesian,
+ 17)
+ 18from shimmer.modules.losses import (
+ 19    BroadcastLossCoefs,
+ 20    GWLosses,
+ 21    GWLosses2Domains,
+ 22    GWLossesBase,
+ 23    GWLossesBayesian,
+ 24    LossCoefs,
+ 25)
+ 26from shimmer.modules.selection import (
+ 27    FixedSharedSelection,
+ 28    RandomSelection,
+ 29    SelectionBase,
+ 30    SingleDomainSelection,
+ 31)
+ 32from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations
+ 33from shimmer.types import (
+ 34    LatentsDomainGroupsDT,
+ 35    LatentsDomainGroupsT,
+ 36    ModelModeT,
+ 37    RawDomainGroupsDT,
+ 38    RawDomainGroupsT,
+ 39    RawDomainGroupT,
+ 40)
+ 41from shimmer.utils import groups_batch_size
+ 42
+ 43
+ 44class SchedulerArgs(TypedDict, total=False):
+ 45    """TypedDict of arguments passed to the OneCycle scheduler"""
+ 46
+ 47    max_lr: float
+ 48    """Maximum learning rate"""
+ 49
+ 50    total_steps: int
+ 51    """Total number of steps"""
+ 52
+ 53
+ 54class GWPredictionsBase(TypedDict):
+ 55    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+ 56
+ 57    states: dict[str, torch.Tensor]
+ 58    """
+ 59    GW state representation from domain groups with only one domain.
+ 60    The key represent the domain's name.
+ 61    """
+ 62
+ 63
+ 64_T_gw_mod = TypeVar("_T_gw_mod", bound=GWModuleBase)
+ 65_T_selection_mod = TypeVar("_T_selection_mod", bound=SelectionBase)
+ 66_T_loss_mod = TypeVar("_T_loss_mod", bound=GWLossesBase)
+ 67
+ 68
+ 69class GlobalWorkspaceBase(
+ 70    Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
+ 71):
+ 72    """
+ 73    Global Workspace Lightning Module.
+ 74
+ 75    This is the base class to build the Global Workspace.
+ 76    """
+ 77
+ 78    def __init__(
+ 79        self,
+ 80        gw_mod: _T_gw_mod,
+ 81        selection_mod: _T_selection_mod,
+ 82        loss_mod: _T_loss_mod,
+ 83        optim_lr: float = 1e-3,
+ 84        optim_weight_decay: float = 0.0,
+ 85        scheduler_args: SchedulerArgs | None = None,
+ 86    ) -> None:
+ 87        """
+ 88        Initializes a GW
+ 89
+ 90        Args:
+ 91            gw_mod (`GWModuleBase`): the GWModule
+ 92            selection_mod (`SelectionBase`): selection module
+ 93            loss_mod (`GWLossesBase`): module to compute the GW losses.
+ 94            optim_lr (`float`): learning rate
+ 95            optim_weight_decay (`float`): weight decay
+ 96            scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
+ 97                scheduler parameters.
+ 98        """
+ 99        super().__init__()
+100        self.save_hyperparameters(
+101            ignore=[
+102                "gw_mod",
+103                "selection_mod",
+104                "domain_mods",
+105                "loss_mod",
+106                "domain_descriptions",
+107                "contrastive_loss",
+108                "cont_loss_bayesian",
+109                "gw_encoders",
+110                "gw_decoders",
+111            ]
+112        )
+113
+114        self.gw_mod = gw_mod
+115        """ a `GWModuleBase` implementation."""
+116
+117        self.selection_mod = selection_mod
+118        """A `SelectionBase` implementation."""
+119
+120        self.loss_mod = loss_mod
+121        """The module that computes losses of the GW"""
+122
+123        self.optim_lr = optim_lr
+124        self.optim_weight_decay = optim_weight_decay
+125        self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
+126        if scheduler_args is not None:
+127            self.scheduler_args.update(scheduler_args)
+128
+129    @property
+130    def domain_mods(self) -> Mapping[str, DomainModule]:
+131        return self.gw_mod.domain_mods
+132
+133    @property
+134    def workspace_dim(self) -> int:
+135        """Dimension of the GW."""
+136        return self.gw_mod.workspace_dim
+137
+138    def encode_and_fuse(
+139        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
+140    ) -> dict[frozenset[str], torch.Tensor]:
+141        """
+142        Encode a group of latent representations into the GW representation.
+143
+144        Args:
+145            x (`LatentsDomainGroupsT`): the input domain representations.
+146            selection_scores (`Mapping[str, torch.Tensor]`):
+147
+148        Returns:
+149            `dict[frozenset[str], torch.Tensor]`: the GW representations.
+150        """
+151        return {
+152            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
+153            for domains, latents in x.items()
+154        }
+155
+156    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
+157        """
+158        Encode a group of latent representations into the pre-fusion GW representation.
+159
+160        Args:
+161            x (`LatentsDomainGroupsT`): the input domain representations.
+162
+163        Returns:
+164            `LatensDomainGroupsDT`: the GW representations.
+165        """
+166        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
+167
+168    def fuse(
+169        self,
+170        x: LatentsDomainGroupsT,
+171        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
+172    ) -> dict[frozenset[str], torch.Tensor]:
+173        """
+174        Fuses a group of latent representations into the GW representation.
+175
+176        Args:
+177            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+178            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+179                selection scores for each group
+180
+181        Returns:
+182            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+183        """
+184        return {
+185            domains: self.gw_mod.fuse(latents, selection_scores[domains])
+186            for domains, latents in x.items()
+187        }
+188
+189    def decode(
+190        self,
+191        z: Mapping[frozenset[str], torch.Tensor],
+192        domains: Iterable[str] | None = None,
+193    ) -> LatentsDomainGroupsDT:
+194        """
+195        Decode the group GW representation into given `domains`.
+196
+197        Args:
+198            z (`torch.Tensor`): the GW representation.
+199            domains (`Iterable[str]`): iterable of domains to decode.
+200
+201        Returns:
+202            `dict[str, torch.Tensor]`: the decoded unimodal representations.
+203        """
+204        return {
+205            domain_names: self.gw_mod.decode(gw_rep, domains)
+206            for domain_names, gw_rep in z.items()
+207        }
+208
+209    def forward(  # type: ignore
+210        self,
+211        latent_domains: LatentsDomainGroupsT,
+212    ) -> GWPredictionsBase:
+213        """
+214        Computes demi-cycles, cycles, and translations.
+215
+216        Args:
+217            latent_domains (`LatentsT`): Groups of domains for the computation.
+218
+219        Returns:
+220            `GWPredictionsBase`: the predictions on the batch.
+221        """
+222
+223        return GWPredictionsBase(states=self.batch_gw_states(latent_domains))
+224
+225    def batch_gw_states(
+226        self, latent_domains: LatentsDomainGroupsT
+227    ) -> dict[str, torch.Tensor]:
+228        """
+229        Comptues GW states of a batch of groups of domains.
+230
+231        Args:
+232            latent_domains (`LatentsT`): the batch of groups of domains
+233
+234        Returns:
+235            `dict[str, torch.Tensor]`: states for each domain.
+236        """
+237        predictions: dict[str, torch.Tensor] = {}
+238        for domains, latents in latent_domains.items():
+239            if len(domains) > 1:
+240                continue
+241            domain_name = list(domains)[0]
+242            z = self.gw_mod.encode_and_fuse(
+243                latents, selection_module=self.selection_mod
+244            )
+245            predictions[domain_name] = z
+246        return predictions
+247
+248    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
+249        """
+250        Encodes a domain from the domain data into the unimodal representation.
+251
+252        This is a convenient proxy for the `DomainModule.encode` method and is
+253        equivalent to:
+254        ```python
+255        self.domain_mods[name].encode(domain)
+256        ```
+257
+258        Args:
+259            domain (`Any`): the domain data
+260            name (`str`): domain name to encode
+261
+262        Returns:
+263            `torch.Tensor`: the domain's unimodal representation.
+264        """
+265        return self.domain_mods[name].encode(domain)
+266
+267    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
+268        """
+269        Encode all domains in the batch.
+270
+271        Args:
+272            batch (`RawDomainGroupsT`): the batch of
+273                domain groups with raw unimodal data to encode into groups of latent
+274                representations.
+275
+276        Returns:
+277            `LatentsDomainGroupsDT`: the domains' unimodal representations.
+278        """
+279        return {
+280            domains: {
+281                name: self.domain_mods[name].encode(domain)
+282                for name, domain in data.items()
+283            }
+284            for domains, data in batch.items()
+285        }
+286
+287    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
+288        """
+289        Decodes a domain from the unimodal representation into the domain data.
+290
+291        This is a convenient proxy for the `DomainModule.encode` method and is
+292        equivalent to:
+293        ```python
+294        self.domain_mods[name].decode(domain)
+295        ```
+296
+297        Args:
+298            domain (`torch.Tensor`): the domain data
+299            name (`str`): domain name to encode
+300
+301        Returns:
+302            `Any`: the domain's raw data.
+303        """
+304        return self.domain_mods[name].decode(domain)
+305
+306    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
+307        """
+308        Decodes all domains in the batch.
+309
+310        Args:
+311            batch (`LatentsDomainGroupsT`): the batch of
+312                domain groups with unimodal latent representation to decode into
+313                groups of raw data.
+314
+315        Returns:
+316            `LatentsDomainGroupsDT`: the domains' raw data.
+317        """
+318        return {
+319            domains: {
+320                name: self.domain_mods[name].decode(domain)
+321                for name, domain in latents.items()
+322            }
+323            for domains, latents in latents_domain.items()
+324        }
+325
+326    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
+327        """
+328        The generic step used in `training_step`, `validation_step` and
+329        `test_step`.
+330
+331        Args:
+332            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+333            mode (`ModelModeT`):
+334
+335        Returns:
+336            `torch.Tensor`: the loss to train on.
+337        """
+338        domain_latents = self.encode_domains(batch)
+339        batch_size = groups_batch_size(domain_latents)
+340
+341        loss_output = self.loss_mod.step(domain_latents, mode)
+342
+343        for name, metric in loss_output.all.items():
+344            self.log(
+345                f"{mode}/{name}",
+346                metric,
+347                batch_size=batch_size,
+348                add_dataloader_idx=False,
+349            )
+350
+351        return loss_output.loss
+352
+353    def validation_step(  # type: ignore
+354        self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
+355    ) -> torch.Tensor:
+356        """Validation step used by lightning"""
+357
+358        batch = {frozenset(data.keys()): data}
+359        for domain in data:
+360            batch[frozenset([domain])] = {domain: data[domain]}
+361        if dataloader_idx == 0:
+362            return self.generic_step(batch, mode="val")
+363        return self.generic_step(batch, mode="val/ood")
+364
+365    def test_step(  # type: ignore
+366        self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
+367    ) -> torch.Tensor:
+368        """Test step used by lightning"""
+369
+370        batch = {frozenset(data.keys()): data}
+371        for domain in data:
+372            batch[frozenset([domain])] = {domain: data[domain]}
+373        if dataloader_idx == 0:
+374            return self.generic_step(batch, mode="test")
+375        return self.generic_step(batch, mode="test/ood")
+376
+377    def training_step(  # type: ignore
+378        self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
+379    ) -> torch.Tensor:
+380        """Training step used by lightning"""
+381
+382        return self.generic_step(batch, mode="train")
+383
+384    def predict_step(  # type: ignore
+385        self, data: Mapping[str, Any], batch_idx: int
+386    ) -> GWPredictionsBase:
+387        """Predict step used by lightning"""
+388
+389        batch = {frozenset(data.keys()): data}
+390        for domain in data:
+391            batch[frozenset([domain])] = {domain: data[domain]}
+392
+393        domain_latents = self.encode_domains(batch)
+394        return self.forward(domain_latents)
+395
+396    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
+397        """
+398        Configure models optimizers.
+399
+400        Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
+401        scheduler.
+402        """
+403
+404        optimizer = torch.optim.AdamW(
+405            self.parameters(),
+406            lr=self.optim_lr,
+407            weight_decay=self.optim_weight_decay,
+408        )
+409
+410        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
+411
+412        return {
+413            "optimizer": optimizer,
+414            "lr_scheduler": {
+415                "scheduler": lr_scheduler,
+416                "interval": "step",
+417            },
+418        }
+419
+420
+421def freeze_domain_modules(
+422    domain_mods: Mapping[str, DomainModule],
+423) -> dict[str, DomainModule]:
+424    """
+425    Freezes weights and set to eval mode the domain modules.
+426
+427    .. note::
+428        The output is casted as `dict[str, DomainModule]` type for better
+429        auto-completion, but is actually a torch `ModuleDict`.
+430
+431    Args:
+432        domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
+433
+434    Returns:
+435        `ModuleDict`: frozen modules.
+436    """
+437
+438    for mod in domain_mods.values():
+439        mod.freeze()
+440    # Cast for better auto-completion at the expense of ModuleDict
+441    return cast(dict[str, DomainModule], ModuleDict(domain_mods))
+442
+443
+444class GWPredictions(GWPredictionsBase):
+445    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+446
+447    demi_cycles: dict[str, torch.Tensor]
+448    """
+449    Demi-cycle predictions of the model for each domain. Only computed on domain
+450    groups with only one domain.
+451    """
+452
+453    cycles: dict[tuple[str, str], torch.Tensor]
+454    """
+455    Cycle predictions of the model from one domain through another one.
+456    Only computed on domain groups with more than one domain.
+457    The keys are tuple with start domain and intermediary domain.
+458    """
+459
+460    translations: dict[tuple[str, str], torch.Tensor]
+461    """
+462    Translation predictions of the model from one domain through another one.
+463
+464    Only computed on domain groups with more than one domain.
+465    The keys are tuples with start domain and target domain.
+466    """
+467
+468
+469class GlobalWorkspace2Domains(
+470    GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains]
+471):
+472    """
+473    A simple 2-domains max flavor of GlobalWorkspaceBase.
+474
+475    This is used to simplify a Global Workspace instanciation and only overrides the
+476    `__init__` method.
+477    """
+478
+479    def __init__(
+480        self,
+481        domain_mods: Mapping[str, DomainModule],
+482        gw_encoders: Mapping[str, Module],
+483        gw_decoders: Mapping[str, Module],
+484        workspace_dim: int,
+485        loss_coefs: LossCoefs,
+486        optim_lr: float = 1e-3,
+487        optim_weight_decay: float = 0.0,
+488        scheduler_args: SchedulerArgs | None = None,
+489        learn_logit_scale: bool = False,
+490        contrastive_loss: ContrastiveLossType | None = None,
+491    ) -> None:
+492        """
+493        Initializes a Global Workspace
+494
+495        Args:
+496            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+497                connected to the GW. Keys are domain names, values are the
+498                `DomainModule`.
+499            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+500                name to a `torch.nn.Module` class which role is to encode a
+501                unimodal latent representations into a GW representation (pre fusion).
+502            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+503                name to a `torch.nn.Module` class which role is to decode a
+504                GW representation into a unimodal latent representations.
+505            workspace_dim (`int`): dimension of the GW.
+506            loss_coefs (`LossCoefs`): loss coefficients
+507            optim_lr (`float`): learning rate
+508            optim_weight_decay (`float`): weight decay
+509            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+510            learn_logit_scale (`bool`): whether to learn the contrastive learning
+511                contrastive loss when using the default contrastive loss.
+512            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+513                function used for alignment. `learn_logit_scale` will not affect custom
+514                contrastive losses.
+515        """
+516        domain_mods = freeze_domain_modules(domain_mods)
+517
+518        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+519        if contrastive_loss is None:
+520            contrastive_loss = ContrastiveLoss(
+521                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+522            )
+523        selection_mod = SingleDomainSelection()
+524        loss_mod = GWLosses2Domains(
+525            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+526        )
+527
+528        super().__init__(
+529            gw_mod,
+530            selection_mod,
+531            loss_mod,
+532            optim_lr,
+533            optim_weight_decay,
+534            scheduler_args,
+535        )
+536
+537    def forward(  # type: ignore
+538        self,
+539        latent_domains: LatentsDomainGroupsT,
+540    ) -> GWPredictions:
+541        """
+542        Computes demi-cycles, cycles, and translations.
+543
+544        Args:
+545            latent_domains (`LatentsT`): Groups of domains for the computation.
+546
+547        Returns:
+548            `GWPredictions`: the predictions on the batch.
+549        """
+550        return GWPredictions(
+551            demi_cycles=batch_demi_cycles(
+552                self.gw_mod, self.selection_mod, latent_domains
+553            ),
+554            cycles=batch_cycles(
+555                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+556            ),
+557            translations=batch_translations(
+558                self.gw_mod, self.selection_mod, latent_domains
+559            ),
+560            **super().forward(latent_domains),
+561        )
+562
+563
+564class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
+565    """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
+566
+567    This is used to simplify a Global Workspace instanciation and only overrides the
+568    `__init__` method.
+569    """
+570
+571    def __init__(
+572        self,
+573        domain_mods: Mapping[str, DomainModule],
+574        gw_encoders: Mapping[str, Module],
+575        gw_decoders: Mapping[str, Module],
+576        workspace_dim: int,
+577        loss_coefs: BroadcastLossCoefs,
+578        selection_temperature: float = 0.2,
+579        optim_lr: float = 1e-3,
+580        optim_weight_decay: float = 0.0,
+581        scheduler_args: SchedulerArgs | None = None,
+582        learn_logit_scale: bool = False,
+583        contrastive_loss: ContrastiveLossType | None = None,
+584    ) -> None:
+585        """
+586        Initializes a Global Workspace
+587
+588        Args:
+589            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+590                connected to the GW. Keys are domain names, values are the
+591                `DomainModule`.
+592            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+593                name to a `torch.nn.Module` class which role is to encode a
+594                unimodal latent representations into a GW representation (pre fusion).
+595            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+596                name to a `torch.nn.Module` class which role is to decode a
+597                GW representation into a unimodal latent representations.
+598            workspace_dim (`int`): dimension of the GW.
+599            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+600            selection_temperature (`float`): temperature value for the RandomSelection
+601                module.
+602            optim_lr (`float`): learning rate
+603            optim_weight_decay (`float`): weight decay
+604            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+605            learn_logit_scale (`bool`): whether to learn the contrastive learning
+606                contrastive loss when using the default contrastive loss.
+607            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+608                function used for alignment. `learn_logit_scale` will not affect custom
+609                contrastive losses.
+610        """
+611        domain_mods = freeze_domain_modules(domain_mods)
+612        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+613
+614        if contrastive_loss is None:
+615            contrastive_loss = ContrastiveLoss(
+616                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+617            )
+618
+619        selection_mod = RandomSelection(selection_temperature)
+620        loss_mod = GWLosses(
+621            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+622        )
+623
+624        super().__init__(
+625            gw_mod,
+626            selection_mod,
+627            loss_mod,
+628            optim_lr,
+629            optim_weight_decay,
+630            scheduler_args,
+631        )
+632
+633    def forward(  # type: ignore
+634        self,
+635        latent_domains: LatentsDomainGroupsT,
+636    ) -> GWPredictions:
+637        """
+638        Computes demi-cycles, cycles, and translations.
+639
+640        Args:
+641            latent_domains (`LatentsT`): Groups of domains for the computation.
+642
+643        Returns:
+644            `GWPredictions`: the predictions on the batch.
+645        """
+646        return GWPredictions(
+647            demi_cycles=batch_demi_cycles(
+648                self.gw_mod, self.selection_mod, latent_domains
+649            ),
+650            cycles=batch_cycles(
+651                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+652            ),
+653            translations=batch_translations(
+654                self.gw_mod, self.selection_mod, latent_domains
+655            ),
+656            # TODO: add other combinations
+657            **super().forward(latent_domains),
+658        )
+659
+660
+661class GlobalWorkspaceBayesian(
+662    GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
+663):
+664    """
+665    A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
+666    prediction.
+667
+668    This is used to simplify a Global Workspace instanciation and only overrides the
+669    `__init__` method.
+670    """
+671
+672    def __init__(
+673        self,
+674        domain_mods: Mapping[str, DomainModule],
+675        gw_encoders: Mapping[str, Module],
+676        gw_decoders: Mapping[str, Module],
+677        workspace_dim: int,
+678        loss_coefs: BroadcastLossCoefs,
+679        sensitivity_selection: float = 1,
+680        sensitivity_precision: float = 1,
+681        optim_lr: float = 1e-3,
+682        optim_weight_decay: float = 0.0,
+683        scheduler_args: SchedulerArgs | None = None,
+684        learn_logit_scale: bool = False,
+685        use_normalized_constrastive: bool = True,
+686        contrastive_loss: ContrastiveLossType | None = None,
+687        precision_softmax_temp: float = 0.01,
+688    ) -> None:
+689        """
+690        Initializes a Global Workspace
+691
+692        Args:
+693            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+694                connected to the GW. Keys are domain names, values are the
+695                `DomainModule`.
+696            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+697                name to a `torch.nn.Module` class which role is to encode a
+698                unimodal latent representations into a GW representation (pre fusion).
+699            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+700                name to a `torch.nn.Module` class which role is to decode a
+701                GW representation into a unimodal latent representations.
+702            workspace_dim (`int`): dimension of the GW.
+703            loss_coefs (`LossCoefs`): loss coefficients
+704            sensitivity_selection (`float`): sensivity coef $c'_1$
+705            sensitivity_precision (`float`): sensitivity coef $c'_2$
+706            optim_lr (`float`): learning rate
+707            optim_weight_decay (`float`): weight decay
+708            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+709            learn_logit_scale (`bool`): whether to learn the contrastive learning
+710                contrastive loss when using the default contrastive loss.
+711            use_normalized_constrastive (`bool`): whether to use the normalized cont
+712                loss by the precision coefs
+713            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+714                function used for alignment. `learn_logit_scale` will not affect custom
+715                contrastive losses.
+716            precision_softmax_temp (`float`): temperature to use in softmax of
+717                precision
+718        """
+719        domain_mods = freeze_domain_modules(domain_mods)
+720
+721        gw_mod = GWModuleBayesian(
+722            domain_mods,
+723            workspace_dim,
+724            gw_encoders,
+725            gw_decoders,
+726            sensitivity_selection,
+727            sensitivity_precision,
+728            precision_softmax_temp,
+729        )
+730
+731        selection_mod = FixedSharedSelection()
+732
+733        contrastive_loss = ContrastiveLoss(
+734            torch.tensor([1]).log(), "mean", learn_logit_scale
+735        )
+736
+737        loss_mod = GWLossesBayesian(
+738            gw_mod,
+739            selection_mod,
+740            domain_mods,
+741            loss_coefs,
+742            contrastive_loss,
+743            use_normalized_constrastive,
+744        )
+745
+746        super().__init__(
+747            gw_mod,
+748            selection_mod,
+749            loss_mod,
+750            optim_lr,
+751            optim_weight_decay,
+752            scheduler_args,
+753        )
+754
+755    def forward(  # type: ignore
+756        self,
+757        latent_domains: LatentsDomainGroupsT,
+758    ) -> GWPredictions:
+759        """
+760        Computes demi-cycles, cycles, and translations.
+761
+762        Args:
+763            latent_domains (`LatentsT`): Groups of domains for the computation.
+764
+765        Returns:
+766            `GWPredictions`: the predictions on the batch.
+767        """
+768        return GWPredictions(
+769            demi_cycles=batch_demi_cycles(
+770                self.gw_mod, self.selection_mod, latent_domains
+771            ),
+772            cycles=batch_cycles(
+773                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+774            ),
+775            translations=batch_translations(
+776                self.gw_mod, self.selection_mod, latent_domains
+777            ),
+778            **super().forward(latent_domains),
+779        )
+780
+781
+782def pretrained_global_workspace(
+783    checkpoint_path: str | Path,
+784    domain_mods: Mapping[str, DomainModule],
+785    gw_encoders: Mapping[str, Module],
+786    gw_decoders: Mapping[str, Module],
+787    workspace_dim: int,
+788    loss_coefs: LossCoefs,
+789    contrastive_fn: ContrastiveLossType,
+790    **kwargs,
+791) -> GlobalWorkspace2Domains:
+792    """
+793    Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
+794
+795    Args:
+796        checkpoint_path (`str | Path`): path to checkpoint
+797        domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+798            connected to the GW. Keys are domain names, values are the
+799            `DomainModule`.
+800        gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+801            name to a `torch.nn.Module` class which role is to encode a
+802            unimodal latent representations into a GW representation (pre fusion).
+803        gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+804            name to a `torch.nn.Module` class which role is to decode a
+805            GW representation into a unimodal latent representations.
+806        workspace_dim (`int`): dimension of the GW.
+807        loss_coefs (`LossCoefs`): loss coefficients
+808        contrastive_loss (`ContrastiveLossType`): a contrastive loss
+809            function used for alignment. `learn_logit_scale` will not affect custom
+810            contrastive losses.
+811        **kwargs: additional arguments to pass to
+812            `GlobalWorkspace.load_from_checkpoint`.
+813
+814    Returns:
+815        `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
+816
+817    Raises:
+818        `TypeError`: if loaded type is not `GlobalWorkspace`.
+819    """
+820    domain_mods = freeze_domain_modules(domain_mods)
+821    gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+822    selection_mod = SingleDomainSelection()
+823    loss_mod = GWLosses2Domains(
+824        gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
+825    )
+826
+827    gw = GlobalWorkspace2Domains.load_from_checkpoint(
+828        checkpoint_path,
+829        gw_mod=gw_mod,
+830        selection_mid=selection_mod,
+831        loss_coefs=loss_coefs,
+832        loss_mod=loss_mod,
+833        **kwargs,
+834    )
+835    if not isinstance(gw, GlobalWorkspace2Domains):
+836        raise TypeError("model should be of type GlobalWorkspace")
+837    return gw
+
+ + +
+
+ +
+ + class + SchedulerArgs(typing.TypedDict): + + + +
+ +
45class SchedulerArgs(TypedDict, total=False):
+46    """TypedDict of arguments passed to the OneCycle scheduler"""
+47
+48    max_lr: float
+49    """Maximum learning rate"""
+50
+51    total_steps: int
+52    """Total number of steps"""
+
+ + +

TypedDict of arguments passed to the OneCycle scheduler

+
+ + +
+
+ max_lr: float + + +
+ + +

Maximum learning rate

+
+ + +
+
+
+ total_steps: int + + +
+ + +

Total number of steps

+
+ + +
+
+
+ +
+ + class + GWPredictionsBase(typing.TypedDict): + + + +
+ +
55class GWPredictionsBase(TypedDict):
+56    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+57
+58    states: dict[str, torch.Tensor]
+59    """
+60    GW state representation from domain groups with only one domain.
+61    The key represent the domain's name.
+62    """
+
+ + +

TypedDict of the output given when calling GlobalWorkspaceBase.predict

+
+ + +
+
+ states: dict[str, torch.Tensor] + + +
+ + +

GW state representation from domain groups with only one domain. +The key represent the domain's name.

+
+ + +
+
+
+ +
+ + class + GlobalWorkspaceBase(typing.Generic[~_T_gw_mod, ~_T_selection_mod, ~_T_loss_mod], lightning.pytorch.core.module.LightningModule): + + + +
+ +
 70class GlobalWorkspaceBase(
+ 71    Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
+ 72):
+ 73    """
+ 74    Global Workspace Lightning Module.
+ 75
+ 76    This is the base class to build the Global Workspace.
+ 77    """
+ 78
+ 79    def __init__(
+ 80        self,
+ 81        gw_mod: _T_gw_mod,
+ 82        selection_mod: _T_selection_mod,
+ 83        loss_mod: _T_loss_mod,
+ 84        optim_lr: float = 1e-3,
+ 85        optim_weight_decay: float = 0.0,
+ 86        scheduler_args: SchedulerArgs | None = None,
+ 87    ) -> None:
+ 88        """
+ 89        Initializes a GW
+ 90
+ 91        Args:
+ 92            gw_mod (`GWModuleBase`): the GWModule
+ 93            selection_mod (`SelectionBase`): selection module
+ 94            loss_mod (`GWLossesBase`): module to compute the GW losses.
+ 95            optim_lr (`float`): learning rate
+ 96            optim_weight_decay (`float`): weight decay
+ 97            scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
+ 98                scheduler parameters.
+ 99        """
+100        super().__init__()
+101        self.save_hyperparameters(
+102            ignore=[
+103                "gw_mod",
+104                "selection_mod",
+105                "domain_mods",
+106                "loss_mod",
+107                "domain_descriptions",
+108                "contrastive_loss",
+109                "cont_loss_bayesian",
+110                "gw_encoders",
+111                "gw_decoders",
+112            ]
+113        )
+114
+115        self.gw_mod = gw_mod
+116        """ a `GWModuleBase` implementation."""
+117
+118        self.selection_mod = selection_mod
+119        """A `SelectionBase` implementation."""
+120
+121        self.loss_mod = loss_mod
+122        """The module that computes losses of the GW"""
+123
+124        self.optim_lr = optim_lr
+125        self.optim_weight_decay = optim_weight_decay
+126        self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
+127        if scheduler_args is not None:
+128            self.scheduler_args.update(scheduler_args)
+129
+130    @property
+131    def domain_mods(self) -> Mapping[str, DomainModule]:
+132        return self.gw_mod.domain_mods
+133
+134    @property
+135    def workspace_dim(self) -> int:
+136        """Dimension of the GW."""
+137        return self.gw_mod.workspace_dim
+138
+139    def encode_and_fuse(
+140        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
+141    ) -> dict[frozenset[str], torch.Tensor]:
+142        """
+143        Encode a group of latent representations into the GW representation.
+144
+145        Args:
+146            x (`LatentsDomainGroupsT`): the input domain representations.
+147            selection_scores (`Mapping[str, torch.Tensor]`):
+148
+149        Returns:
+150            `dict[frozenset[str], torch.Tensor]`: the GW representations.
+151        """
+152        return {
+153            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
+154            for domains, latents in x.items()
+155        }
+156
+157    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
+158        """
+159        Encode a group of latent representations into the pre-fusion GW representation.
+160
+161        Args:
+162            x (`LatentsDomainGroupsT`): the input domain representations.
+163
+164        Returns:
+165            `LatensDomainGroupsDT`: the GW representations.
+166        """
+167        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
+168
+169    def fuse(
+170        self,
+171        x: LatentsDomainGroupsT,
+172        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
+173    ) -> dict[frozenset[str], torch.Tensor]:
+174        """
+175        Fuses a group of latent representations into the GW representation.
+176
+177        Args:
+178            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+179            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+180                selection scores for each group
+181
+182        Returns:
+183            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+184        """
+185        return {
+186            domains: self.gw_mod.fuse(latents, selection_scores[domains])
+187            for domains, latents in x.items()
+188        }
+189
+190    def decode(
+191        self,
+192        z: Mapping[frozenset[str], torch.Tensor],
+193        domains: Iterable[str] | None = None,
+194    ) -> LatentsDomainGroupsDT:
+195        """
+196        Decode the group GW representation into given `domains`.
+197
+198        Args:
+199            z (`torch.Tensor`): the GW representation.
+200            domains (`Iterable[str]`): iterable of domains to decode.
+201
+202        Returns:
+203            `dict[str, torch.Tensor]`: the decoded unimodal representations.
+204        """
+205        return {
+206            domain_names: self.gw_mod.decode(gw_rep, domains)
+207            for domain_names, gw_rep in z.items()
+208        }
+209
+210    def forward(  # type: ignore
+211        self,
+212        latent_domains: LatentsDomainGroupsT,
+213    ) -> GWPredictionsBase:
+214        """
+215        Computes demi-cycles, cycles, and translations.
+216
+217        Args:
+218            latent_domains (`LatentsT`): Groups of domains for the computation.
+219
+220        Returns:
+221            `GWPredictionsBase`: the predictions on the batch.
+222        """
+223
+224        return GWPredictionsBase(states=self.batch_gw_states(latent_domains))
+225
+226    def batch_gw_states(
+227        self, latent_domains: LatentsDomainGroupsT
+228    ) -> dict[str, torch.Tensor]:
+229        """
+230        Comptues GW states of a batch of groups of domains.
+231
+232        Args:
+233            latent_domains (`LatentsT`): the batch of groups of domains
+234
+235        Returns:
+236            `dict[str, torch.Tensor]`: states for each domain.
+237        """
+238        predictions: dict[str, torch.Tensor] = {}
+239        for domains, latents in latent_domains.items():
+240            if len(domains) > 1:
+241                continue
+242            domain_name = list(domains)[0]
+243            z = self.gw_mod.encode_and_fuse(
+244                latents, selection_module=self.selection_mod
+245            )
+246            predictions[domain_name] = z
+247        return predictions
+248
+249    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
+250        """
+251        Encodes a domain from the domain data into the unimodal representation.
+252
+253        This is a convenient proxy for the `DomainModule.encode` method and is
+254        equivalent to:
+255        ```python
+256        self.domain_mods[name].encode(domain)
+257        ```
+258
+259        Args:
+260            domain (`Any`): the domain data
+261            name (`str`): domain name to encode
+262
+263        Returns:
+264            `torch.Tensor`: the domain's unimodal representation.
+265        """
+266        return self.domain_mods[name].encode(domain)
+267
+268    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
+269        """
+270        Encode all domains in the batch.
+271
+272        Args:
+273            batch (`RawDomainGroupsT`): the batch of
+274                domain groups with raw unimodal data to encode into groups of latent
+275                representations.
+276
+277        Returns:
+278            `LatentsDomainGroupsDT`: the domains' unimodal representations.
+279        """
+280        return {
+281            domains: {
+282                name: self.domain_mods[name].encode(domain)
+283                for name, domain in data.items()
+284            }
+285            for domains, data in batch.items()
+286        }
+287
+288    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
+289        """
+290        Decodes a domain from the unimodal representation into the domain data.
+291
+292        This is a convenient proxy for the `DomainModule.encode` method and is
+293        equivalent to:
+294        ```python
+295        self.domain_mods[name].decode(domain)
+296        ```
+297
+298        Args:
+299            domain (`torch.Tensor`): the domain data
+300            name (`str`): domain name to encode
+301
+302        Returns:
+303            `Any`: the domain's raw data.
+304        """
+305        return self.domain_mods[name].decode(domain)
+306
+307    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
+308        """
+309        Decodes all domains in the batch.
+310
+311        Args:
+312            batch (`LatentsDomainGroupsT`): the batch of
+313                domain groups with unimodal latent representation to decode into
+314                groups of raw data.
+315
+316        Returns:
+317            `LatentsDomainGroupsDT`: the domains' raw data.
+318        """
+319        return {
+320            domains: {
+321                name: self.domain_mods[name].decode(domain)
+322                for name, domain in latents.items()
+323            }
+324            for domains, latents in latents_domain.items()
+325        }
+326
+327    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
+328        """
+329        The generic step used in `training_step`, `validation_step` and
+330        `test_step`.
+331
+332        Args:
+333            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+334            mode (`ModelModeT`):
+335
+336        Returns:
+337            `torch.Tensor`: the loss to train on.
+338        """
+339        domain_latents = self.encode_domains(batch)
+340        batch_size = groups_batch_size(domain_latents)
+341
+342        loss_output = self.loss_mod.step(domain_latents, mode)
+343
+344        for name, metric in loss_output.all.items():
+345            self.log(
+346                f"{mode}/{name}",
+347                metric,
+348                batch_size=batch_size,
+349                add_dataloader_idx=False,
+350            )
+351
+352        return loss_output.loss
+353
+354    def validation_step(  # type: ignore
+355        self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
+356    ) -> torch.Tensor:
+357        """Validation step used by lightning"""
+358
+359        batch = {frozenset(data.keys()): data}
+360        for domain in data:
+361            batch[frozenset([domain])] = {domain: data[domain]}
+362        if dataloader_idx == 0:
+363            return self.generic_step(batch, mode="val")
+364        return self.generic_step(batch, mode="val/ood")
+365
+366    def test_step(  # type: ignore
+367        self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
+368    ) -> torch.Tensor:
+369        """Test step used by lightning"""
+370
+371        batch = {frozenset(data.keys()): data}
+372        for domain in data:
+373            batch[frozenset([domain])] = {domain: data[domain]}
+374        if dataloader_idx == 0:
+375            return self.generic_step(batch, mode="test")
+376        return self.generic_step(batch, mode="test/ood")
+377
+378    def training_step(  # type: ignore
+379        self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
+380    ) -> torch.Tensor:
+381        """Training step used by lightning"""
+382
+383        return self.generic_step(batch, mode="train")
+384
+385    def predict_step(  # type: ignore
+386        self, data: Mapping[str, Any], batch_idx: int
+387    ) -> GWPredictionsBase:
+388        """Predict step used by lightning"""
+389
+390        batch = {frozenset(data.keys()): data}
+391        for domain in data:
+392            batch[frozenset([domain])] = {domain: data[domain]}
+393
+394        domain_latents = self.encode_domains(batch)
+395        return self.forward(domain_latents)
+396
+397    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
+398        """
+399        Configure models optimizers.
+400
+401        Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
+402        scheduler.
+403        """
+404
+405        optimizer = torch.optim.AdamW(
+406            self.parameters(),
+407            lr=self.optim_lr,
+408            weight_decay=self.optim_weight_decay,
+409        )
+410
+411        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
+412
+413        return {
+414            "optimizer": optimizer,
+415            "lr_scheduler": {
+416                "scheduler": lr_scheduler,
+417                "interval": "step",
+418            },
+419        }
+
+ + +

Global Workspace Lightning Module.

+ +

This is the base class to build the Global Workspace.

+
+ + +
+
+ gw_mod + + +
+ + +

a GWModuleBase implementation.

+
+ + +
+
+
+ selection_mod + + +
+ + +

A SelectionBase implementation.

+
+ + +
+
+
+ loss_mod + + +
+ + +

The module that computes losses of the GW

+
+ + +
+
+
+ optim_lr + + +
+ + + + +
+
+
+ optim_weight_decay + + +
+ + + + +
+
+
+ scheduler_args + + +
+ + + + +
+
+ +
+ domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule] + + + +
+ +
130    @property
+131    def domain_mods(self) -> Mapping[str, DomainModule]:
+132        return self.gw_mod.domain_mods
+
+ + + + +
+
+ +
+ workspace_dim: int + + + +
+ +
134    @property
+135    def workspace_dim(self) -> int:
+136        """Dimension of the GW."""
+137        return self.gw_mod.workspace_dim
+
+ + +

Dimension of the GW.

+
+ + +
+
+ +
+ + def + encode_and_fuse( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], selection_module: shimmer.modules.selection.SelectionBase) -> dict[frozenset[str], torch.Tensor]: + + + +
+ +
139    def encode_and_fuse(
+140        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
+141    ) -> dict[frozenset[str], torch.Tensor]:
+142        """
+143        Encode a group of latent representations into the GW representation.
+144
+145        Args:
+146            x (`LatentsDomainGroupsT`): the input domain representations.
+147            selection_scores (`Mapping[str, torch.Tensor]`):
+148
+149        Returns:
+150            `dict[frozenset[str], torch.Tensor]`: the GW representations.
+151        """
+152        return {
+153            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
+154            for domains, latents in x.items()
+155        }
+
+ + +

Encode a group of latent representations into the GW representation.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupsT): the input domain representations.
  • +
  • selection_scores (Mapping[str, torch.Tensor]):
  • +
+ +
Returns:
+ +
+

dict[frozenset[str], torch.Tensor]: the GW representations.

+
+
+ + +
+
+ +
+ + def + encode( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, torch.Tensor]]: + + + +
+ +
157    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
+158        """
+159        Encode a group of latent representations into the pre-fusion GW representation.
+160
+161        Args:
+162            x (`LatentsDomainGroupsT`): the input domain representations.
+163
+164        Returns:
+165            `LatensDomainGroupsDT`: the GW representations.
+166        """
+167        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
+
+ + +

Encode a group of latent representations into the pre-fusion GW representation.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupsT): the input domain representations.
  • +
+ +
Returns:
+ +
+

LatensDomainGroupsDT: the GW representations.

+
+
+ + +
+
+ +
+ + def + fuse( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], selection_scores: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], torch.Tensor]: + + + +
+ +
169    def fuse(
+170        self,
+171        x: LatentsDomainGroupsT,
+172        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
+173    ) -> dict[frozenset[str], torch.Tensor]:
+174        """
+175        Fuses a group of latent representations into the GW representation.
+176
+177        Args:
+178            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
+179            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
+180                selection scores for each group
+181
+182        Returns:
+183            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
+184        """
+185        return {
+186            domains: self.gw_mod.fuse(latents, selection_scores[domains])
+187            for domains, latents in x.items()
+188        }
+
+ + +

Fuses a group of latent representations into the GW representation.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupsT): the pre-fusion latent representations
  • +
  • selection_scores (Mapping[frozenset[str], Mapping[str, torch.Tensor]]): selection scores for each group
  • +
+ +
Returns:
+ +
+

dict[frozenset[str], torch.Tensor]: GW representation of each group

+
+
+ + +
+
+ +
+ + def + decode( self, z: collections.abc.Mapping[frozenset[str], torch.Tensor], domains: collections.abc.Iterable[str] | None = None) -> dict[frozenset[str], dict[str, torch.Tensor]]: + + + +
+ +
190    def decode(
+191        self,
+192        z: Mapping[frozenset[str], torch.Tensor],
+193        domains: Iterable[str] | None = None,
+194    ) -> LatentsDomainGroupsDT:
+195        """
+196        Decode the group GW representation into given `domains`.
+197
+198        Args:
+199            z (`torch.Tensor`): the GW representation.
+200            domains (`Iterable[str]`): iterable of domains to decode.
+201
+202        Returns:
+203            `dict[str, torch.Tensor]`: the decoded unimodal representations.
+204        """
+205        return {
+206            domain_names: self.gw_mod.decode(gw_rep, domains)
+207            for domain_names, gw_rep in z.items()
+208        }
+
+ + +

Decode the group GW representation into given domains.

+ +
Arguments:
+ +
    +
  • z (torch.Tensor): the GW representation.
  • +
  • domains (Iterable[str]): iterable of domains to decode.
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: the decoded unimodal representations.

+
+
+ + +
+
+ +
+ + def + batch_gw_states( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
226    def batch_gw_states(
+227        self, latent_domains: LatentsDomainGroupsT
+228    ) -> dict[str, torch.Tensor]:
+229        """
+230        Comptues GW states of a batch of groups of domains.
+231
+232        Args:
+233            latent_domains (`LatentsT`): the batch of groups of domains
+234
+235        Returns:
+236            `dict[str, torch.Tensor]`: states for each domain.
+237        """
+238        predictions: dict[str, torch.Tensor] = {}
+239        for domains, latents in latent_domains.items():
+240            if len(domains) > 1:
+241                continue
+242            domain_name = list(domains)[0]
+243            z = self.gw_mod.encode_and_fuse(
+244                latents, selection_module=self.selection_mod
+245            )
+246            predictions[domain_name] = z
+247        return predictions
+
+ + +

Comptues GW states of a batch of groups of domains.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsT): the batch of groups of domains
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: states for each domain.

+
+
+ + +
+
+ +
+ + def + encode_domain(self, domain: Any, name: str) -> torch.Tensor: + + + +
+ +
249    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
+250        """
+251        Encodes a domain from the domain data into the unimodal representation.
+252
+253        This is a convenient proxy for the `DomainModule.encode` method and is
+254        equivalent to:
+255        ```python
+256        self.domain_mods[name].encode(domain)
+257        ```
+258
+259        Args:
+260            domain (`Any`): the domain data
+261            name (`str`): domain name to encode
+262
+263        Returns:
+264            `torch.Tensor`: the domain's unimodal representation.
+265        """
+266        return self.domain_mods[name].encode(domain)
+
+ + +

Encodes a domain from the domain data into the unimodal representation.

+ +

This is a convenient proxy for the DomainModule.encode method and is +equivalent to:

+ +
+
self.domain_mods[name].encode(domain)
+
+
+ +
Arguments:
+ +
    +
  • domain (Any): the domain data
  • +
  • name (str): domain name to encode
  • +
+ +
Returns:
+ +
+

torch.Tensor: the domain's unimodal representation.

+
+
+ + +
+
+ +
+ + def + encode_domains( self, batch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]) -> dict[frozenset[str], dict[str, torch.Tensor]]: + + + +
+ +
268    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
+269        """
+270        Encode all domains in the batch.
+271
+272        Args:
+273            batch (`RawDomainGroupsT`): the batch of
+274                domain groups with raw unimodal data to encode into groups of latent
+275                representations.
+276
+277        Returns:
+278            `LatentsDomainGroupsDT`: the domains' unimodal representations.
+279        """
+280        return {
+281            domains: {
+282                name: self.domain_mods[name].encode(domain)
+283                for name, domain in data.items()
+284            }
+285            for domains, data in batch.items()
+286        }
+
+ + +

Encode all domains in the batch.

+ +
Arguments:
+ +
    +
  • batch (RawDomainGroupsT): the batch of +domain groups with raw unimodal data to encode into groups of latent +representations.
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupsDT: the domains' unimodal representations.

+
+
+ + +
+
+ +
+ + def + decode_domain(self, domain: torch.Tensor, name: str) -> Any: + + + +
+ +
288    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
+289        """
+290        Decodes a domain from the unimodal representation into the domain data.
+291
+292        This is a convenient proxy for the `DomainModule.encode` method and is
+293        equivalent to:
+294        ```python
+295        self.domain_mods[name].decode(domain)
+296        ```
+297
+298        Args:
+299            domain (`torch.Tensor`): the domain data
+300            name (`str`): domain name to encode
+301
+302        Returns:
+303            `Any`: the domain's raw data.
+304        """
+305        return self.domain_mods[name].decode(domain)
+
+ + +

Decodes a domain from the unimodal representation into the domain data.

+ +

This is a convenient proxy for the DomainModule.encode method and is +equivalent to:

+ +
+
self.domain_mods[name].decode(domain)
+
+
+ +
Arguments:
+ +
    +
  • domain (torch.Tensor): the domain data
  • +
  • name (str): domain name to encode
  • +
+ +
Returns:
+ +
+

Any: the domain's raw data.

+
+
+ + +
+
+ +
+ + def + decode_domains( self, latents_domain: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, typing.Any]]: + + + +
+ +
307    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
+308        """
+309        Decodes all domains in the batch.
+310
+311        Args:
+312            batch (`LatentsDomainGroupsT`): the batch of
+313                domain groups with unimodal latent representation to decode into
+314                groups of raw data.
+315
+316        Returns:
+317            `LatentsDomainGroupsDT`: the domains' raw data.
+318        """
+319        return {
+320            domains: {
+321                name: self.domain_mods[name].decode(domain)
+322                for name, domain in latents.items()
+323            }
+324            for domains, latents in latents_domain.items()
+325        }
+
+ + +

Decodes all domains in the batch.

+ +
Arguments:
+ +
    +
  • batch (LatentsDomainGroupsT): the batch of +domain groups with unimodal latent representation to decode into +groups of raw data.
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupsDT: the domains' raw data.

+
+
+ + +
+
+ +
+ + def + generic_step( self, batch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> torch.Tensor: + + + +
+ +
327    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
+328        """
+329        The generic step used in `training_step`, `validation_step` and
+330        `test_step`.
+331
+332        Args:
+333            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
+334            mode (`ModelModeT`):
+335
+336        Returns:
+337            `torch.Tensor`: the loss to train on.
+338        """
+339        domain_latents = self.encode_domains(batch)
+340        batch_size = groups_batch_size(domain_latents)
+341
+342        loss_output = self.loss_mod.step(domain_latents, mode)
+343
+344        for name, metric in loss_output.all.items():
+345            self.log(
+346                f"{mode}/{name}",
+347                metric,
+348                batch_size=batch_size,
+349                add_dataloader_idx=False,
+350            )
+351
+352        return loss_output.loss
+
+ + +

The generic step used in training_step, validation_step and +test_step.

+ +
Arguments:
+ +
    +
  • batch (RawDomainGroupsT): the batch of groups of raw unimodal data.
  • +
  • mode (ModelModeT):
  • +
+ +
Returns:
+ +
+

torch.Tensor: the loss to train on.

+
+
+ + +
+
+
Inherited Members
+
+
lightning.pytorch.core.module.LightningModule
+
LightningModule
+
CHECKPOINT_HYPER_PARAMS_KEY
+
CHECKPOINT_HYPER_PARAMS_NAME
+
CHECKPOINT_HYPER_PARAMS_TYPE
+
optimizers
+
lr_schedulers
+
trainer
+
fabric
+
example_input_array
+
current_epoch
+
global_step
+
global_rank
+
local_rank
+
on_gpu
+
automatic_optimization
+
strict_loading
+
logger
+
loggers
+
print
+
log
+
log_dict
+
all_gather
+
forward
+
training_step
+
validation_step
+
test_step
+
predict_step
+
configure_callbacks
+
configure_optimizers
+
manual_backward
+
backward
+
toggle_optimizer
+
untoggle_optimizer
+
clip_gradients
+
configure_gradient_clipping
+
lr_scheduler_step
+
optimizer_step
+
optimizer_zero_grad
+
freeze
+
unfreeze
+
to_onnx
+
to_torchscript
+
load_from_checkpoint
+ +
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+
dtype
+
device
+
to
+
cuda
+
cpu
+
type
+
float
+
double
+
half
+ +
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+
save_hyperparameters
+
hparams
+
hparams_initial
+ +
+
lightning.pytorch.core.hooks.ModelHooks
+
on_fit_start
+
on_fit_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_train_batch_start
+
on_train_batch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_validation_model_zero_grad
+
on_validation_model_eval
+
on_validation_model_train
+
on_test_model_eval
+
on_test_model_train
+
on_predict_model_eval
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_before_zero_grad
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
configure_sharded_model
+
configure_model
+ +
+
lightning.pytorch.core.hooks.DataHooks
+
prepare_data_per_node
+
allow_zero_length_dataloader_with_multiple_devices
+
prepare_data
+
setup
+
teardown
+
train_dataloader
+
test_dataloader
+
val_dataloader
+
predict_dataloader
+
transfer_batch_to_device
+
on_before_batch_transfer
+
on_after_batch_transfer
+ +
+
lightning.pytorch.core.hooks.CheckpointHooks
+
on_load_checkpoint
+
on_save_checkpoint
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
ipu
+
xpu
+
bfloat16
+
to_empty
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + def + freeze_domain_modules( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]) -> dict[str, shimmer.modules.domain.DomainModule]: + + + +
+ +
422def freeze_domain_modules(
+423    domain_mods: Mapping[str, DomainModule],
+424) -> dict[str, DomainModule]:
+425    """
+426    Freezes weights and set to eval mode the domain modules.
+427
+428    .. note::
+429        The output is casted as `dict[str, DomainModule]` type for better
+430        auto-completion, but is actually a torch `ModuleDict`.
+431
+432    Args:
+433        domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
+434
+435    Returns:
+436        `ModuleDict`: frozen modules.
+437    """
+438
+439    for mod in domain_mods.values():
+440        mod.freeze()
+441    # Cast for better auto-completion at the expense of ModuleDict
+442    return cast(dict[str, DomainModule], ModuleDict(domain_mods))
+
+ + +

Freezes weights and set to eval mode the domain modules.

+ +
+ +

The output is casted as dict[str, DomainModule] type for better +auto-completion, but is actually a torch ModuleDict.

+ +
+ +
Arguments:
+ +
    +
  • domain_mods (Mapping[str, DomainModule]): mapping of domain modules to freeze
  • +
+ +
Returns:
+ +
+

ModuleDict: frozen modules.

+
+
+ + +
+
+ +
+ + class + GWPredictions(builtins.dict): + + + +
+ +
445class GWPredictions(GWPredictionsBase):
+446    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
+447
+448    demi_cycles: dict[str, torch.Tensor]
+449    """
+450    Demi-cycle predictions of the model for each domain. Only computed on domain
+451    groups with only one domain.
+452    """
+453
+454    cycles: dict[tuple[str, str], torch.Tensor]
+455    """
+456    Cycle predictions of the model from one domain through another one.
+457    Only computed on domain groups with more than one domain.
+458    The keys are tuple with start domain and intermediary domain.
+459    """
+460
+461    translations: dict[tuple[str, str], torch.Tensor]
+462    """
+463    Translation predictions of the model from one domain through another one.
+464
+465    Only computed on domain groups with more than one domain.
+466    The keys are tuples with start domain and target domain.
+467    """
+
+ + +

TypedDict of the output given when calling GlobalWorkspaceBase.predict

+
+ + +
+
+ demi_cycles: dict[str, torch.Tensor] + + +
+ + +

Demi-cycle predictions of the model for each domain. Only computed on domain +groups with only one domain.

+
+ + +
+
+
+ cycles: dict[tuple[str, str], torch.Tensor] + + +
+ + +

Cycle predictions of the model from one domain through another one. +Only computed on domain groups with more than one domain. +The keys are tuple with start domain and intermediary domain.

+
+ + +
+
+
+ translations: dict[tuple[str, str], torch.Tensor] + + +
+ + +

Translation predictions of the model from one domain through another one.

+ +

Only computed on domain groups with more than one domain. +The keys are tuples with start domain and target domain.

+
+ + +
+
+
+ states: dict[str, torch.Tensor] + + +
+ + + + +
+
+
Inherited Members
+
+
builtins.dict
+
get
+
setdefault
+
pop
+
popitem
+
keys
+
items
+
values
+
update
+
fromkeys
+
clear
+
copy
+ +
+
+
+
+
+ + + +
470class GlobalWorkspace2Domains(
+471    GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains]
+472):
+473    """
+474    A simple 2-domains max flavor of GlobalWorkspaceBase.
+475
+476    This is used to simplify a Global Workspace instanciation and only overrides the
+477    `__init__` method.
+478    """
+479
+480    def __init__(
+481        self,
+482        domain_mods: Mapping[str, DomainModule],
+483        gw_encoders: Mapping[str, Module],
+484        gw_decoders: Mapping[str, Module],
+485        workspace_dim: int,
+486        loss_coefs: LossCoefs,
+487        optim_lr: float = 1e-3,
+488        optim_weight_decay: float = 0.0,
+489        scheduler_args: SchedulerArgs | None = None,
+490        learn_logit_scale: bool = False,
+491        contrastive_loss: ContrastiveLossType | None = None,
+492    ) -> None:
+493        """
+494        Initializes a Global Workspace
+495
+496        Args:
+497            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+498                connected to the GW. Keys are domain names, values are the
+499                `DomainModule`.
+500            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+501                name to a `torch.nn.Module` class which role is to encode a
+502                unimodal latent representations into a GW representation (pre fusion).
+503            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+504                name to a `torch.nn.Module` class which role is to decode a
+505                GW representation into a unimodal latent representations.
+506            workspace_dim (`int`): dimension of the GW.
+507            loss_coefs (`LossCoefs`): loss coefficients
+508            optim_lr (`float`): learning rate
+509            optim_weight_decay (`float`): weight decay
+510            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+511            learn_logit_scale (`bool`): whether to learn the contrastive learning
+512                contrastive loss when using the default contrastive loss.
+513            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+514                function used for alignment. `learn_logit_scale` will not affect custom
+515                contrastive losses.
+516        """
+517        domain_mods = freeze_domain_modules(domain_mods)
+518
+519        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+520        if contrastive_loss is None:
+521            contrastive_loss = ContrastiveLoss(
+522                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+523            )
+524        selection_mod = SingleDomainSelection()
+525        loss_mod = GWLosses2Domains(
+526            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+527        )
+528
+529        super().__init__(
+530            gw_mod,
+531            selection_mod,
+532            loss_mod,
+533            optim_lr,
+534            optim_weight_decay,
+535            scheduler_args,
+536        )
+537
+538    def forward(  # type: ignore
+539        self,
+540        latent_domains: LatentsDomainGroupsT,
+541    ) -> GWPredictions:
+542        """
+543        Computes demi-cycles, cycles, and translations.
+544
+545        Args:
+546            latent_domains (`LatentsT`): Groups of domains for the computation.
+547
+548        Returns:
+549            `GWPredictions`: the predictions on the batch.
+550        """
+551        return GWPredictions(
+552            demi_cycles=batch_demi_cycles(
+553                self.gw_mod, self.selection_mod, latent_domains
+554            ),
+555            cycles=batch_cycles(
+556                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+557            ),
+558            translations=batch_translations(
+559                self.gw_mod, self.selection_mod, latent_domains
+560            ),
+561            **super().forward(latent_domains),
+562        )
+
+ + +

A simple 2-domains max flavor of GlobalWorkspaceBase.

+ +

This is used to simplify a Global Workspace instanciation and only overrides the +__init__ method.

+
+ + +
+ +
+ + GlobalWorkspace2Domains( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.LossCoefs, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None) + + + +
+ +
480    def __init__(
+481        self,
+482        domain_mods: Mapping[str, DomainModule],
+483        gw_encoders: Mapping[str, Module],
+484        gw_decoders: Mapping[str, Module],
+485        workspace_dim: int,
+486        loss_coefs: LossCoefs,
+487        optim_lr: float = 1e-3,
+488        optim_weight_decay: float = 0.0,
+489        scheduler_args: SchedulerArgs | None = None,
+490        learn_logit_scale: bool = False,
+491        contrastive_loss: ContrastiveLossType | None = None,
+492    ) -> None:
+493        """
+494        Initializes a Global Workspace
+495
+496        Args:
+497            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+498                connected to the GW. Keys are domain names, values are the
+499                `DomainModule`.
+500            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+501                name to a `torch.nn.Module` class which role is to encode a
+502                unimodal latent representations into a GW representation (pre fusion).
+503            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+504                name to a `torch.nn.Module` class which role is to decode a
+505                GW representation into a unimodal latent representations.
+506            workspace_dim (`int`): dimension of the GW.
+507            loss_coefs (`LossCoefs`): loss coefficients
+508            optim_lr (`float`): learning rate
+509            optim_weight_decay (`float`): weight decay
+510            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+511            learn_logit_scale (`bool`): whether to learn the contrastive learning
+512                contrastive loss when using the default contrastive loss.
+513            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+514                function used for alignment. `learn_logit_scale` will not affect custom
+515                contrastive losses.
+516        """
+517        domain_mods = freeze_domain_modules(domain_mods)
+518
+519        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+520        if contrastive_loss is None:
+521            contrastive_loss = ContrastiveLoss(
+522                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+523            )
+524        selection_mod = SingleDomainSelection()
+525        loss_mod = GWLosses2Domains(
+526            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+527        )
+528
+529        super().__init__(
+530            gw_mod,
+531            selection_mod,
+532            loss_mod,
+533            optim_lr,
+534            optim_weight_decay,
+535            scheduler_args,
+536        )
+
+ + +

Initializes a Global Workspace

+ +
Arguments:
+ +
    +
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains +connected to the GW. Keys are domain names, values are the +DomainModule.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to encode a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to decode a +GW representation into a unimodal latent representations.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • loss_coefs (LossCoefs): loss coefficients
  • +
  • optim_lr (float): learning rate
  • +
  • optim_weight_decay (float): weight decay
  • +
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • +
  • learn_logit_scale (bool): whether to learn the contrastive learning +contrastive loss when using the default contrastive loss.
  • +
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss +function used for alignment. learn_logit_scale will not affect custom +contrastive losses.
  • +
+
+ + +
+
+ +
+ + def + forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions: + + + +
+ +
538    def forward(  # type: ignore
+539        self,
+540        latent_domains: LatentsDomainGroupsT,
+541    ) -> GWPredictions:
+542        """
+543        Computes demi-cycles, cycles, and translations.
+544
+545        Args:
+546            latent_domains (`LatentsT`): Groups of domains for the computation.
+547
+548        Returns:
+549            `GWPredictions`: the predictions on the batch.
+550        """
+551        return GWPredictions(
+552            demi_cycles=batch_demi_cycles(
+553                self.gw_mod, self.selection_mod, latent_domains
+554            ),
+555            cycles=batch_cycles(
+556                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+557            ),
+558            translations=batch_translations(
+559                self.gw_mod, self.selection_mod, latent_domains
+560            ),
+561            **super().forward(latent_domains),
+562        )
+
+ + +

Computes demi-cycles, cycles, and translations.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsT): Groups of domains for the computation.
  • +
+ +
Returns:
+ +
+

GWPredictions: the predictions on the batch.

+
+
+ + +
+
+
Inherited Members
+
+ +
lightning.pytorch.core.module.LightningModule
+
CHECKPOINT_HYPER_PARAMS_KEY
+
CHECKPOINT_HYPER_PARAMS_NAME
+
CHECKPOINT_HYPER_PARAMS_TYPE
+
optimizers
+
lr_schedulers
+
trainer
+
fabric
+
example_input_array
+
current_epoch
+
global_step
+
global_rank
+
local_rank
+
on_gpu
+
automatic_optimization
+
strict_loading
+
logger
+
loggers
+
print
+
log
+
log_dict
+
all_gather
+
configure_callbacks
+
manual_backward
+
backward
+
toggle_optimizer
+
untoggle_optimizer
+
clip_gradients
+
configure_gradient_clipping
+
lr_scheduler_step
+
optimizer_step
+
optimizer_zero_grad
+
freeze
+
unfreeze
+
to_onnx
+
to_torchscript
+
load_from_checkpoint
+ +
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+
dtype
+
device
+
to
+
cuda
+
cpu
+
type
+
float
+
double
+
half
+ +
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+
save_hyperparameters
+
hparams
+
hparams_initial
+ +
+
lightning.pytorch.core.hooks.ModelHooks
+
on_fit_start
+
on_fit_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_train_batch_start
+
on_train_batch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_validation_model_zero_grad
+
on_validation_model_eval
+
on_validation_model_train
+
on_test_model_eval
+
on_test_model_train
+
on_predict_model_eval
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_before_zero_grad
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
configure_sharded_model
+
configure_model
+ +
+
lightning.pytorch.core.hooks.DataHooks
+
prepare_data_per_node
+
allow_zero_length_dataloader_with_multiple_devices
+
prepare_data
+
setup
+
teardown
+
train_dataloader
+
test_dataloader
+
val_dataloader
+
predict_dataloader
+
transfer_batch_to_device
+
on_before_batch_transfer
+
on_after_batch_transfer
+ +
+
lightning.pytorch.core.hooks.CheckpointHooks
+
on_load_checkpoint
+
on_save_checkpoint
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
ipu
+
xpu
+
bfloat16
+
to_empty
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + + +
565class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
+566    """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
+567
+568    This is used to simplify a Global Workspace instanciation and only overrides the
+569    `__init__` method.
+570    """
+571
+572    def __init__(
+573        self,
+574        domain_mods: Mapping[str, DomainModule],
+575        gw_encoders: Mapping[str, Module],
+576        gw_decoders: Mapping[str, Module],
+577        workspace_dim: int,
+578        loss_coefs: BroadcastLossCoefs,
+579        selection_temperature: float = 0.2,
+580        optim_lr: float = 1e-3,
+581        optim_weight_decay: float = 0.0,
+582        scheduler_args: SchedulerArgs | None = None,
+583        learn_logit_scale: bool = False,
+584        contrastive_loss: ContrastiveLossType | None = None,
+585    ) -> None:
+586        """
+587        Initializes a Global Workspace
+588
+589        Args:
+590            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+591                connected to the GW. Keys are domain names, values are the
+592                `DomainModule`.
+593            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+594                name to a `torch.nn.Module` class which role is to encode a
+595                unimodal latent representations into a GW representation (pre fusion).
+596            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+597                name to a `torch.nn.Module` class which role is to decode a
+598                GW representation into a unimodal latent representations.
+599            workspace_dim (`int`): dimension of the GW.
+600            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+601            selection_temperature (`float`): temperature value for the RandomSelection
+602                module.
+603            optim_lr (`float`): learning rate
+604            optim_weight_decay (`float`): weight decay
+605            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+606            learn_logit_scale (`bool`): whether to learn the contrastive learning
+607                contrastive loss when using the default contrastive loss.
+608            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+609                function used for alignment. `learn_logit_scale` will not affect custom
+610                contrastive losses.
+611        """
+612        domain_mods = freeze_domain_modules(domain_mods)
+613        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+614
+615        if contrastive_loss is None:
+616            contrastive_loss = ContrastiveLoss(
+617                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+618            )
+619
+620        selection_mod = RandomSelection(selection_temperature)
+621        loss_mod = GWLosses(
+622            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+623        )
+624
+625        super().__init__(
+626            gw_mod,
+627            selection_mod,
+628            loss_mod,
+629            optim_lr,
+630            optim_weight_decay,
+631            scheduler_args,
+632        )
+633
+634    def forward(  # type: ignore
+635        self,
+636        latent_domains: LatentsDomainGroupsT,
+637    ) -> GWPredictions:
+638        """
+639        Computes demi-cycles, cycles, and translations.
+640
+641        Args:
+642            latent_domains (`LatentsT`): Groups of domains for the computation.
+643
+644        Returns:
+645            `GWPredictions`: the predictions on the batch.
+646        """
+647        return GWPredictions(
+648            demi_cycles=batch_demi_cycles(
+649                self.gw_mod, self.selection_mod, latent_domains
+650            ),
+651            cycles=batch_cycles(
+652                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+653            ),
+654            translations=batch_translations(
+655                self.gw_mod, self.selection_mod, latent_domains
+656            ),
+657            # TODO: add other combinations
+658            **super().forward(latent_domains),
+659        )
+
+ + +

The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.

+ +

This is used to simplify a Global Workspace instanciation and only overrides the +__init__ method.

+
+ + +
+ +
+ + GlobalWorkspace( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.BroadcastLossCoefs, selection_temperature: float = 0.2, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None) + + + +
+ +
572    def __init__(
+573        self,
+574        domain_mods: Mapping[str, DomainModule],
+575        gw_encoders: Mapping[str, Module],
+576        gw_decoders: Mapping[str, Module],
+577        workspace_dim: int,
+578        loss_coefs: BroadcastLossCoefs,
+579        selection_temperature: float = 0.2,
+580        optim_lr: float = 1e-3,
+581        optim_weight_decay: float = 0.0,
+582        scheduler_args: SchedulerArgs | None = None,
+583        learn_logit_scale: bool = False,
+584        contrastive_loss: ContrastiveLossType | None = None,
+585    ) -> None:
+586        """
+587        Initializes a Global Workspace
+588
+589        Args:
+590            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+591                connected to the GW. Keys are domain names, values are the
+592                `DomainModule`.
+593            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+594                name to a `torch.nn.Module` class which role is to encode a
+595                unimodal latent representations into a GW representation (pre fusion).
+596            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+597                name to a `torch.nn.Module` class which role is to decode a
+598                GW representation into a unimodal latent representations.
+599            workspace_dim (`int`): dimension of the GW.
+600            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
+601            selection_temperature (`float`): temperature value for the RandomSelection
+602                module.
+603            optim_lr (`float`): learning rate
+604            optim_weight_decay (`float`): weight decay
+605            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+606            learn_logit_scale (`bool`): whether to learn the contrastive learning
+607                contrastive loss when using the default contrastive loss.
+608            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+609                function used for alignment. `learn_logit_scale` will not affect custom
+610                contrastive losses.
+611        """
+612        domain_mods = freeze_domain_modules(domain_mods)
+613        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+614
+615        if contrastive_loss is None:
+616            contrastive_loss = ContrastiveLoss(
+617                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
+618            )
+619
+620        selection_mod = RandomSelection(selection_temperature)
+621        loss_mod = GWLosses(
+622            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
+623        )
+624
+625        super().__init__(
+626            gw_mod,
+627            selection_mod,
+628            loss_mod,
+629            optim_lr,
+630            optim_weight_decay,
+631            scheduler_args,
+632        )
+
+ + +

Initializes a Global Workspace

+ +
Arguments:
+ +
    +
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains +connected to the GW. Keys are domain names, values are the +DomainModule.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to encode a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to decode a +GW representation into a unimodal latent representations.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • loss_coefs (BroadcastLossCoefs): loss coefs for the losses.
  • +
  • selection_temperature (float): temperature value for the RandomSelection +module.
  • +
  • optim_lr (float): learning rate
  • +
  • optim_weight_decay (float): weight decay
  • +
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • +
  • learn_logit_scale (bool): whether to learn the contrastive learning +contrastive loss when using the default contrastive loss.
  • +
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss +function used for alignment. learn_logit_scale will not affect custom +contrastive losses.
  • +
+
+ + +
+
+ +
+ + def + forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions: + + + +
+ +
634    def forward(  # type: ignore
+635        self,
+636        latent_domains: LatentsDomainGroupsT,
+637    ) -> GWPredictions:
+638        """
+639        Computes demi-cycles, cycles, and translations.
+640
+641        Args:
+642            latent_domains (`LatentsT`): Groups of domains for the computation.
+643
+644        Returns:
+645            `GWPredictions`: the predictions on the batch.
+646        """
+647        return GWPredictions(
+648            demi_cycles=batch_demi_cycles(
+649                self.gw_mod, self.selection_mod, latent_domains
+650            ),
+651            cycles=batch_cycles(
+652                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+653            ),
+654            translations=batch_translations(
+655                self.gw_mod, self.selection_mod, latent_domains
+656            ),
+657            # TODO: add other combinations
+658            **super().forward(latent_domains),
+659        )
+
+ + +

Computes demi-cycles, cycles, and translations.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsT): Groups of domains for the computation.
  • +
+ +
Returns:
+ +
+

GWPredictions: the predictions on the batch.

+
+
+ + +
+
+
Inherited Members
+
+ +
lightning.pytorch.core.module.LightningModule
+
CHECKPOINT_HYPER_PARAMS_KEY
+
CHECKPOINT_HYPER_PARAMS_NAME
+
CHECKPOINT_HYPER_PARAMS_TYPE
+
optimizers
+
lr_schedulers
+
trainer
+
fabric
+
example_input_array
+
current_epoch
+
global_step
+
global_rank
+
local_rank
+
on_gpu
+
automatic_optimization
+
strict_loading
+
logger
+
loggers
+
print
+
log
+
log_dict
+
all_gather
+
configure_callbacks
+
manual_backward
+
backward
+
toggle_optimizer
+
untoggle_optimizer
+
clip_gradients
+
configure_gradient_clipping
+
lr_scheduler_step
+
optimizer_step
+
optimizer_zero_grad
+
freeze
+
unfreeze
+
to_onnx
+
to_torchscript
+
load_from_checkpoint
+ +
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+
dtype
+
device
+
to
+
cuda
+
cpu
+
type
+
float
+
double
+
half
+ +
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+
save_hyperparameters
+
hparams
+
hparams_initial
+ +
+
lightning.pytorch.core.hooks.ModelHooks
+
on_fit_start
+
on_fit_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_train_batch_start
+
on_train_batch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_validation_model_zero_grad
+
on_validation_model_eval
+
on_validation_model_train
+
on_test_model_eval
+
on_test_model_train
+
on_predict_model_eval
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_before_zero_grad
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
configure_sharded_model
+
configure_model
+ +
+
lightning.pytorch.core.hooks.DataHooks
+
prepare_data_per_node
+
allow_zero_length_dataloader_with_multiple_devices
+
prepare_data
+
setup
+
teardown
+
train_dataloader
+
test_dataloader
+
val_dataloader
+
predict_dataloader
+
transfer_batch_to_device
+
on_before_batch_transfer
+
on_after_batch_transfer
+ +
+
lightning.pytorch.core.hooks.CheckpointHooks
+
on_load_checkpoint
+
on_save_checkpoint
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
ipu
+
xpu
+
bfloat16
+
to_empty
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + + +
662class GlobalWorkspaceBayesian(
+663    GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
+664):
+665    """
+666    A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
+667    prediction.
+668
+669    This is used to simplify a Global Workspace instanciation and only overrides the
+670    `__init__` method.
+671    """
+672
+673    def __init__(
+674        self,
+675        domain_mods: Mapping[str, DomainModule],
+676        gw_encoders: Mapping[str, Module],
+677        gw_decoders: Mapping[str, Module],
+678        workspace_dim: int,
+679        loss_coefs: BroadcastLossCoefs,
+680        sensitivity_selection: float = 1,
+681        sensitivity_precision: float = 1,
+682        optim_lr: float = 1e-3,
+683        optim_weight_decay: float = 0.0,
+684        scheduler_args: SchedulerArgs | None = None,
+685        learn_logit_scale: bool = False,
+686        use_normalized_constrastive: bool = True,
+687        contrastive_loss: ContrastiveLossType | None = None,
+688        precision_softmax_temp: float = 0.01,
+689    ) -> None:
+690        """
+691        Initializes a Global Workspace
+692
+693        Args:
+694            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+695                connected to the GW. Keys are domain names, values are the
+696                `DomainModule`.
+697            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+698                name to a `torch.nn.Module` class which role is to encode a
+699                unimodal latent representations into a GW representation (pre fusion).
+700            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+701                name to a `torch.nn.Module` class which role is to decode a
+702                GW representation into a unimodal latent representations.
+703            workspace_dim (`int`): dimension of the GW.
+704            loss_coefs (`LossCoefs`): loss coefficients
+705            sensitivity_selection (`float`): sensivity coef $c'_1$
+706            sensitivity_precision (`float`): sensitivity coef $c'_2$
+707            optim_lr (`float`): learning rate
+708            optim_weight_decay (`float`): weight decay
+709            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+710            learn_logit_scale (`bool`): whether to learn the contrastive learning
+711                contrastive loss when using the default contrastive loss.
+712            use_normalized_constrastive (`bool`): whether to use the normalized cont
+713                loss by the precision coefs
+714            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+715                function used for alignment. `learn_logit_scale` will not affect custom
+716                contrastive losses.
+717            precision_softmax_temp (`float`): temperature to use in softmax of
+718                precision
+719        """
+720        domain_mods = freeze_domain_modules(domain_mods)
+721
+722        gw_mod = GWModuleBayesian(
+723            domain_mods,
+724            workspace_dim,
+725            gw_encoders,
+726            gw_decoders,
+727            sensitivity_selection,
+728            sensitivity_precision,
+729            precision_softmax_temp,
+730        )
+731
+732        selection_mod = FixedSharedSelection()
+733
+734        contrastive_loss = ContrastiveLoss(
+735            torch.tensor([1]).log(), "mean", learn_logit_scale
+736        )
+737
+738        loss_mod = GWLossesBayesian(
+739            gw_mod,
+740            selection_mod,
+741            domain_mods,
+742            loss_coefs,
+743            contrastive_loss,
+744            use_normalized_constrastive,
+745        )
+746
+747        super().__init__(
+748            gw_mod,
+749            selection_mod,
+750            loss_mod,
+751            optim_lr,
+752            optim_weight_decay,
+753            scheduler_args,
+754        )
+755
+756    def forward(  # type: ignore
+757        self,
+758        latent_domains: LatentsDomainGroupsT,
+759    ) -> GWPredictions:
+760        """
+761        Computes demi-cycles, cycles, and translations.
+762
+763        Args:
+764            latent_domains (`LatentsT`): Groups of domains for the computation.
+765
+766        Returns:
+767            `GWPredictions`: the predictions on the batch.
+768        """
+769        return GWPredictions(
+770            demi_cycles=batch_demi_cycles(
+771                self.gw_mod, self.selection_mod, latent_domains
+772            ),
+773            cycles=batch_cycles(
+774                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+775            ),
+776            translations=batch_translations(
+777                self.gw_mod, self.selection_mod, latent_domains
+778            ),
+779            **super().forward(latent_domains),
+780        )
+
+ + +

A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty +prediction.

+ +

This is used to simplify a Global Workspace instanciation and only overrides the +__init__ method.

+
+ + +
+ +
+ + GlobalWorkspaceBayesian( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.BroadcastLossCoefs, sensitivity_selection: float = 1, sensitivity_precision: float = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, use_normalized_constrastive: bool = True, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None, precision_softmax_temp: float = 0.01) + + + +
+ +
673    def __init__(
+674        self,
+675        domain_mods: Mapping[str, DomainModule],
+676        gw_encoders: Mapping[str, Module],
+677        gw_decoders: Mapping[str, Module],
+678        workspace_dim: int,
+679        loss_coefs: BroadcastLossCoefs,
+680        sensitivity_selection: float = 1,
+681        sensitivity_precision: float = 1,
+682        optim_lr: float = 1e-3,
+683        optim_weight_decay: float = 0.0,
+684        scheduler_args: SchedulerArgs | None = None,
+685        learn_logit_scale: bool = False,
+686        use_normalized_constrastive: bool = True,
+687        contrastive_loss: ContrastiveLossType | None = None,
+688        precision_softmax_temp: float = 0.01,
+689    ) -> None:
+690        """
+691        Initializes a Global Workspace
+692
+693        Args:
+694            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+695                connected to the GW. Keys are domain names, values are the
+696                `DomainModule`.
+697            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+698                name to a `torch.nn.Module` class which role is to encode a
+699                unimodal latent representations into a GW representation (pre fusion).
+700            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+701                name to a `torch.nn.Module` class which role is to decode a
+702                GW representation into a unimodal latent representations.
+703            workspace_dim (`int`): dimension of the GW.
+704            loss_coefs (`LossCoefs`): loss coefficients
+705            sensitivity_selection (`float`): sensivity coef $c'_1$
+706            sensitivity_precision (`float`): sensitivity coef $c'_2$
+707            optim_lr (`float`): learning rate
+708            optim_weight_decay (`float`): weight decay
+709            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
+710            learn_logit_scale (`bool`): whether to learn the contrastive learning
+711                contrastive loss when using the default contrastive loss.
+712            use_normalized_constrastive (`bool`): whether to use the normalized cont
+713                loss by the precision coefs
+714            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
+715                function used for alignment. `learn_logit_scale` will not affect custom
+716                contrastive losses.
+717            precision_softmax_temp (`float`): temperature to use in softmax of
+718                precision
+719        """
+720        domain_mods = freeze_domain_modules(domain_mods)
+721
+722        gw_mod = GWModuleBayesian(
+723            domain_mods,
+724            workspace_dim,
+725            gw_encoders,
+726            gw_decoders,
+727            sensitivity_selection,
+728            sensitivity_precision,
+729            precision_softmax_temp,
+730        )
+731
+732        selection_mod = FixedSharedSelection()
+733
+734        contrastive_loss = ContrastiveLoss(
+735            torch.tensor([1]).log(), "mean", learn_logit_scale
+736        )
+737
+738        loss_mod = GWLossesBayesian(
+739            gw_mod,
+740            selection_mod,
+741            domain_mods,
+742            loss_coefs,
+743            contrastive_loss,
+744            use_normalized_constrastive,
+745        )
+746
+747        super().__init__(
+748            gw_mod,
+749            selection_mod,
+750            loss_mod,
+751            optim_lr,
+752            optim_weight_decay,
+753            scheduler_args,
+754        )
+
+ + +

Initializes a Global Workspace

+ +
Arguments:
+ +
    +
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains +connected to the GW. Keys are domain names, values are the +DomainModule.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to encode a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to decode a +GW representation into a unimodal latent representations.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • loss_coefs (LossCoefs): loss coefficients
  • +
  • sensitivity_selection (float): sensivity coef $c'_1$
  • +
  • sensitivity_precision (float): sensitivity coef $c'_2$
  • +
  • optim_lr (float): learning rate
  • +
  • optim_weight_decay (float): weight decay
  • +
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • +
  • learn_logit_scale (bool): whether to learn the contrastive learning +contrastive loss when using the default contrastive loss.
  • +
  • use_normalized_constrastive (bool): whether to use the normalized cont +loss by the precision coefs
  • +
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss +function used for alignment. learn_logit_scale will not affect custom +contrastive losses.
  • +
  • precision_softmax_temp (float): temperature to use in softmax of +precision
  • +
+
+ + +
+
+ +
+ + def + forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions: + + + +
+ +
756    def forward(  # type: ignore
+757        self,
+758        latent_domains: LatentsDomainGroupsT,
+759    ) -> GWPredictions:
+760        """
+761        Computes demi-cycles, cycles, and translations.
+762
+763        Args:
+764            latent_domains (`LatentsT`): Groups of domains for the computation.
+765
+766        Returns:
+767            `GWPredictions`: the predictions on the batch.
+768        """
+769        return GWPredictions(
+770            demi_cycles=batch_demi_cycles(
+771                self.gw_mod, self.selection_mod, latent_domains
+772            ),
+773            cycles=batch_cycles(
+774                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
+775            ),
+776            translations=batch_translations(
+777                self.gw_mod, self.selection_mod, latent_domains
+778            ),
+779            **super().forward(latent_domains),
+780        )
+
+ + +

Computes demi-cycles, cycles, and translations.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsT): Groups of domains for the computation.
  • +
+ +
Returns:
+ +
+

GWPredictions: the predictions on the batch.

+
+
+ + +
+
+
Inherited Members
+
+ +
lightning.pytorch.core.module.LightningModule
+
CHECKPOINT_HYPER_PARAMS_KEY
+
CHECKPOINT_HYPER_PARAMS_NAME
+
CHECKPOINT_HYPER_PARAMS_TYPE
+
optimizers
+
lr_schedulers
+
trainer
+
fabric
+
example_input_array
+
current_epoch
+
global_step
+
global_rank
+
local_rank
+
on_gpu
+
automatic_optimization
+
strict_loading
+
logger
+
loggers
+
print
+
log
+
log_dict
+
all_gather
+
configure_callbacks
+
manual_backward
+
backward
+
toggle_optimizer
+
untoggle_optimizer
+
clip_gradients
+
configure_gradient_clipping
+
lr_scheduler_step
+
optimizer_step
+
optimizer_zero_grad
+
freeze
+
unfreeze
+
to_onnx
+
to_torchscript
+
load_from_checkpoint
+ +
+
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
+
dtype
+
device
+
to
+
cuda
+
cpu
+
type
+
float
+
double
+
half
+ +
+
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
+
save_hyperparameters
+
hparams
+
hparams_initial
+ +
+
lightning.pytorch.core.hooks.ModelHooks
+
on_fit_start
+
on_fit_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_train_batch_start
+
on_train_batch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_validation_model_zero_grad
+
on_validation_model_eval
+
on_validation_model_train
+
on_test_model_eval
+
on_test_model_train
+
on_predict_model_eval
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_before_zero_grad
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
configure_sharded_model
+
configure_model
+ +
+
lightning.pytorch.core.hooks.DataHooks
+
prepare_data_per_node
+
allow_zero_length_dataloader_with_multiple_devices
+
prepare_data
+
setup
+
teardown
+
train_dataloader
+
test_dataloader
+
val_dataloader
+
predict_dataloader
+
transfer_batch_to_device
+
on_before_batch_transfer
+
on_after_batch_transfer
+ +
+
lightning.pytorch.core.hooks.CheckpointHooks
+
on_load_checkpoint
+
on_save_checkpoint
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
ipu
+
xpu
+
bfloat16
+
to_empty
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + def + pretrained_global_workspace( checkpoint_path: str | pathlib.Path, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.LossCoefs, contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput], **kwargs) -> GlobalWorkspace2Domains: + + + +
+ +
783def pretrained_global_workspace(
+784    checkpoint_path: str | Path,
+785    domain_mods: Mapping[str, DomainModule],
+786    gw_encoders: Mapping[str, Module],
+787    gw_decoders: Mapping[str, Module],
+788    workspace_dim: int,
+789    loss_coefs: LossCoefs,
+790    contrastive_fn: ContrastiveLossType,
+791    **kwargs,
+792) -> GlobalWorkspace2Domains:
+793    """
+794    Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
+795
+796    Args:
+797        checkpoint_path (`str | Path`): path to checkpoint
+798        domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
+799            connected to the GW. Keys are domain names, values are the
+800            `DomainModule`.
+801        gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+802            name to a `torch.nn.Module` class which role is to encode a
+803            unimodal latent representations into a GW representation (pre fusion).
+804        gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+805            name to a `torch.nn.Module` class which role is to decode a
+806            GW representation into a unimodal latent representations.
+807        workspace_dim (`int`): dimension of the GW.
+808        loss_coefs (`LossCoefs`): loss coefficients
+809        contrastive_loss (`ContrastiveLossType`): a contrastive loss
+810            function used for alignment. `learn_logit_scale` will not affect custom
+811            contrastive losses.
+812        **kwargs: additional arguments to pass to
+813            `GlobalWorkspace.load_from_checkpoint`.
+814
+815    Returns:
+816        `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
+817
+818    Raises:
+819        `TypeError`: if loaded type is not `GlobalWorkspace`.
+820    """
+821    domain_mods = freeze_domain_modules(domain_mods)
+822    gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
+823    selection_mod = SingleDomainSelection()
+824    loss_mod = GWLosses2Domains(
+825        gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
+826    )
+827
+828    gw = GlobalWorkspace2Domains.load_from_checkpoint(
+829        checkpoint_path,
+830        gw_mod=gw_mod,
+831        selection_mid=selection_mod,
+832        loss_coefs=loss_coefs,
+833        loss_mod=loss_mod,
+834        **kwargs,
+835    )
+836    if not isinstance(gw, GlobalWorkspace2Domains):
+837        raise TypeError("model should be of type GlobalWorkspace")
+838    return gw
+
+ + +

Load a GlobalWorkspace flavor of GlobalWorkspaceBase from a checkpoint.

+ +
Arguments:
+ +
    +
  • checkpoint_path (str | Path): path to checkpoint
  • +
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains +connected to the GW. Keys are domain names, values are the +DomainModule.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to encode a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a torch.nn.Module class which role is to decode a +GW representation into a unimodal latent representations.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • loss_coefs (LossCoefs): loss coefficients
  • +
  • contrastive_loss (ContrastiveLossType): a contrastive loss +function used for alignment. learn_logit_scale will not affect custom +contrastive losses.
  • +
  • **kwargs: additional arguments to pass to +GlobalWorkspace.load_from_checkpoint.
  • +
+ +
Returns:
+ +
+

GlobalWorkspace: the pretrained GlobalWorkspace.

+
+ +
Raises:
+ + +
+ + +
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/gw_module.html b/docs/api/v0.5.1/shimmer/modules/gw_module.html new file mode 100644 index 00000000..bf539184 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/gw_module.html @@ -0,0 +1,2893 @@ + + + + + + + shimmer.modules.gw_module API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.gw_module

+ + + + + + +
  1from abc import ABC, abstractmethod
+  2from collections.abc import Iterable, Mapping
+  3from typing import cast
+  4
+  5import torch
+  6from torch import nn
+  7
+  8from shimmer.modules.domain import DomainModule
+  9from shimmer.modules.selection import SelectionBase
+ 10from shimmer.types import LatentsDomainGroupDT, LatentsDomainGroupT
+ 11
+ 12
+ 13def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]:
+ 14    """
+ 15    Makes a list of `n_layers` `nn.Linear` layers with `nn.ReLU`.
+ 16
+ 17    Args:
+ 18        n_layers (`int`): number of layers
+ 19        hidden_dim (`int`): size of the hidden dimension
+ 20
+ 21    Returns:
+ 22        `list[nn.Module]`: list of linear and relu layers.
+ 23    """
+ 24    layers: list[nn.Module] = []
+ 25    for _ in range(n_layers):
+ 26        layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
+ 27    return layers
+ 28
+ 29
+ 30class GWDecoder(nn.Sequential):
+ 31    """A Decoder network for GWModules."""
+ 32
+ 33    def __init__(
+ 34        self,
+ 35        in_dim: int,
+ 36        hidden_dim: int,
+ 37        out_dim: int,
+ 38        n_layers: int,
+ 39    ):
+ 40        """
+ 41        Initializes the decoder.
+ 42
+ 43        Args:
+ 44            in_dim (`int`): input dimension
+ 45            hidden_dim (`int`): hidden dimension
+ 46            out_dim (`int`): output dimension
+ 47            n_layers (`int`): number of hidden layers. The total number of layers
+ 48                will be `n_layers` + 2 (one before, one after).
+ 49        """
+ 50
+ 51        self.in_dim = in_dim
+ 52        """input dimension"""
+ 53
+ 54        self.hidden_dim = hidden_dim
+ 55        """hidden dimension"""
+ 56
+ 57        self.out_dim = out_dim
+ 58        """output dimension"""
+ 59
+ 60        self.n_layers = n_layers
+ 61        """
+ 62        number of hidden layers. The total number of layers
+ 63                will be `n_layers` + 2 (one before, one after)."""
+ 64
+ 65        super().__init__(
+ 66            nn.Linear(self.in_dim, self.hidden_dim),
+ 67            nn.ReLU(),
+ 68            *get_n_layers(n_layers, self.hidden_dim),
+ 69            nn.Linear(self.hidden_dim, self.out_dim),
+ 70        )
+ 71
+ 72
+ 73class GWEncoder(GWDecoder):
+ 74    """
+ 75    An Encoder network used in GWModules.
+ 76
+ 77    This is similar to the decoder, but adds a tanh non-linearity at the end.
+ 78    """
+ 79
+ 80    def __init__(
+ 81        self,
+ 82        in_dim: int,
+ 83        hidden_dim: int,
+ 84        out_dim: int,
+ 85        n_layers: int,
+ 86    ):
+ 87        """
+ 88        Initializes the encoder.
+ 89
+ 90        Args:
+ 91            in_dim (`int`): input dimension
+ 92            hidden_dim (`int`): hidden dimension
+ 93            out_dim (`int`): output dimension
+ 94            n_layers (`int`): number of hidden layers. The total number of layers
+ 95                will be `n_layers` + 2 (one before, one after).
+ 96        """
+ 97        super().__init__(in_dim, hidden_dim, out_dim, n_layers)
+ 98
+ 99    def forward(self, input: torch.Tensor) -> torch.Tensor:
+100        return super().forward(input)
+101
+102
+103class GWEncoderLinear(nn.Linear):
+104    """A linear Encoder network used in GWModules."""
+105
+106    def forward(self, input: torch.Tensor) -> torch.Tensor:
+107        return torch.tanh(super().forward(input))
+108
+109
+110class GWModuleBase(nn.Module, ABC):
+111    """
+112    Base class for GWModule.
+113
+114    GWModule handles encoding, decoding the unimodal representations
+115    using the `gw_encoders` and`gw_decoders`, and define
+116    some common operations in GW like cycles and translations.
+117
+118    This is an abstract class and should be implemented.
+119    For an implemented interface, see `GWModule`.
+120    """
+121
+122    def __init__(
+123        self,
+124        domain_mods: Mapping[str, DomainModule],
+125        workspace_dim: int,
+126        *args,
+127        **kwargs,
+128    ) -> None:
+129        """
+130        Initializes the GWModule.
+131
+132        Args:
+133            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+134            workspace_dim (`int`): dimension of the GW.
+135        """
+136        super().__init__()
+137
+138        self.domain_mods = domain_mods
+139        """The unimodal domain modules."""
+140
+141        self.workspace_dim = workspace_dim
+142        """Dimension of the GW"""
+143
+144    @abstractmethod
+145    def fuse(
+146        self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
+147    ) -> torch.Tensor:
+148        """
+149        Merge function used to combine domains.
+150
+151        Args:
+152            x (`LatentsDomainGroupT`): the group of latent representation.
+153            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+154                use to encode the reprensetation.
+155        Returns:
+156            `torch.Tensor`: The merged representation.
+157        """
+158        ...
+159
+160    @abstractmethod
+161    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+162        """
+163        Encode the latent representation infos to the pre-fusion GW representation.
+164
+165        Args:
+166            x (`LatentsDomainGroupT`): the input domain representations
+167
+168        Returns:
+169            `LatentsDomainGroupT`: pre-fusion GW representations
+170        """
+171        ...
+172
+173    def encode_and_fuse(
+174        self, x: LatentsDomainGroupT, selection_module: SelectionBase
+175    ) -> torch.Tensor:
+176        """
+177        Encode the latent representation infos to the final GW representation.
+178        It combines the encode and fuse methods.
+179
+180        Args:
+181            x (`LatentsDomainGroupT`): the input domain representations
+182            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+183                use to encode the reprensetation.
+184
+185        Returns:
+186            `torch.Tensor`: The merged representation.
+187        """
+188        encodings = self.encode(x)
+189        selection_scores = selection_module(x, encodings)
+190        return self.fuse(encodings, selection_scores)
+191
+192    @abstractmethod
+193    def decode(
+194        self, z: torch.Tensor, domains: Iterable[str] | None = None
+195    ) -> LatentsDomainGroupDT:
+196        """
+197        Decode the GW representation into given `domains`.
+198
+199        Args:
+200            z (`torch.Tensor`): the GW representation.
+201            domains (`Iterable[str]`): iterable of domains to decode.
+202
+203        Returns:
+204            `LatentsDomainGroupDT`: the decoded unimodal representations.
+205        """
+206        ...
+207
+208
+209class GWModule(GWModuleBase):
+210    """GW nn.Module. Implements `GWModuleBase`."""
+211
+212    def __init__(
+213        self,
+214        domain_modules: Mapping[str, DomainModule],
+215        workspace_dim: int,
+216        gw_encoders: Mapping[str, nn.Module],
+217        gw_decoders: Mapping[str, nn.Module],
+218    ) -> None:
+219        """
+220        Initializes the GWModule.
+221
+222        Args:
+223            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+224            workspace_dim (`int`): dimension of the GW.
+225            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+226                name to a an torch.nn.Module class that encodes a
+227                unimodal latent representations into a GW representation (pre fusion).
+228            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+229                name to a an torch.nn.Module class that decodes a
+230                 GW representation to a unimodal latent representation.
+231        """
+232        super().__init__(domain_modules, workspace_dim)
+233
+234        self.gw_encoders = nn.ModuleDict(gw_encoders)
+235        """The module's encoders"""
+236
+237        self.gw_decoders = nn.ModuleDict(gw_decoders)
+238        """The module's decoders"""
+239
+240    def fuse(
+241        self,
+242        x: LatentsDomainGroupT,
+243        selection_scores: Mapping[str, torch.Tensor],
+244    ) -> torch.Tensor:
+245        """
+246        Merge function used to combine domains.
+247
+248        Args:
+249            x (`LatentsDomainGroupT`): the group of latent representation.
+250            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+251                use to encode the reprensetation.
+252        Returns:
+253            `torch.Tensor`: The merged representation.
+254        """
+255        return torch.tanh(
+256            torch.sum(
+257                torch.stack(
+258                    [
+259                        selection_scores[domain].unsqueeze(1) * x[domain]
+260                        for domain in selection_scores
+261                    ]
+262                ),
+263                dim=0,
+264            )
+265        )
+266
+267    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+268        """
+269        Encode the latent representation infos to the pre-fusion GW representation.
+270
+271        Args:
+272            x (`LatentsDomainGroupT`): the input domain representations.
+273
+274        Returns:
+275            `LatentsDomainGroupT`: pre-fusion representation
+276        """
+277        return {
+278            domain_name: self.gw_encoders[domain_name](domain)
+279            for domain_name, domain in x.items()
+280        }
+281
+282    def decode(
+283        self, z: torch.Tensor, domains: Iterable[str] | None = None
+284    ) -> LatentsDomainGroupDT:
+285        """
+286        Decodes a GW representation to multiple domains.
+287
+288        Args:
+289            z (`torch.Tensor`): the GW representation
+290            domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+291                use keys in `gw_interfaces` (all domains).
+292        Returns:
+293            `LatentsDomainGroupDT`: decoded unimodal representation
+294        """
+295        return {
+296            domain: self.gw_decoders[domain](z)
+297            for domain in domains or self.gw_decoders.keys()
+298        }
+299
+300
+301def compute_fusion_scores(
+302    score_1: torch.Tensor,
+303    score_2: torch.Tensor,
+304    sensitivity_1: float = 1.0,
+305    sensitivity_2: float = 1.0,
+306    eps: float = 1e-6,
+307) -> torch.Tensor:
+308    """
+309    Combine precision scores using std summation in quadrature
+310
+311    The two scores should have the same dimension.
+312
+313    Args:
+314        score_1 (`torch.Tensor`): First scores.
+315        score_2 (`torch.Tensor`): Second scores.
+316        sensitivity_1 (`float`): sensitivity for the first score
+317        sensitivity_2 (`float`): sensitivity for the second score
+318        eps (`float`): a value added to avoid numerical unstability.
+319
+320    Returns:
+321        `torch.Tensor`: the combined scores
+322    """
+323    total_uncertainty = sensitivity_1 / (eps + score_1) + sensitivity_2 / (
+324        eps + score_2
+325    )
+326    final_scores = 1 / (eps + total_uncertainty)
+327    return final_scores / final_scores.sum(dim=0, keepdim=True)
+328
+329
+330class GWModuleBayesian(GWModule):
+331    """`GWModule` with a Bayesian based uncertainty prediction."""
+332
+333    def __init__(
+334        self,
+335        domain_modules: Mapping[str, DomainModule],
+336        workspace_dim: int,
+337        gw_encoders: Mapping[str, nn.Module],
+338        gw_decoders: Mapping[str, nn.Module],
+339        sensitivity_selection: float = 1,
+340        sensitivity_precision: float = 1,
+341        precision_softmax_temp: float = 0.01,
+342    ) -> None:
+343        """
+344        Initializes the GWModuleBayesian.
+345
+346        Args:
+347            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+348            workspace_dim (`int`): dimension of the GW.
+349            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+350                name to a an torch.nn.Module class that encodes a
+351                unimodal latent representations into a GW representation (pre fusion).
+352            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+353                name to a an torch.nn.Module class that decodes a
+354                 GW representation to a unimodal latent representation.
+355            sensitivity_selection (`float`): sensivity coef $c'_1$
+356            sensitivity_precision (`float`): sensitivity coef $c'_2$
+357            precision_softmax_temp (`float`): temperature to use in softmax of
+358                precision
+359        """
+360        super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders)
+361
+362        self.precisions = cast(
+363            dict[str, torch.Tensor],
+364            nn.ParameterDict(
+365                {domain: torch.randn(workspace_dim) for domain in gw_encoders}
+366            ),
+367        )
+368        """Precision at the neuron level for every domain."""
+369
+370        self.sensitivity_selection = sensitivity_selection
+371        self.sensitivity_precision = sensitivity_precision
+372        self.precision_softmax_temp = precision_softmax_temp
+373
+374    def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor:
+375        """
+376        Get the precision vector of given domain and batch
+377
+378        Args:
+379            domain (`str`):
+380            x (`torch.Tensor`): batch of inputs
+381
+382        Returns:
+383            `torch.Tensor`: batch of precision
+384        """
+385        return self.precisions[domain].unsqueeze(0).expand(x.size(0), -1)
+386
+387    def fuse(
+388        self,
+389        x: LatentsDomainGroupT,
+390        selection_scores: Mapping[str, torch.Tensor],
+391    ) -> torch.Tensor:
+392        """
+393        Merge function used to combine domains.
+394
+395        In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+396        dimension of the Global Workspace.
+397
+398        This function needs to merge two kind of scores:
+399        * the selection scores $a\\in [0,1]^{D\\times N}$;
+400        * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+401
+402        .. note::
+403            The precision score is obtained by predicting logits and using a softmax
+404
+405        We can obtain associated uncertainties to the scores by introducing a std
+406        variable and using bayesian integration:
+407
+408        $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+409        where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+410
+411        Similarly,
+412        $$b_k = \\frac{M_2}{\\mu_k^2}$$
+413        where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+414
+415        The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+416        $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+417
+418        which, in terms of $a_k$ and $b_k$ yields:
+419        $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+420        where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+421
+422        Finally, the finale combined coefficient is
+423        $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+424        where
+425        $$M_3 = \\frac{1}{\\sum_{i=1}^D
+426            \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+427
+428        Args:
+429            x (`LatentsDomainGroupT`): the group of latent representation.
+430            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+431                use to encode the reprensetation.
+432        Returns:
+433            `torch.Tensor`: The merged representation.
+434        """
+435        scores: list[torch.Tensor] = []
+436        precisions: list[torch.Tensor] = []
+437        domains: list[torch.Tensor] = []
+438        for domain, score in selection_scores.items():
+439            scores.append(score)
+440            precisions.append(self.get_precision(domain, x[domain]))
+441            domains.append(x[domain])
+442        combined_scores = compute_fusion_scores(
+443            torch.stack(scores).unsqueeze(-1),
+444            torch.softmax(
+445                torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0
+446            ),
+447            self.sensitivity_selection,
+448            self.sensitivity_precision,
+449        )
+450        return torch.tanh(
+451            torch.sum(
+452                combined_scores * torch.stack(domains),
+453                dim=0,
+454            )
+455        )
+
+ + +
+
+ +
+ + def + get_n_layers(n_layers: int, hidden_dim: int) -> list[torch.nn.modules.module.Module]: + + + +
+ +
14def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]:
+15    """
+16    Makes a list of `n_layers` `nn.Linear` layers with `nn.ReLU`.
+17
+18    Args:
+19        n_layers (`int`): number of layers
+20        hidden_dim (`int`): size of the hidden dimension
+21
+22    Returns:
+23        `list[nn.Module]`: list of linear and relu layers.
+24    """
+25    layers: list[nn.Module] = []
+26    for _ in range(n_layers):
+27        layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()])
+28    return layers
+
+ + +

Makes a list of n_layers nn.Linear layers with nn.ReLU.

+ +
Arguments:
+ +
    +
  • n_layers (int): number of layers
  • +
  • hidden_dim (int): size of the hidden dimension
  • +
+ +
Returns:
+ +
+

list[nn.Module]: list of linear and relu layers.

+
+
+ + +
+
+ +
+ + class + GWDecoder(torch.nn.modules.container.Sequential): + + + +
+ +
31class GWDecoder(nn.Sequential):
+32    """A Decoder network for GWModules."""
+33
+34    def __init__(
+35        self,
+36        in_dim: int,
+37        hidden_dim: int,
+38        out_dim: int,
+39        n_layers: int,
+40    ):
+41        """
+42        Initializes the decoder.
+43
+44        Args:
+45            in_dim (`int`): input dimension
+46            hidden_dim (`int`): hidden dimension
+47            out_dim (`int`): output dimension
+48            n_layers (`int`): number of hidden layers. The total number of layers
+49                will be `n_layers` + 2 (one before, one after).
+50        """
+51
+52        self.in_dim = in_dim
+53        """input dimension"""
+54
+55        self.hidden_dim = hidden_dim
+56        """hidden dimension"""
+57
+58        self.out_dim = out_dim
+59        """output dimension"""
+60
+61        self.n_layers = n_layers
+62        """
+63        number of hidden layers. The total number of layers
+64                will be `n_layers` + 2 (one before, one after)."""
+65
+66        super().__init__(
+67            nn.Linear(self.in_dim, self.hidden_dim),
+68            nn.ReLU(),
+69            *get_n_layers(n_layers, self.hidden_dim),
+70            nn.Linear(self.hidden_dim, self.out_dim),
+71        )
+
+ + +

A Decoder network for GWModules.

+
+ + +
+ +
+ + GWDecoder(in_dim: int, hidden_dim: int, out_dim: int, n_layers: int) + + + +
+ +
34    def __init__(
+35        self,
+36        in_dim: int,
+37        hidden_dim: int,
+38        out_dim: int,
+39        n_layers: int,
+40    ):
+41        """
+42        Initializes the decoder.
+43
+44        Args:
+45            in_dim (`int`): input dimension
+46            hidden_dim (`int`): hidden dimension
+47            out_dim (`int`): output dimension
+48            n_layers (`int`): number of hidden layers. The total number of layers
+49                will be `n_layers` + 2 (one before, one after).
+50        """
+51
+52        self.in_dim = in_dim
+53        """input dimension"""
+54
+55        self.hidden_dim = hidden_dim
+56        """hidden dimension"""
+57
+58        self.out_dim = out_dim
+59        """output dimension"""
+60
+61        self.n_layers = n_layers
+62        """
+63        number of hidden layers. The total number of layers
+64                will be `n_layers` + 2 (one before, one after)."""
+65
+66        super().__init__(
+67            nn.Linear(self.in_dim, self.hidden_dim),
+68            nn.ReLU(),
+69            *get_n_layers(n_layers, self.hidden_dim),
+70            nn.Linear(self.hidden_dim, self.out_dim),
+71        )
+
+ + +

Initializes the decoder.

+ +
Arguments:
+ +
    +
  • in_dim (int): input dimension
  • +
  • hidden_dim (int): hidden dimension
  • +
  • out_dim (int): output dimension
  • +
  • n_layers (int): number of hidden layers. The total number of layers +will be n_layers + 2 (one before, one after).
  • +
+
+ + +
+
+
+ in_dim + + +
+ + +

input dimension

+
+ + +
+
+
+ hidden_dim + + +
+ + +

hidden dimension

+
+ + +
+
+
+ out_dim + + +
+ + +

output dimension

+
+ + +
+
+
+ n_layers + + +
+ + +

number of hidden layers. The total number of layers + will be n_layers + 2 (one before, one after).

+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.container.Sequential
+
pop
+
forward
+
append
+
insert
+
extend
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + GWEncoder(GWDecoder): + + + +
+ +
 74class GWEncoder(GWDecoder):
+ 75    """
+ 76    An Encoder network used in GWModules.
+ 77
+ 78    This is similar to the decoder, but adds a tanh non-linearity at the end.
+ 79    """
+ 80
+ 81    def __init__(
+ 82        self,
+ 83        in_dim: int,
+ 84        hidden_dim: int,
+ 85        out_dim: int,
+ 86        n_layers: int,
+ 87    ):
+ 88        """
+ 89        Initializes the encoder.
+ 90
+ 91        Args:
+ 92            in_dim (`int`): input dimension
+ 93            hidden_dim (`int`): hidden dimension
+ 94            out_dim (`int`): output dimension
+ 95            n_layers (`int`): number of hidden layers. The total number of layers
+ 96                will be `n_layers` + 2 (one before, one after).
+ 97        """
+ 98        super().__init__(in_dim, hidden_dim, out_dim, n_layers)
+ 99
+100    def forward(self, input: torch.Tensor) -> torch.Tensor:
+101        return super().forward(input)
+
+ + +

An Encoder network used in GWModules.

+ +

This is similar to the decoder, but adds a tanh non-linearity at the end.

+
+ + +
+ +
+ + GWEncoder(in_dim: int, hidden_dim: int, out_dim: int, n_layers: int) + + + +
+ +
81    def __init__(
+82        self,
+83        in_dim: int,
+84        hidden_dim: int,
+85        out_dim: int,
+86        n_layers: int,
+87    ):
+88        """
+89        Initializes the encoder.
+90
+91        Args:
+92            in_dim (`int`): input dimension
+93            hidden_dim (`int`): hidden dimension
+94            out_dim (`int`): output dimension
+95            n_layers (`int`): number of hidden layers. The total number of layers
+96                will be `n_layers` + 2 (one before, one after).
+97        """
+98        super().__init__(in_dim, hidden_dim, out_dim, n_layers)
+
+ + +

Initializes the encoder.

+ +
Arguments:
+ +
    +
  • in_dim (int): input dimension
  • +
  • hidden_dim (int): hidden dimension
  • +
  • out_dim (int): output dimension
  • +
  • n_layers (int): number of hidden layers. The total number of layers +will be n_layers + 2 (one before, one after).
  • +
+
+ + +
+
+ +
+ + def + forward(self, input: torch.Tensor) -> torch.Tensor: + + + +
+ +
100    def forward(self, input: torch.Tensor) -> torch.Tensor:
+101        return super().forward(input)
+
+ + +

Define the computation performed at every call.

+ +

Should be overridden by all subclasses.

+ +
+ +

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+ +
+
+ + +
+
+
Inherited Members
+
+ +
torch.nn.modules.container.Sequential
+
pop
+
append
+
insert
+
extend
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + GWEncoderLinear(torch.nn.modules.linear.Linear): + + + +
+ +
104class GWEncoderLinear(nn.Linear):
+105    """A linear Encoder network used in GWModules."""
+106
+107    def forward(self, input: torch.Tensor) -> torch.Tensor:
+108        return torch.tanh(super().forward(input))
+
+ + +

A linear Encoder network used in GWModules.

+
+ + +
+ +
+ + def + forward(self, input: torch.Tensor) -> torch.Tensor: + + + +
+ +
107    def forward(self, input: torch.Tensor) -> torch.Tensor:
+108        return torch.tanh(super().forward(input))
+
+ + +

Define the computation performed at every call.

+ +

Should be overridden by all subclasses.

+ +
+ +

Although the recipe for forward pass needs to be defined within +this function, one should call the Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+ +
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.linear.Linear
+
Linear
+
in_features
+
out_features
+
weight
+
reset_parameters
+
extra_repr
+ +
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
compile
+ +
+
+
+
+
+ +
+ + class + GWModuleBase(torch.nn.modules.module.Module, abc.ABC): + + + +
+ +
111class GWModuleBase(nn.Module, ABC):
+112    """
+113    Base class for GWModule.
+114
+115    GWModule handles encoding, decoding the unimodal representations
+116    using the `gw_encoders` and`gw_decoders`, and define
+117    some common operations in GW like cycles and translations.
+118
+119    This is an abstract class and should be implemented.
+120    For an implemented interface, see `GWModule`.
+121    """
+122
+123    def __init__(
+124        self,
+125        domain_mods: Mapping[str, DomainModule],
+126        workspace_dim: int,
+127        *args,
+128        **kwargs,
+129    ) -> None:
+130        """
+131        Initializes the GWModule.
+132
+133        Args:
+134            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+135            workspace_dim (`int`): dimension of the GW.
+136        """
+137        super().__init__()
+138
+139        self.domain_mods = domain_mods
+140        """The unimodal domain modules."""
+141
+142        self.workspace_dim = workspace_dim
+143        """Dimension of the GW"""
+144
+145    @abstractmethod
+146    def fuse(
+147        self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
+148    ) -> torch.Tensor:
+149        """
+150        Merge function used to combine domains.
+151
+152        Args:
+153            x (`LatentsDomainGroupT`): the group of latent representation.
+154            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+155                use to encode the reprensetation.
+156        Returns:
+157            `torch.Tensor`: The merged representation.
+158        """
+159        ...
+160
+161    @abstractmethod
+162    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+163        """
+164        Encode the latent representation infos to the pre-fusion GW representation.
+165
+166        Args:
+167            x (`LatentsDomainGroupT`): the input domain representations
+168
+169        Returns:
+170            `LatentsDomainGroupT`: pre-fusion GW representations
+171        """
+172        ...
+173
+174    def encode_and_fuse(
+175        self, x: LatentsDomainGroupT, selection_module: SelectionBase
+176    ) -> torch.Tensor:
+177        """
+178        Encode the latent representation infos to the final GW representation.
+179        It combines the encode and fuse methods.
+180
+181        Args:
+182            x (`LatentsDomainGroupT`): the input domain representations
+183            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+184                use to encode the reprensetation.
+185
+186        Returns:
+187            `torch.Tensor`: The merged representation.
+188        """
+189        encodings = self.encode(x)
+190        selection_scores = selection_module(x, encodings)
+191        return self.fuse(encodings, selection_scores)
+192
+193    @abstractmethod
+194    def decode(
+195        self, z: torch.Tensor, domains: Iterable[str] | None = None
+196    ) -> LatentsDomainGroupDT:
+197        """
+198        Decode the GW representation into given `domains`.
+199
+200        Args:
+201            z (`torch.Tensor`): the GW representation.
+202            domains (`Iterable[str]`): iterable of domains to decode.
+203
+204        Returns:
+205            `LatentsDomainGroupDT`: the decoded unimodal representations.
+206        """
+207        ...
+
+ + +

Base class for GWModule.

+ +

GWModule handles encoding, decoding the unimodal representations +using the gw_encoders andgw_decoders, and define +some common operations in GW like cycles and translations.

+ +

This is an abstract class and should be implemented. +For an implemented interface, see GWModule.

+
+ + +
+ +
+ + GWModuleBase( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], workspace_dim: int, *args, **kwargs) + + + +
+ +
123    def __init__(
+124        self,
+125        domain_mods: Mapping[str, DomainModule],
+126        workspace_dim: int,
+127        *args,
+128        **kwargs,
+129    ) -> None:
+130        """
+131        Initializes the GWModule.
+132
+133        Args:
+134            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+135            workspace_dim (`int`): dimension of the GW.
+136        """
+137        super().__init__()
+138
+139        self.domain_mods = domain_mods
+140        """The unimodal domain modules."""
+141
+142        self.workspace_dim = workspace_dim
+143        """Dimension of the GW"""
+
+ + +

Initializes the GWModule.

+ +
Arguments:
+ +
    +
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
+
+ + +
+
+
+ domain_mods + + +
+ + +

The unimodal domain modules.

+
+ + +
+
+
+ workspace_dim + + +
+ + +

Dimension of the GW

+
+ + +
+
+ +
+
@abstractmethod
+ + def + fuse( self, x: collections.abc.Mapping[str, torch.Tensor], selection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor: + + + +
+ +
145    @abstractmethod
+146    def fuse(
+147        self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
+148    ) -> torch.Tensor:
+149        """
+150        Merge function used to combine domains.
+151
+152        Args:
+153            x (`LatentsDomainGroupT`): the group of latent representation.
+154            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+155                use to encode the reprensetation.
+156        Returns:
+157            `torch.Tensor`: The merged representation.
+158        """
+159        ...
+
+ + +

Merge function used to combine domains.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the group of latent representation.
  • +
  • selection_score (Mapping[str, torch.Tensor]): attention scores to +use to encode the reprensetation.
  • +
+ +
Returns:
+ +
+

torch.Tensor: The merged representation.

+
+
+ + +
+
+ +
+
@abstractmethod
+ + def + encode( self, x: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
161    @abstractmethod
+162    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+163        """
+164        Encode the latent representation infos to the pre-fusion GW representation.
+165
+166        Args:
+167            x (`LatentsDomainGroupT`): the input domain representations
+168
+169        Returns:
+170            `LatentsDomainGroupT`: pre-fusion GW representations
+171        """
+172        ...
+
+ + +

Encode the latent representation infos to the pre-fusion GW representation.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the input domain representations
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupT: pre-fusion GW representations

+
+
+ + +
+
+ +
+ + def + encode_and_fuse( self, x: collections.abc.Mapping[str, torch.Tensor], selection_module: shimmer.modules.selection.SelectionBase) -> torch.Tensor: + + + +
+ +
174    def encode_and_fuse(
+175        self, x: LatentsDomainGroupT, selection_module: SelectionBase
+176    ) -> torch.Tensor:
+177        """
+178        Encode the latent representation infos to the final GW representation.
+179        It combines the encode and fuse methods.
+180
+181        Args:
+182            x (`LatentsDomainGroupT`): the input domain representations
+183            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+184                use to encode the reprensetation.
+185
+186        Returns:
+187            `torch.Tensor`: The merged representation.
+188        """
+189        encodings = self.encode(x)
+190        selection_scores = selection_module(x, encodings)
+191        return self.fuse(encodings, selection_scores)
+
+ + +

Encode the latent representation infos to the final GW representation. +It combines the encode and fuse methods.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the input domain representations
  • +
  • selection_score (Mapping[str, torch.Tensor]): attention scores to +use to encode the reprensetation.
  • +
+ +
Returns:
+ +
+

torch.Tensor: The merged representation.

+
+
+ + +
+
+ +
+
@abstractmethod
+ + def + decode( self, z: torch.Tensor, domains: collections.abc.Iterable[str] | None = None) -> dict[str, torch.Tensor]: + + + +
+ +
193    @abstractmethod
+194    def decode(
+195        self, z: torch.Tensor, domains: Iterable[str] | None = None
+196    ) -> LatentsDomainGroupDT:
+197        """
+198        Decode the GW representation into given `domains`.
+199
+200        Args:
+201            z (`torch.Tensor`): the GW representation.
+202            domains (`Iterable[str]`): iterable of domains to decode.
+203
+204        Returns:
+205            `LatentsDomainGroupDT`: the decoded unimodal representations.
+206        """
+207        ...
+
+ + +

Decode the GW representation into given domains.

+ +
Arguments:
+ +
    +
  • z (torch.Tensor): the GW representation.
  • +
  • domains (Iterable[str]): iterable of domains to decode.
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupDT: the decoded unimodal representations.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + GWModule(GWModuleBase): + + + +
+ +
210class GWModule(GWModuleBase):
+211    """GW nn.Module. Implements `GWModuleBase`."""
+212
+213    def __init__(
+214        self,
+215        domain_modules: Mapping[str, DomainModule],
+216        workspace_dim: int,
+217        gw_encoders: Mapping[str, nn.Module],
+218        gw_decoders: Mapping[str, nn.Module],
+219    ) -> None:
+220        """
+221        Initializes the GWModule.
+222
+223        Args:
+224            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+225            workspace_dim (`int`): dimension of the GW.
+226            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+227                name to a an torch.nn.Module class that encodes a
+228                unimodal latent representations into a GW representation (pre fusion).
+229            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+230                name to a an torch.nn.Module class that decodes a
+231                 GW representation to a unimodal latent representation.
+232        """
+233        super().__init__(domain_modules, workspace_dim)
+234
+235        self.gw_encoders = nn.ModuleDict(gw_encoders)
+236        """The module's encoders"""
+237
+238        self.gw_decoders = nn.ModuleDict(gw_decoders)
+239        """The module's decoders"""
+240
+241    def fuse(
+242        self,
+243        x: LatentsDomainGroupT,
+244        selection_scores: Mapping[str, torch.Tensor],
+245    ) -> torch.Tensor:
+246        """
+247        Merge function used to combine domains.
+248
+249        Args:
+250            x (`LatentsDomainGroupT`): the group of latent representation.
+251            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+252                use to encode the reprensetation.
+253        Returns:
+254            `torch.Tensor`: The merged representation.
+255        """
+256        return torch.tanh(
+257            torch.sum(
+258                torch.stack(
+259                    [
+260                        selection_scores[domain].unsqueeze(1) * x[domain]
+261                        for domain in selection_scores
+262                    ]
+263                ),
+264                dim=0,
+265            )
+266        )
+267
+268    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+269        """
+270        Encode the latent representation infos to the pre-fusion GW representation.
+271
+272        Args:
+273            x (`LatentsDomainGroupT`): the input domain representations.
+274
+275        Returns:
+276            `LatentsDomainGroupT`: pre-fusion representation
+277        """
+278        return {
+279            domain_name: self.gw_encoders[domain_name](domain)
+280            for domain_name, domain in x.items()
+281        }
+282
+283    def decode(
+284        self, z: torch.Tensor, domains: Iterable[str] | None = None
+285    ) -> LatentsDomainGroupDT:
+286        """
+287        Decodes a GW representation to multiple domains.
+288
+289        Args:
+290            z (`torch.Tensor`): the GW representation
+291            domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+292                use keys in `gw_interfaces` (all domains).
+293        Returns:
+294            `LatentsDomainGroupDT`: decoded unimodal representation
+295        """
+296        return {
+297            domain: self.gw_decoders[domain](z)
+298            for domain in domains or self.gw_decoders.keys()
+299        }
+
+ + +

GW nn.Module. Implements GWModuleBase.

+
+ + +
+ +
+ + GWModule( domain_modules: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], workspace_dim: int, gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module]) + + + +
+ +
213    def __init__(
+214        self,
+215        domain_modules: Mapping[str, DomainModule],
+216        workspace_dim: int,
+217        gw_encoders: Mapping[str, nn.Module],
+218        gw_decoders: Mapping[str, nn.Module],
+219    ) -> None:
+220        """
+221        Initializes the GWModule.
+222
+223        Args:
+224            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+225            workspace_dim (`int`): dimension of the GW.
+226            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+227                name to a an torch.nn.Module class that encodes a
+228                unimodal latent representations into a GW representation (pre fusion).
+229            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+230                name to a an torch.nn.Module class that decodes a
+231                 GW representation to a unimodal latent representation.
+232        """
+233        super().__init__(domain_modules, workspace_dim)
+234
+235        self.gw_encoders = nn.ModuleDict(gw_encoders)
+236        """The module's encoders"""
+237
+238        self.gw_decoders = nn.ModuleDict(gw_decoders)
+239        """The module's decoders"""
+
+ + +

Initializes the GWModule.

+ +
Arguments:
+ +
    +
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a an torch.nn.Module class that encodes a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a an torch.nn.Module class that decodes a + GW representation to a unimodal latent representation.
  • +
+
+ + +
+
+
+ gw_encoders + + +
+ + +

The module's encoders

+
+ + +
+
+
+ gw_decoders + + +
+ + +

The module's decoders

+
+ + +
+
+ +
+ + def + fuse( self, x: collections.abc.Mapping[str, torch.Tensor], selection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor: + + + +
+ +
241    def fuse(
+242        self,
+243        x: LatentsDomainGroupT,
+244        selection_scores: Mapping[str, torch.Tensor],
+245    ) -> torch.Tensor:
+246        """
+247        Merge function used to combine domains.
+248
+249        Args:
+250            x (`LatentsDomainGroupT`): the group of latent representation.
+251            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+252                use to encode the reprensetation.
+253        Returns:
+254            `torch.Tensor`: The merged representation.
+255        """
+256        return torch.tanh(
+257            torch.sum(
+258                torch.stack(
+259                    [
+260                        selection_scores[domain].unsqueeze(1) * x[domain]
+261                        for domain in selection_scores
+262                    ]
+263                ),
+264                dim=0,
+265            )
+266        )
+
+ + +

Merge function used to combine domains.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the group of latent representation.
  • +
  • selection_score (Mapping[str, torch.Tensor]): attention scores to +use to encode the reprensetation.
  • +
+ +
Returns:
+ +
+

torch.Tensor: The merged representation.

+
+
+ + +
+
+ +
+ + def + encode( self, x: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
268    def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupDT:
+269        """
+270        Encode the latent representation infos to the pre-fusion GW representation.
+271
+272        Args:
+273            x (`LatentsDomainGroupT`): the input domain representations.
+274
+275        Returns:
+276            `LatentsDomainGroupT`: pre-fusion representation
+277        """
+278        return {
+279            domain_name: self.gw_encoders[domain_name](domain)
+280            for domain_name, domain in x.items()
+281        }
+
+ + +

Encode the latent representation infos to the pre-fusion GW representation.

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the input domain representations.
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupT: pre-fusion representation

+
+
+ + +
+
+ +
+ + def + decode( self, z: torch.Tensor, domains: collections.abc.Iterable[str] | None = None) -> dict[str, torch.Tensor]: + + + +
+ +
283    def decode(
+284        self, z: torch.Tensor, domains: Iterable[str] | None = None
+285    ) -> LatentsDomainGroupDT:
+286        """
+287        Decodes a GW representation to multiple domains.
+288
+289        Args:
+290            z (`torch.Tensor`): the GW representation
+291            domains (`Iterable[str] | None`): the domains to decode to. Defaults to
+292                use keys in `gw_interfaces` (all domains).
+293        Returns:
+294            `LatentsDomainGroupDT`: decoded unimodal representation
+295        """
+296        return {
+297            domain: self.gw_decoders[domain](z)
+298            for domain in domains or self.gw_decoders.keys()
+299        }
+
+ + +

Decodes a GW representation to multiple domains.

+ +
Arguments:
+ +
    +
  • z (torch.Tensor): the GW representation
  • +
  • domains (Iterable[str] | None): the domains to decode to. Defaults to +use keys in gw_interfaces (all domains).
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupDT: decoded unimodal representation

+
+
+ + +
+
+
Inherited Members
+
+ +
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + def + compute_fusion_scores( score_1: torch.Tensor, score_2: torch.Tensor, sensitivity_1: float = 1.0, sensitivity_2: float = 1.0, eps: float = 1e-06) -> torch.Tensor: + + + +
+ +
302def compute_fusion_scores(
+303    score_1: torch.Tensor,
+304    score_2: torch.Tensor,
+305    sensitivity_1: float = 1.0,
+306    sensitivity_2: float = 1.0,
+307    eps: float = 1e-6,
+308) -> torch.Tensor:
+309    """
+310    Combine precision scores using std summation in quadrature
+311
+312    The two scores should have the same dimension.
+313
+314    Args:
+315        score_1 (`torch.Tensor`): First scores.
+316        score_2 (`torch.Tensor`): Second scores.
+317        sensitivity_1 (`float`): sensitivity for the first score
+318        sensitivity_2 (`float`): sensitivity for the second score
+319        eps (`float`): a value added to avoid numerical unstability.
+320
+321    Returns:
+322        `torch.Tensor`: the combined scores
+323    """
+324    total_uncertainty = sensitivity_1 / (eps + score_1) + sensitivity_2 / (
+325        eps + score_2
+326    )
+327    final_scores = 1 / (eps + total_uncertainty)
+328    return final_scores / final_scores.sum(dim=0, keepdim=True)
+
+ + +

Combine precision scores using std summation in quadrature

+ +

The two scores should have the same dimension.

+ +
Arguments:
+ +
    +
  • score_1 (torch.Tensor): First scores.
  • +
  • score_2 (torch.Tensor): Second scores.
  • +
  • sensitivity_1 (float): sensitivity for the first score
  • +
  • sensitivity_2 (float): sensitivity for the second score
  • +
  • eps (float): a value added to avoid numerical unstability.
  • +
+ +
Returns:
+ +
+

torch.Tensor: the combined scores

+
+
+ + +
+
+ +
+ + class + GWModuleBayesian(GWModule): + + + +
+ +
331class GWModuleBayesian(GWModule):
+332    """`GWModule` with a Bayesian based uncertainty prediction."""
+333
+334    def __init__(
+335        self,
+336        domain_modules: Mapping[str, DomainModule],
+337        workspace_dim: int,
+338        gw_encoders: Mapping[str, nn.Module],
+339        gw_decoders: Mapping[str, nn.Module],
+340        sensitivity_selection: float = 1,
+341        sensitivity_precision: float = 1,
+342        precision_softmax_temp: float = 0.01,
+343    ) -> None:
+344        """
+345        Initializes the GWModuleBayesian.
+346
+347        Args:
+348            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+349            workspace_dim (`int`): dimension of the GW.
+350            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+351                name to a an torch.nn.Module class that encodes a
+352                unimodal latent representations into a GW representation (pre fusion).
+353            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+354                name to a an torch.nn.Module class that decodes a
+355                 GW representation to a unimodal latent representation.
+356            sensitivity_selection (`float`): sensivity coef $c'_1$
+357            sensitivity_precision (`float`): sensitivity coef $c'_2$
+358            precision_softmax_temp (`float`): temperature to use in softmax of
+359                precision
+360        """
+361        super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders)
+362
+363        self.precisions = cast(
+364            dict[str, torch.Tensor],
+365            nn.ParameterDict(
+366                {domain: torch.randn(workspace_dim) for domain in gw_encoders}
+367            ),
+368        )
+369        """Precision at the neuron level for every domain."""
+370
+371        self.sensitivity_selection = sensitivity_selection
+372        self.sensitivity_precision = sensitivity_precision
+373        self.precision_softmax_temp = precision_softmax_temp
+374
+375    def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor:
+376        """
+377        Get the precision vector of given domain and batch
+378
+379        Args:
+380            domain (`str`):
+381            x (`torch.Tensor`): batch of inputs
+382
+383        Returns:
+384            `torch.Tensor`: batch of precision
+385        """
+386        return self.precisions[domain].unsqueeze(0).expand(x.size(0), -1)
+387
+388    def fuse(
+389        self,
+390        x: LatentsDomainGroupT,
+391        selection_scores: Mapping[str, torch.Tensor],
+392    ) -> torch.Tensor:
+393        """
+394        Merge function used to combine domains.
+395
+396        In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+397        dimension of the Global Workspace.
+398
+399        This function needs to merge two kind of scores:
+400        * the selection scores $a\\in [0,1]^{D\\times N}$;
+401        * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+402
+403        .. note::
+404            The precision score is obtained by predicting logits and using a softmax
+405
+406        We can obtain associated uncertainties to the scores by introducing a std
+407        variable and using bayesian integration:
+408
+409        $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+410        where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+411
+412        Similarly,
+413        $$b_k = \\frac{M_2}{\\mu_k^2}$$
+414        where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+415
+416        The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+417        $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+418
+419        which, in terms of $a_k$ and $b_k$ yields:
+420        $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+421        where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+422
+423        Finally, the finale combined coefficient is
+424        $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+425        where
+426        $$M_3 = \\frac{1}{\\sum_{i=1}^D
+427            \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+428
+429        Args:
+430            x (`LatentsDomainGroupT`): the group of latent representation.
+431            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+432                use to encode the reprensetation.
+433        Returns:
+434            `torch.Tensor`: The merged representation.
+435        """
+436        scores: list[torch.Tensor] = []
+437        precisions: list[torch.Tensor] = []
+438        domains: list[torch.Tensor] = []
+439        for domain, score in selection_scores.items():
+440            scores.append(score)
+441            precisions.append(self.get_precision(domain, x[domain]))
+442            domains.append(x[domain])
+443        combined_scores = compute_fusion_scores(
+444            torch.stack(scores).unsqueeze(-1),
+445            torch.softmax(
+446                torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0
+447            ),
+448            self.sensitivity_selection,
+449            self.sensitivity_precision,
+450        )
+451        return torch.tanh(
+452            torch.sum(
+453                combined_scores * torch.stack(domains),
+454                dim=0,
+455            )
+456        )
+
+ + +

GWModule with a Bayesian based uncertainty prediction.

+
+ + +
+ +
+ + GWModuleBayesian( domain_modules: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], workspace_dim: int, gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], sensitivity_selection: float = 1, sensitivity_precision: float = 1, precision_softmax_temp: float = 0.01) + + + +
+ +
334    def __init__(
+335        self,
+336        domain_modules: Mapping[str, DomainModule],
+337        workspace_dim: int,
+338        gw_encoders: Mapping[str, nn.Module],
+339        gw_decoders: Mapping[str, nn.Module],
+340        sensitivity_selection: float = 1,
+341        sensitivity_precision: float = 1,
+342        precision_softmax_temp: float = 0.01,
+343    ) -> None:
+344        """
+345        Initializes the GWModuleBayesian.
+346
+347        Args:
+348            domain_modules (`Mapping[str, DomainModule]`): the domain modules.
+349            workspace_dim (`int`): dimension of the GW.
+350            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+351                name to a an torch.nn.Module class that encodes a
+352                unimodal latent representations into a GW representation (pre fusion).
+353            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
+354                name to a an torch.nn.Module class that decodes a
+355                 GW representation to a unimodal latent representation.
+356            sensitivity_selection (`float`): sensivity coef $c'_1$
+357            sensitivity_precision (`float`): sensitivity coef $c'_2$
+358            precision_softmax_temp (`float`): temperature to use in softmax of
+359                precision
+360        """
+361        super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders)
+362
+363        self.precisions = cast(
+364            dict[str, torch.Tensor],
+365            nn.ParameterDict(
+366                {domain: torch.randn(workspace_dim) for domain in gw_encoders}
+367            ),
+368        )
+369        """Precision at the neuron level for every domain."""
+370
+371        self.sensitivity_selection = sensitivity_selection
+372        self.sensitivity_precision = sensitivity_precision
+373        self.precision_softmax_temp = precision_softmax_temp
+
+ + +

Initializes the GWModuleBayesian.

+ +
Arguments:
+ +
    +
  • domain_modules (Mapping[str, DomainModule]): the domain modules.
  • +
  • workspace_dim (int): dimension of the GW.
  • +
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a an torch.nn.Module class that encodes a +unimodal latent representations into a GW representation (pre fusion).
  • +
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain +name to a an torch.nn.Module class that decodes a + GW representation to a unimodal latent representation.
  • +
  • sensitivity_selection (float): sensivity coef $c'_1$
  • +
  • sensitivity_precision (float): sensitivity coef $c'_2$
  • +
  • precision_softmax_temp (float): temperature to use in softmax of +precision
  • +
+
+ + +
+
+
+ precisions + + +
+ + +

Precision at the neuron level for every domain.

+
+ + +
+
+
+ sensitivity_selection + + +
+ + + + +
+
+
+ sensitivity_precision + + +
+ + + + +
+
+
+ precision_softmax_temp + + +
+ + + + +
+
+ +
+ + def + get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor: + + + +
+ +
375    def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor:
+376        """
+377        Get the precision vector of given domain and batch
+378
+379        Args:
+380            domain (`str`):
+381            x (`torch.Tensor`): batch of inputs
+382
+383        Returns:
+384            `torch.Tensor`: batch of precision
+385        """
+386        return self.precisions[domain].unsqueeze(0).expand(x.size(0), -1)
+
+ + +

Get the precision vector of given domain and batch

+ +
Arguments:
+ +
    +
  • domain (str):
  • +
  • x (torch.Tensor): batch of inputs
  • +
+ +
Returns:
+ +
+

torch.Tensor: batch of precision

+
+
+ + +
+
+ +
+ + def + fuse( self, x: collections.abc.Mapping[str, torch.Tensor], selection_scores: collections.abc.Mapping[str, torch.Tensor]) -> torch.Tensor: + + + +
+ +
388    def fuse(
+389        self,
+390        x: LatentsDomainGroupT,
+391        selection_scores: Mapping[str, torch.Tensor],
+392    ) -> torch.Tensor:
+393        """
+394        Merge function used to combine domains.
+395
+396        In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the
+397        dimension of the Global Workspace.
+398
+399        This function needs to merge two kind of scores:
+400        * the selection scores $a\\in [0,1]^{D\\times N}$;
+401        * the precision scores $b \\in [0,1]^{D\\times N \\times d}$.
+402
+403        .. note::
+404            The precision score is obtained by predicting logits and using a softmax
+405
+406        We can obtain associated uncertainties to the scores by introducing a std
+407        variable and using bayesian integration:
+408
+409        $$a_k = \\frac{M_1}{\\sigma_k^2}$$
+410        where $M_1 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\sigma_i^2}}$.
+411
+412        Similarly,
+413        $$b_k = \\frac{M_2}{\\mu_k^2}$$
+414        where $M_2 = \\frac{1}{\\sum_{i=1}^D \\frac{1}{\\mu_i^2}}$.
+415
+416        The we can sum the variances to obtain the final uncertainty (squared) $\\xi$:
+417        $$\\xi_k^2 = c_1 \\sigma_k^2 + c_2 \\mu_k^2$$
+418
+419        which, in terms of $a_k$ and $b_k$ yields:
+420        $$\\xi_k^2 = \\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}$$
+421        where $c'_1 = c_1 \\cdot M_1$ and $c'_2 = c_2 \\cdot M_2$.
+422
+423        Finally, the finale combined coefficient is
+424        $$\\lambda_k = \\frac{M_3}{\\frac{c'_1}{a_k} + \\frac{c'_2}{b_k}}$$
+425        where
+426        $$M_3 = \\frac{1}{\\sum_{i=1}^D
+427            \\frac{1}{\\frac{c'_1}{a_i} + \\frac{c'_2}{b_i}}}$$
+428
+429        Args:
+430            x (`LatentsDomainGroupT`): the group of latent representation.
+431            selection_score (`Mapping[str, torch.Tensor]`): attention scores to
+432                use to encode the reprensetation.
+433        Returns:
+434            `torch.Tensor`: The merged representation.
+435        """
+436        scores: list[torch.Tensor] = []
+437        precisions: list[torch.Tensor] = []
+438        domains: list[torch.Tensor] = []
+439        for domain, score in selection_scores.items():
+440            scores.append(score)
+441            precisions.append(self.get_precision(domain, x[domain]))
+442            domains.append(x[domain])
+443        combined_scores = compute_fusion_scores(
+444            torch.stack(scores).unsqueeze(-1),
+445            torch.softmax(
+446                torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0
+447            ),
+448            self.sensitivity_selection,
+449            self.sensitivity_precision,
+450        )
+451        return torch.tanh(
+452            torch.sum(
+453                combined_scores * torch.stack(domains),
+454                dim=0,
+455            )
+456        )
+
+ + +

Merge function used to combine domains.

+ +

In the following, $D$ is the number of domains, $N$ the batch size, and $d$ the +dimension of the Global Workspace.

+ +

This function needs to merge two kind of scores:

+ +
    +
  • the selection scores $a\in [0,1]^{D\times N}$;
  • +
  • the precision scores $b \in [0,1]^{D\times N \times d}$.
  • +
+ +
+ +

The precision score is obtained by predicting logits and using a softmax

+ +
+ +

We can obtain associated uncertainties to the scores by introducing a std +variable and using bayesian integration:

+ +

$$a_k = \frac{M_1}{\sigma_k^2}$$ +where $M_1 = \frac{1}{\sum_{i=1}^D \frac{1}{\sigma_i^2}}$.

+ +

Similarly, +$$b_k = \frac{M_2}{\mu_k^2}$$ +where $M_2 = \frac{1}{\sum_{i=1}^D \frac{1}{\mu_i^2}}$.

+ +

The we can sum the variances to obtain the final uncertainty (squared) $\xi$: +$$\xi_k^2 = c_1 \sigma_k^2 + c_2 \mu_k^2$$

+ +

which, in terms of $a_k$ and $b_k$ yields: +$$\xi_k^2 = \frac{c'_1}{a_k} + \frac{c'_2}{b_k}$$ +where $c'_1 = c_1 \cdot M_1$ and $c'_2 = c_2 \cdot M_2$.

+ +

Finally, the finale combined coefficient is +$$\lambda_k = \frac{M_3}{\frac{c'_1}{a_k} + \frac{c'_2}{b_k}}$$ +where +$$M_3 = \frac{1}{\sum_{i=1}^D + \frac{1}{\frac{c'_1}{a_i} + \frac{c'_2}{b_i}}}$$

+ +
Arguments:
+ +
    +
  • x (LatentsDomainGroupT): the group of latent representation.
  • +
  • selection_score (Mapping[str, torch.Tensor]): attention scores to +use to encode the reprensetation.
  • +
+ +
Returns:
+ +
+

torch.Tensor: The merged representation.

+
+
+ + +
+
+
Inherited Members
+
+ + +
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/losses.html b/docs/api/v0.5.1/shimmer/modules/losses.html new file mode 100644 index 00000000..4d306516 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/losses.html @@ -0,0 +1,3888 @@ + + + + + + + shimmer.modules.losses API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.losses

+ + + + + + +
  1from abc import ABC, abstractmethod
+  2from collections.abc import Generator, Mapping
+  3from itertools import product
+  4from typing import TypedDict
+  5
+  6import torch
+  7
+  8from shimmer.modules.contrastive_loss import ContrastiveLossType
+  9from shimmer.modules.domain import DomainModule, LossOutput
+ 10from shimmer.modules.gw_module import (
+ 11    GWModule,
+ 12    GWModuleBase,
+ 13    GWModuleBayesian,
+ 14)
+ 15from shimmer.modules.selection import SelectionBase
+ 16from shimmer.types import LatentsDomainGroupsT, ModelModeT
+ 17
+ 18
+ 19class GWLossesBase(torch.nn.Module, ABC):
+ 20    """
+ 21    Base Abstract Class for Global Workspace (GW) losses. This module is used
+ 22    to compute the different losses of the GW (typically translation, cycle,
+ 23    demi-cycle, contrastive losses).
+ 24    """
+ 25
+ 26    @abstractmethod
+ 27    def step(
+ 28        self,
+ 29        domain_latents: LatentsDomainGroupsT,
+ 30        mode: ModelModeT,
+ 31    ) -> LossOutput:
+ 32        """
+ 33        Computes the losses.
+ 34
+ 35        Args:
+ 36            domain_latents (`LatentsDomainGroupsT`): All latent groups
+ 37            mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+ 38        Returns:
+ 39            `LossOutput`: the losses
+ 40        """
+ 41        ...
+ 42
+ 43
+ 44def demi_cycle_loss(
+ 45    gw_mod: GWModuleBase,
+ 46    selection_mod: SelectionBase,
+ 47    domain_mods: Mapping[str, DomainModule],
+ 48    latent_domains: LatentsDomainGroupsT,
+ 49) -> dict[str, torch.Tensor]:
+ 50    """
+ 51    Computes the demi-cycle loss.
+ 52
+ 53    This return multiple metrics:
+ 54        * `demi_cycle_{domain_name}` with the demi-cycle of a particular domain;
+ 55        * `demi_cycle_{domain_name}_{metric}` with additional metrics provided by
+ 56            the domain_mod's `compute_dcy_loss` output;
+ 57        * `demi_cycles` with the average value of all `demi_cycle_{domain_name}` values.
+ 58
+ 59    Args:
+ 60        gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+ 61        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+ 62        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+ 63        latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal
+ 64            groups
+ 65
+ 66    Returns:
+ 67        `dict[str, torch.Tensor]`: a dict of metrics.
+ 68    """
+ 69    losses: dict[str, torch.Tensor] = {}
+ 70    metrics: dict[str, torch.Tensor] = {}
+ 71    for domains, latents in latent_domains.items():
+ 72        if len(domains) > 1:
+ 73            continue
+ 74        domain_name = next(iter(domains))
+ 75
+ 76        domain_mod = domain_mods[domain_name]
+ 77        x_recons = gw_mod.decode(
+ 78            gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name}
+ 79        )[domain_name]
+ 80        loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name])
+ 81        losses[f"demi_cycle_{domain_name}"] = loss_output.loss
+ 82        metrics.update(
+ 83            {f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()}
+ 84        )
+ 85    losses["demi_cycles"] = torch.stack(list(losses.values()), dim=0).mean()
+ 86    losses.update(metrics)
+ 87    return losses
+ 88
+ 89
+ 90def cycle_loss(
+ 91    gw_mod: GWModuleBase,
+ 92    selection_mod: SelectionBase,
+ 93    domain_mods: Mapping[str, DomainModule],
+ 94    latent_domains: LatentsDomainGroupsT,
+ 95) -> dict[str, torch.Tensor]:
+ 96    """
+ 97    Computes the cycle loss.
+ 98
+ 99    This return multiple metrics:
+100        * `cycle_{domain_source}_through_{domain_target}` with the cycle of
+101            a particular domain;
+102        * `cycle_{domain_source}_through_{domain_target}_{metric}` with additional
+103            metrics provided by the domain_mod's `compute_cy_loss` output;
+104        * `cycles` with the average value of all
+105            `cycle_{domain_source}_through_{domain_target}` values.
+106
+107    Args:
+108        gw_mod (`GWModuleBase`): The GWModule to use
+109        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+110        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+111        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+112
+113    Returns:
+114        `dict[str, torch.Tensor]`: a dict of metrics.
+115    """
+116    losses: dict[str, torch.Tensor] = {}
+117    metrics: dict[str, torch.Tensor] = {}
+118    for domains_source, latents_source in latent_domains.items():
+119        if len(domains_source) > 1:
+120            continue
+121        domain_name_source = list(domains_source)[0]
+122
+123        domain_mod = domain_mods[domain_name_source]
+124        z = gw_mod.encode_and_fuse(latents_source, selection_mod)
+125        for domain_name_target in domain_mods:
+126            if domain_name_target == domain_name_source:
+127                continue
+128
+129            x_pred = gw_mod.decode(z, domains={domain_name_target})
+130
+131            x_recons = gw_mod.decode(
+132                gw_mod.encode_and_fuse(x_pred, selection_mod),
+133                domains={domain_name_source},
+134            )
+135
+136            loss_name = f"{domain_name_source}_through_{domain_name_target}"
+137            loss_output = domain_mod.compute_cy_loss(
+138                x_recons[domain_name_source],
+139                latents_source[domain_name_source],
+140            )
+141            metrics.update(
+142                {f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+143            )
+144            losses[f"cycle_{loss_name}"] = loss_output.loss
+145    losses["cycles"] = torch.stack(list(losses.values()), dim=0).mean()
+146    losses.update(metrics)
+147    return losses
+148
+149
+150def translation_loss(
+151    gw_mod: GWModuleBase,
+152    selection_mod: SelectionBase,
+153    domain_mods: Mapping[str, DomainModule],
+154    latent_domains: LatentsDomainGroupsT,
+155) -> dict[str, torch.Tensor]:
+156    """
+157    Computes the translation loss.
+158
+159    This return multiple metrics:
+160        * `translation_{domain_source}_to_{domain_target}` with the translation
+161            from a domain source to a domain target;
+162        * `translation_{domain_source}_to_{domain_target}_{metric}` with
+163            additional metrics provided by the domain_mod's
+164            `compute_tr_loss` output;
+165        * `translations` with the average value of all
+166            `translation_{domain_source}_to_{domain_target}` values.
+167
+168    Args:
+169        gw_mod (`GWModuleBase`): The GWModule to use
+170        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+171        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+172
+173    Returns:
+174        `dict[str, torch.Tensor]`: a dict of metrics.
+175    """
+176    losses: dict[str, torch.Tensor] = {}
+177    metrics: dict[str, torch.Tensor] = {}
+178    for domains, latents in latent_domains.items():
+179        if len(domains) < 2:
+180            continue
+181        for domain_name_target in domains:
+182            domain_sources = {
+183                domain: latents[domain]
+184                for domain in domains
+185                if domain != domain_name_target
+186            }
+187
+188            z = gw_mod.encode_and_fuse(domain_sources, selection_mod)
+189            mod = domain_mods[domain_name_target]
+190
+191            domain_source_names = "/".join(domain_sources.keys())
+192            loss_name = f"{domain_source_names}_to_{domain_name_target}"
+193            if loss_name in losses:
+194                raise ValueError(f"{loss_name} is already computed.")
+195
+196            prediction = gw_mod.decode(z, domains={domain_name_target})[
+197                domain_name_target
+198            ]
+199            loss_output = mod.compute_tr_loss(
+200                prediction,
+201                latents[domain_name_target],
+202            )
+203            losses[f"translation_{loss_name}"] = loss_output.loss
+204            metrics.update(
+205                {
+206                    f"translation_{loss_name}_{k}": v
+207                    for k, v in loss_output.metrics.items()
+208                }
+209            )
+210    losses["translations"] = torch.stack(list(losses.values()), dim=0).mean()
+211    losses.update(metrics)
+212    return losses
+213
+214
+215def contrastive_loss(
+216    gw_mod: GWModuleBase,
+217    latent_domains: LatentsDomainGroupsT,
+218    contrastive_fn: ContrastiveLossType,
+219) -> dict[str, torch.Tensor]:
+220    """
+221    Computes the contrastive loss.
+222
+223    This return multiple metrics:
+224        * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+225            between 2 domains;
+226        * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+227            additional metrics provided by the domain_mod's
+228            `compute_cont_loss` output;
+229        * `contrastives` with the average value of all
+230            `contrastive_{domain_1}_and_{domain_2}` values.
+231
+232    Args:
+233        gw_mod (`GWModuleBase`): The GWModule to use
+234        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+235        contrastive_fn (`ContrastiveLossType`): the contrastive function to apply
+236
+237    Returns:
+238        `dict[str, torch.Tensor]`: a dict of metrics.
+239    """
+240    losses: dict[str, torch.Tensor] = {}
+241    metrics: dict[str, torch.Tensor] = {}
+242    keys: list[set[str]] = []
+243
+244    for latents in latent_domains.values():
+245        if len(latents) != 2:
+246            continue
+247
+248        cont_latents = gw_mod.encode(latents)
+249        for domain1, z1 in cont_latents.items():
+250            for domain2, z2 in cont_latents.items():
+251                selected_domains = {domain1, domain2}
+252                if domain1 == domain2 or selected_domains in keys:
+253                    continue
+254
+255                keys.append(selected_domains)
+256
+257                loss_name = f"contrastive_{domain1}_and_{domain2}"
+258                loss_output = contrastive_fn(z1, z2)
+259                losses[loss_name] = loss_output.loss
+260                metrics.update(
+261                    {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+262                )
+263
+264    losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
+265    losses.update(metrics)
+266    return losses
+267
+268
+269def contrastive_loss_bayesian(
+270    gw_mod: GWModuleBayesian,
+271    latent_domains: LatentsDomainGroupsT,
+272    contrastive_fn: ContrastiveLossType,
+273) -> dict[str, torch.Tensor]:
+274    """
+275    Computes the contrastive loss with a Bayesian based uncertainty prediction.
+276
+277    This return multiple metrics:
+278        * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+279            between 2 domains;
+280        * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+281            additional metrics provided by the domain_mod's
+282            `compute_cont_loss` output;
+283        * `contrastives` with the average value of all
+284            `contrastive_{domain_1}_and_{domain_2}` values.
+285
+286    Args:
+287        gw_mod (`GWModuleBayesian`): The GWModule to use
+288        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+289        contrastive_fn (`ContrastiveLossBayesianType`): the contrastive function
+290            to apply
+291
+292    Returns:
+293        `dict[str, torch.Tensor]`: a dict of metrics.
+294    """
+295    losses: dict[str, torch.Tensor] = {}
+296    metrics: dict[str, torch.Tensor] = {}
+297    keys: list[set[str]] = []
+298
+299    for latents in latent_domains.values():
+300        if len(latents) < 2:
+301            continue
+302        for domain1_name, domain1 in latents.items():
+303            z1 = gw_mod.encode({domain1_name: domain1})[domain1_name]
+304            z1_precision = gw_mod.get_precision(domain1_name, domain1)
+305            for domain2_name, domain2 in latents.items():
+306                selected_domains = {domain1_name, domain2_name}
+307                if domain1_name == domain2_name or selected_domains in keys:
+308                    continue
+309
+310                keys.append(selected_domains)
+311
+312                loss_name = f"contrastive_{domain1_name}_and_{domain2_name}"
+313                z2 = gw_mod.encode({domain2_name: domain2})[domain2_name]
+314                z2_precision = gw_mod.get_precision(domain2_name, domain2)
+315                coef = torch.softmax(
+316                    gw_mod.precision_softmax_temp
+317                    * torch.stack([z1_precision, z2_precision]),
+318                    dim=0,
+319                )
+320                norm = torch.sqrt(coef[0] * coef[1])
+321                loss_output = contrastive_fn(z1 * norm, z2 * norm)
+322                loss_output_no_norm = contrastive_fn(z1, z2)
+323                losses[loss_name] = loss_output.loss
+324                metrics.update(
+325                    {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+326                )
+327                metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss
+328
+329    losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
+330    losses.update(metrics)
+331    return losses
+332
+333
+334class LossCoefs(TypedDict, total=False):
+335    """
+336    Dict of loss coefficients used in the GWLosses.
+337
+338    If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+339    If the loss is excplicitely set to 0, it will be logged, but not take part in
+340    the total loss.
+341    """
+342
+343    demi_cycles: float
+344    """Demi-cycle loss coefficient."""
+345
+346    cycles: float
+347    """Cycle loss coefficient."""
+348
+349    translations: float
+350    """Translation loss coefficient."""
+351
+352    contrastives: float
+353    """Contrastive loss coefficient."""
+354
+355
+356class GWLosses2Domains(GWLossesBase):
+357    """
+358    Implementation of `GWLossesBase` used for `GWModule`.
+359    """
+360
+361    def __init__(
+362        self,
+363        gw_mod: GWModule,
+364        selection_mod: SelectionBase,
+365        domain_mods: dict[str, DomainModule],
+366        loss_coefs: LossCoefs,
+367        contrastive_fn: ContrastiveLossType,
+368    ):
+369        """
+370        Main loss module to use with the GlobalWorkspace
+371
+372        Args:
+373            gw_mod (`GWModule`): the GWModule
+374            selection_mod (`SelectionBase`): selection module
+375            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+376                domain name and value is the DomainModule
+377            loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+378                mapping to float with correct keys.
+379            contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+380                in contrastive loss
+381        """
+382
+383        super().__init__()
+384        self.gw_mod = gw_mod
+385        self.selection_mod = selection_mod
+386        self.domain_mods = domain_mods
+387        self.loss_coefs = loss_coefs
+388        self.contrastive_fn = contrastive_fn
+389
+390    def demi_cycle_loss(
+391        self, latent_domains: LatentsDomainGroupsT
+392    ) -> dict[str, torch.Tensor]:
+393        """
+394        Computes the demi-cycle loss.
+395
+396        See `shimmer.modules.losses.demi_cycle_loss`.
+397
+398        Args:
+399            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+400
+401        Returns:
+402            `dict[str, torch.Tensor]`: a dict of metrics.
+403        """
+404        return demi_cycle_loss(
+405            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+406        )
+407
+408    def cycle_loss(
+409        self, latent_domains: LatentsDomainGroupsT
+410    ) -> dict[str, torch.Tensor]:
+411        """
+412        Computes the cycle loss.
+413
+414        See `shimmer.modules.losses.cycle_loss`.
+415
+416        Args:
+417            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+418
+419        Returns:
+420            `dict[str, torch.Tensor]`: a dict of metrics.
+421        """
+422        return cycle_loss(
+423            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+424        )
+425
+426    def translation_loss(
+427        self, latent_domains: LatentsDomainGroupsT
+428    ) -> dict[str, torch.Tensor]:
+429        """
+430        Computes the translation loss.
+431
+432        See `shimmer.modules.losses.translation_loss`.
+433
+434        Args:
+435            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+436
+437        Returns:
+438            `dict[str, torch.Tensor]`: a dict of metrics.
+439        """
+440        return translation_loss(
+441            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+442        )
+443
+444    def contrastive_loss(
+445        self, latent_domains: LatentsDomainGroupsT
+446    ) -> dict[str, torch.Tensor]:
+447        """
+448        Computes the contrastive loss.
+449
+450        See `shimmer.modules.losses.contrastive_loss`.
+451
+452        Args:
+453            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+454
+455        Returns:
+456            `dict[str, torch.Tensor]`: a dict of metrics.
+457        """
+458        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+459
+460    def step(
+461        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+462    ) -> LossOutput:
+463        """
+464        Computes and returns the losses
+465
+466        Contains:
+467            - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+468            - Cycle metrics (see `GWLosses.cycle_loss`)
+469            - Translation metrics (see `GWLosses.translation_loss`)
+470            - Contrastive metrics (see `GWLosses.contrastive_loss`)
+471
+472        Args:
+473            domain_latents (`LatentsDomainGroupsT`): All latent groups
+474            mode (`ModelModeT`): model mode
+475        Returns:
+476            `LossOutput`: the losses
+477        """
+478        metrics: dict[str, torch.Tensor] = {}
+479
+480        metrics.update(self.demi_cycle_loss(domain_latents))
+481        metrics.update(self.cycle_loss(domain_latents))
+482        metrics.update(self.translation_loss(domain_latents))
+483        metrics.update(self.contrastive_loss(domain_latents))
+484
+485        loss = torch.stack(
+486            [
+487                metrics[name] * coef
+488                for name, coef in self.loss_coefs.items()
+489                if isinstance(coef, float) and coef > 0
+490            ],
+491            dim=0,
+492        ).mean()
+493
+494        return LossOutput(loss, metrics)
+495
+496
+497def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]:
+498    """
+499    Generates all possible partitions of zeros and ones for `n` elements,
+500    excluding the all-zeros partition.
+501
+502    Args:
+503        n (`int`): The number of modalities to generate partitions for.
+504
+505    Yields:
+506        `tuple[int, ...]`: A partition of zeros and ones, excluding the
+507        all-zeros partition.
+508    """
+509    for perm in product([0, 1], repeat=n):
+510        if any(perm):
+511            yield perm
+512
+513
+514def broadcast_loss(
+515    gw_mod: GWModuleBase,
+516    selection_mod: SelectionBase,
+517    domain_mods: Mapping[str, DomainModule],
+518    latent_domains: LatentsDomainGroupsT,
+519) -> dict[str, torch.Tensor]:
+520    """
+521    Computes broadcast loss including demi-cycle, cycle, and translation losses.
+522
+523    Args:
+524        gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+525        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+526        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+527        latent_domains: The latent domain representations.
+528
+529    Returns:
+530        A dictionary with the total loss and additional metrics.
+531    """
+532    losses: dict[str, torch.Tensor] = {}
+533    metrics: dict[str, torch.Tensor] = {}
+534
+535    demi_cycle_losses: list[str] = []
+536    cycle_losses: list[str] = []
+537    translation_losses: list[str] = []
+538    fused_losses: list[str] = []
+539
+540    for group_domains, latents in latent_domains.items():
+541        encoded_latents = gw_mod.encode(latents)
+542        partitions = generate_partitions(len(group_domains))
+543        domain_names = list(latents)
+544        group_name = "-".join(group_domains)
+545
+546        for partition in partitions:
+547            selected_latents = {
+548                domain: latents[domain]
+549                for domain, present in zip(domain_names, partition, strict=True)
+550                if present
+551            }
+552            selected_encoded_latents = {
+553                domain: encoded_latents[domain] for domain in selected_latents
+554            }
+555            selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}"
+556
+557            selection_scores = selection_mod(selected_latents, selected_encoded_latents)
+558            fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores)
+559            decoded_latents = gw_mod.decode(fused_latents)
+560
+561            num_active_domains = sum(partition)
+562            num_total_domains = len(decoded_latents)
+563
+564            for domain, pred in decoded_latents.items():
+565                if domain not in group_domains:  # if we don't have ground truth
+566                    continue
+567                ground_truth = latents[domain]
+568                loss_output = domain_mods[domain].compute_loss(pred, ground_truth)
+569                loss_label = f"from_{selected_group_label}_to_{domain}"
+570                losses[loss_label + "_loss"] = loss_output.loss
+571                metrics.update(
+572                    {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()}
+573                )
+574
+575                if num_active_domains == 1 and domain in selected_latents:
+576                    demi_cycle_losses.append(loss_label + "_loss")
+577                elif domain not in selected_latents:
+578                    translation_losses.append(loss_label + "_loss")
+579                else:  # fused loss
+580                    fused_losses.append(loss_label + "_loss")
+581
+582            if num_active_domains < num_total_domains:
+583                inverse_selected_latents = {
+584                    domain: decoded_latents[domain]
+585                    for domain in decoded_latents
+586                    if domain not in selected_latents
+587                }
+588
+589                inverse_selected_group_label = (
+590                    "{" + ", ".join(sorted(inverse_selected_latents)) + "}"
+591                )
+592
+593                re_encoded_latents = gw_mod.encode(inverse_selected_latents)
+594                re_selection_scores = selection_mod(
+595                    inverse_selected_latents, re_encoded_latents
+596                )
+597                re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores)
+598                re_decoded_latents = gw_mod.decode(
+599                    re_fused_latents, domains=selected_latents.keys()
+600                )
+601
+602                for domain in selected_latents:
+603                    re_ground_truth = latents[domain]
+604                    re_loss_output = domain_mods[domain].compute_loss(
+605                        re_decoded_latents[domain], re_ground_truth
+606                    )
+607                    loss_label = (
+608                        f"from_{selected_group_label}_"
+609                        f"through_{inverse_selected_group_label}_to_{domain}_"
+610                        f"case_{group_name}"
+611                    )
+612                    losses[loss_label + "_loss"] = re_loss_output.loss
+613                    metrics.update(
+614                        {
+615                            f"{loss_label}_{k}": v
+616                            for k, v in re_loss_output.metrics.items()
+617                        }
+618                    )
+619                    cycle_losses.append(loss_label + "_loss")
+620
+621    if demi_cycle_losses:
+622        metrics["demi_cycles"] = torch.mean(
+623            torch.stack([losses[loss_name] for loss_name in demi_cycle_losses])
+624        )
+625    if cycle_losses:
+626        metrics["cycles"] = torch.mean(
+627            torch.stack([losses[loss_name] for loss_name in cycle_losses])
+628        )
+629    if translation_losses:
+630        metrics["translations"] = torch.mean(
+631            torch.stack([losses[loss_name] for loss_name in translation_losses])
+632        )
+633    if fused_losses:
+634        metrics["fused"] = torch.mean(
+635            torch.stack([losses[loss_name] for loss_name in fused_losses])
+636        )
+637
+638    metrics.update(losses)
+639    return metrics
+640
+641
+642class BroadcastLossCoefs(TypedDict, total=False):
+643    """
+644    Dict of loss coefficients used in the GWLossesFusion.
+645
+646    If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+647    If the loss is excplicitely set to 0, it will be logged, but not take part in
+648    the total loss.
+649    """
+650
+651    contrastives: float
+652    """Contrastive loss coefficient."""
+653
+654    fused: float
+655    """fused loss coefficient (encode multiple domains and decode to one of them)."""
+656
+657    demi_cycles: float
+658    """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
+659
+660    cycles: float
+661    """cycles loss coefficient. Cycles can be many-to-one"""
+662
+663    translations: float
+664    """translation loss coefficient. Translation, like cycles, can be many-to-one."""
+665
+666
+667class GWLosses(GWLossesBase):
+668    """
+669    Implementation of `GWLossesBase` for fusion-based models.
+670    """
+671
+672    def __init__(
+673        self,
+674        gw_mod: GWModule,
+675        selection_mod: SelectionBase,
+676        domain_mods: dict[str, DomainModule],
+677        loss_coefs: BroadcastLossCoefs,
+678        contrastive_fn: ContrastiveLossType,
+679    ):
+680        """
+681        Initializes the loss computation module for a Global Workspace Fusion model.
+682
+683        Args:
+684            gw_mod: The GWModule for the global workspace.
+685            selection_mod: The selection mechanism for the model.
+686            domain_mods: A mapping of domain names to their respective DomainModule.
+687            loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+688            contrastive_fn: The function used for computing contrastive loss.
+689        """
+690        super().__init__()
+691        self.gw_mod = gw_mod
+692        self.selection_mod = selection_mod
+693        self.domain_mods = domain_mods
+694        self.loss_coefs = loss_coefs
+695        self.contrastive_fn = contrastive_fn
+696
+697    def contrastive_loss(
+698        self, latent_domains: LatentsDomainGroupsT
+699    ) -> dict[str, torch.Tensor]:
+700        """
+701        Computes the contrastive loss for the given latent domains.
+702
+703        Args:
+704            latent_domains: The latent domain representations.
+705
+706        Returns:
+707            A dictionary of contrastive loss metrics.
+708        """
+709
+710        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+711
+712    def broadcast_loss(
+713        self, latent_domains: LatentsDomainGroupsT
+714    ) -> dict[str, torch.Tensor]:
+715        return broadcast_loss(
+716            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+717        )
+718
+719    def step(
+720        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+721    ) -> LossOutput:
+722        """
+723        Performs a step of loss computation.
+724
+725        Args:
+726            domain_latents: Latent representations for all domains.
+727            mode: The mode in which the model is currently operating.
+728
+729        Returns:
+730            A LossOutput object containing the loss and metrics for this step.
+731        """
+732
+733        metrics: dict[str, torch.Tensor] = {}
+734
+735        metrics.update(self.contrastive_loss(domain_latents))
+736        metrics.update(self.broadcast_loss(domain_latents))
+737
+738        loss = torch.stack(
+739            [
+740                metrics[name] * coef
+741                for name, coef in self.loss_coefs.items()
+742                if isinstance(coef, float) and coef > 0
+743            ],
+744            dim=0,
+745        ).mean()
+746
+747        metrics["broadcast_loss"] = torch.stack(
+748            [
+749                metrics[name]
+750                for name, coef in self.loss_coefs.items()
+751                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+752            ],
+753            dim=0,
+754        ).mean()
+755
+756        return LossOutput(loss, metrics)
+757
+758
+759class GWLossesBayesian(GWLossesBase):
+760    """
+761    Implementation of `GWLossesBase` used for `GWModuleBayesian`.
+762    """
+763
+764    def __init__(
+765        self,
+766        gw_mod: GWModuleBayesian,
+767        selection_mod: SelectionBase,
+768        domain_mods: dict[str, DomainModule],
+769        loss_coefs: BroadcastLossCoefs,
+770        contrastive_fn: ContrastiveLossType,
+771        use_normalized_constrastive: bool = True,
+772    ):
+773        """
+774        Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+775
+776        Args:
+777            gw_mod (`GWModuleBayesian`): the GWModule
+778            selection_mod (`SelectionBase`): selection module
+779            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+780                domain name and value is the DomainModule
+781            loss_coefs (`BroadcastLossCoefs`): loss coefficients
+782            contrastive_fn (`ContrastiveLossType`): the contrastive function
+783                to use in contrastive loss
+784            use_normalized_constrastive (`bool`): whether to use the normalized cont
+785                loss by the precision coefs
+786        """
+787        super().__init__()
+788
+789        self.gw_mod = gw_mod
+790        """The GWModule."""
+791
+792        self.selection_mod = selection_mod
+793        """Selection module"""
+794
+795        self.domain_mods = domain_mods
+796        """Domain modules linked to the GW."""
+797
+798        self.loss_coefs = loss_coefs
+799        """The loss coefficients."""
+800
+801        self.contrastive_fn = contrastive_fn
+802        """
+803        Contrastive loss to use.
+804        """
+805
+806        self.use_normalized_constrastive = use_normalized_constrastive
+807
+808    def contrastive_loss(
+809        self, latent_domains: LatentsDomainGroupsT
+810    ) -> dict[str, torch.Tensor]:
+811        """
+812        Contrastive loss.
+813
+814        Args:
+815            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+816
+817        Returns:
+818            `dict[str, torch.Tensor]`: a dict of metrics.
+819        """
+820        if self.use_normalized_constrastive:
+821            return contrastive_loss_bayesian(
+822                self.gw_mod, latent_domains, self.contrastive_fn
+823            )
+824        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+825
+826    def broadcast_loss(
+827        self, latent_domains: LatentsDomainGroupsT
+828    ) -> dict[str, torch.Tensor]:
+829        return broadcast_loss(
+830            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+831        )
+832
+833    def step(
+834        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+835    ) -> LossOutput:
+836        """
+837        Performs a step of loss computation.
+838
+839        Args:
+840            domain_latents: Latent representations for all domains.
+841            mode: The mode in which the model is currently operating.
+842
+843        Returns:
+844            A LossOutput object containing the loss and metrics for this step.
+845        """
+846
+847        metrics: dict[str, torch.Tensor] = {}
+848
+849        metrics.update(self.contrastive_loss(domain_latents))
+850        metrics.update(self.broadcast_loss(domain_latents))
+851
+852        loss = torch.stack(
+853            [
+854                metrics[name] * coef
+855                for name, coef in self.loss_coefs.items()
+856                if isinstance(coef, float) and coef > 0
+857            ],
+858            dim=0,
+859        ).mean()
+860
+861        metrics["broadcast_loss"] = torch.stack(
+862            [
+863                metrics[name]
+864                for name, coef in self.loss_coefs.items()
+865                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+866            ],
+867            dim=0,
+868        ).mean()
+869
+870        return LossOutput(loss, metrics)
+
+ + +
+
+ +
+ + class + GWLossesBase(torch.nn.modules.module.Module, abc.ABC): + + + +
+ +
20class GWLossesBase(torch.nn.Module, ABC):
+21    """
+22    Base Abstract Class for Global Workspace (GW) losses. This module is used
+23    to compute the different losses of the GW (typically translation, cycle,
+24    demi-cycle, contrastive losses).
+25    """
+26
+27    @abstractmethod
+28    def step(
+29        self,
+30        domain_latents: LatentsDomainGroupsT,
+31        mode: ModelModeT,
+32    ) -> LossOutput:
+33        """
+34        Computes the losses.
+35
+36        Args:
+37            domain_latents (`LatentsDomainGroupsT`): All latent groups
+38            mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+39        Returns:
+40            `LossOutput`: the losses
+41        """
+42        ...
+
+ + +

Base Abstract Class for Global Workspace (GW) losses. This module is used +to compute the different losses of the GW (typically translation, cycle, +demi-cycle, contrastive losses).

+
+ + +
+ +
+
@abstractmethod
+ + def + step( self, domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput: + + + +
+ +
27    @abstractmethod
+28    def step(
+29        self,
+30        domain_latents: LatentsDomainGroupsT,
+31        mode: ModelModeT,
+32    ) -> LossOutput:
+33        """
+34        Computes the losses.
+35
+36        Args:
+37            domain_latents (`LatentsDomainGroupsT`): All latent groups
+38            mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): model mode
+39        Returns:
+40            `LossOutput`: the losses
+41        """
+42        ...
+
+ + +

Computes the losses.

+ +
Arguments:
+ +
    +
  • domain_latents (LatentsDomainGroupsT): All latent groups
  • +
  • mode (Literal["train", "val", "test", "val/ood", "test/ood"]): model mode
  • +
+ +
Returns:
+ +
+

LossOutput: the losses

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + def + demi_cycle_loss( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
45def demi_cycle_loss(
+46    gw_mod: GWModuleBase,
+47    selection_mod: SelectionBase,
+48    domain_mods: Mapping[str, DomainModule],
+49    latent_domains: LatentsDomainGroupsT,
+50) -> dict[str, torch.Tensor]:
+51    """
+52    Computes the demi-cycle loss.
+53
+54    This return multiple metrics:
+55        * `demi_cycle_{domain_name}` with the demi-cycle of a particular domain;
+56        * `demi_cycle_{domain_name}_{metric}` with additional metrics provided by
+57            the domain_mod's `compute_dcy_loss` output;
+58        * `demi_cycles` with the average value of all `demi_cycle_{domain_name}` values.
+59
+60    Args:
+61        gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+62        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+63        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+64        latent_domains (`shimmer.types.LatentsDomainGroupsT`): the latent unimodal
+65            groups
+66
+67    Returns:
+68        `dict[str, torch.Tensor]`: a dict of metrics.
+69    """
+70    losses: dict[str, torch.Tensor] = {}
+71    metrics: dict[str, torch.Tensor] = {}
+72    for domains, latents in latent_domains.items():
+73        if len(domains) > 1:
+74            continue
+75        domain_name = next(iter(domains))
+76
+77        domain_mod = domain_mods[domain_name]
+78        x_recons = gw_mod.decode(
+79            gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name}
+80        )[domain_name]
+81        loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name])
+82        losses[f"demi_cycle_{domain_name}"] = loss_output.loss
+83        metrics.update(
+84            {f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()}
+85        )
+86    losses["demi_cycles"] = torch.stack(list(losses.values()), dim=0).mean()
+87    losses.update(metrics)
+88    return losses
+
+ + +

Computes the demi-cycle loss.

+ +
This return multiple metrics:
+ +
+
    +
  • demi_cycle_{domain_name} with the demi-cycle of a particular domain;
  • +
  • demi_cycle_{domain_name}_{metric} with additional metrics provided by + the domain_mod's compute_dcy_loss output;
  • +
  • demi_cycles with the average value of all demi_cycle_{domain_name} values.
  • +
+
+ +
Arguments:
+ + + +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + cycle_loss( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
 91def cycle_loss(
+ 92    gw_mod: GWModuleBase,
+ 93    selection_mod: SelectionBase,
+ 94    domain_mods: Mapping[str, DomainModule],
+ 95    latent_domains: LatentsDomainGroupsT,
+ 96) -> dict[str, torch.Tensor]:
+ 97    """
+ 98    Computes the cycle loss.
+ 99
+100    This return multiple metrics:
+101        * `cycle_{domain_source}_through_{domain_target}` with the cycle of
+102            a particular domain;
+103        * `cycle_{domain_source}_through_{domain_target}_{metric}` with additional
+104            metrics provided by the domain_mod's `compute_cy_loss` output;
+105        * `cycles` with the average value of all
+106            `cycle_{domain_source}_through_{domain_target}` values.
+107
+108    Args:
+109        gw_mod (`GWModuleBase`): The GWModule to use
+110        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+111        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+112        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+113
+114    Returns:
+115        `dict[str, torch.Tensor]`: a dict of metrics.
+116    """
+117    losses: dict[str, torch.Tensor] = {}
+118    metrics: dict[str, torch.Tensor] = {}
+119    for domains_source, latents_source in latent_domains.items():
+120        if len(domains_source) > 1:
+121            continue
+122        domain_name_source = list(domains_source)[0]
+123
+124        domain_mod = domain_mods[domain_name_source]
+125        z = gw_mod.encode_and_fuse(latents_source, selection_mod)
+126        for domain_name_target in domain_mods:
+127            if domain_name_target == domain_name_source:
+128                continue
+129
+130            x_pred = gw_mod.decode(z, domains={domain_name_target})
+131
+132            x_recons = gw_mod.decode(
+133                gw_mod.encode_and_fuse(x_pred, selection_mod),
+134                domains={domain_name_source},
+135            )
+136
+137            loss_name = f"{domain_name_source}_through_{domain_name_target}"
+138            loss_output = domain_mod.compute_cy_loss(
+139                x_recons[domain_name_source],
+140                latents_source[domain_name_source],
+141            )
+142            metrics.update(
+143                {f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+144            )
+145            losses[f"cycle_{loss_name}"] = loss_output.loss
+146    losses["cycles"] = torch.stack(list(losses.values()), dim=0).mean()
+147    losses.update(metrics)
+148    return losses
+
+ + +

Computes the cycle loss.

+ +
This return multiple metrics:
+ +
+
    +
  • cycle_{domain_source}_through_{domain_target} with the cycle of + a particular domain;
  • +
  • cycle_{domain_source}_through_{domain_target}_{metric} with additional + metrics provided by the domain_mod's compute_cy_loss output;
  • +
  • cycles with the average value of all + cycle_{domain_source}_through_{domain_target} values.
  • +
+
+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): The GWModule to use
  • +
  • selection_mod (shimmer.modules.selection.SelectionBase): Selection mod to use
  • +
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + translation_loss( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
151def translation_loss(
+152    gw_mod: GWModuleBase,
+153    selection_mod: SelectionBase,
+154    domain_mods: Mapping[str, DomainModule],
+155    latent_domains: LatentsDomainGroupsT,
+156) -> dict[str, torch.Tensor]:
+157    """
+158    Computes the translation loss.
+159
+160    This return multiple metrics:
+161        * `translation_{domain_source}_to_{domain_target}` with the translation
+162            from a domain source to a domain target;
+163        * `translation_{domain_source}_to_{domain_target}_{metric}` with
+164            additional metrics provided by the domain_mod's
+165            `compute_tr_loss` output;
+166        * `translations` with the average value of all
+167            `translation_{domain_source}_to_{domain_target}` values.
+168
+169    Args:
+170        gw_mod (`GWModuleBase`): The GWModule to use
+171        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+172        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+173
+174    Returns:
+175        `dict[str, torch.Tensor]`: a dict of metrics.
+176    """
+177    losses: dict[str, torch.Tensor] = {}
+178    metrics: dict[str, torch.Tensor] = {}
+179    for domains, latents in latent_domains.items():
+180        if len(domains) < 2:
+181            continue
+182        for domain_name_target in domains:
+183            domain_sources = {
+184                domain: latents[domain]
+185                for domain in domains
+186                if domain != domain_name_target
+187            }
+188
+189            z = gw_mod.encode_and_fuse(domain_sources, selection_mod)
+190            mod = domain_mods[domain_name_target]
+191
+192            domain_source_names = "/".join(domain_sources.keys())
+193            loss_name = f"{domain_source_names}_to_{domain_name_target}"
+194            if loss_name in losses:
+195                raise ValueError(f"{loss_name} is already computed.")
+196
+197            prediction = gw_mod.decode(z, domains={domain_name_target})[
+198                domain_name_target
+199            ]
+200            loss_output = mod.compute_tr_loss(
+201                prediction,
+202                latents[domain_name_target],
+203            )
+204            losses[f"translation_{loss_name}"] = loss_output.loss
+205            metrics.update(
+206                {
+207                    f"translation_{loss_name}_{k}": v
+208                    for k, v in loss_output.metrics.items()
+209                }
+210            )
+211    losses["translations"] = torch.stack(list(losses.values()), dim=0).mean()
+212    losses.update(metrics)
+213    return losses
+
+ + +

Computes the translation loss.

+ +
This return multiple metrics:
+ +
+
    +
  • translation_{domain_source}_to_{domain_target} with the translation + from a domain source to a domain target;
  • +
  • translation_{domain_source}_to_{domain_target}_{metric} with + additional metrics provided by the domain_mod's + compute_tr_loss output;
  • +
  • translations with the average value of all + translation_{domain_source}_to_{domain_target} values.
  • +
+
+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): The GWModule to use
  • +
  • domain_mods (Mapping[str, DomainModule]): the domain modules
  • +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + contrastive_loss( gw_mod: shimmer.modules.gw_module.GWModuleBase, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) -> dict[str, torch.Tensor]: + + + +
+ +
216def contrastive_loss(
+217    gw_mod: GWModuleBase,
+218    latent_domains: LatentsDomainGroupsT,
+219    contrastive_fn: ContrastiveLossType,
+220) -> dict[str, torch.Tensor]:
+221    """
+222    Computes the contrastive loss.
+223
+224    This return multiple metrics:
+225        * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+226            between 2 domains;
+227        * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+228            additional metrics provided by the domain_mod's
+229            `compute_cont_loss` output;
+230        * `contrastives` with the average value of all
+231            `contrastive_{domain_1}_and_{domain_2}` values.
+232
+233    Args:
+234        gw_mod (`GWModuleBase`): The GWModule to use
+235        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+236        contrastive_fn (`ContrastiveLossType`): the contrastive function to apply
+237
+238    Returns:
+239        `dict[str, torch.Tensor]`: a dict of metrics.
+240    """
+241    losses: dict[str, torch.Tensor] = {}
+242    metrics: dict[str, torch.Tensor] = {}
+243    keys: list[set[str]] = []
+244
+245    for latents in latent_domains.values():
+246        if len(latents) != 2:
+247            continue
+248
+249        cont_latents = gw_mod.encode(latents)
+250        for domain1, z1 in cont_latents.items():
+251            for domain2, z2 in cont_latents.items():
+252                selected_domains = {domain1, domain2}
+253                if domain1 == domain2 or selected_domains in keys:
+254                    continue
+255
+256                keys.append(selected_domains)
+257
+258                loss_name = f"contrastive_{domain1}_and_{domain2}"
+259                loss_output = contrastive_fn(z1, z2)
+260                losses[loss_name] = loss_output.loss
+261                metrics.update(
+262                    {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+263                )
+264
+265    losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
+266    losses.update(metrics)
+267    return losses
+
+ + +

Computes the contrastive loss.

+ +
This return multiple metrics:
+ +
+
    +
  • contrastive_{domain_1}_and_{domain_2} with the contrastive + between 2 domains;
  • +
  • contrastive_{domain_1}_and_{domain_2}_{metric} with + additional metrics provided by the domain_mod's + compute_cont_loss output;
  • +
  • contrastives with the average value of all + contrastive_{domain_1}_and_{domain_2} values.
  • +
+
+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): The GWModule to use
  • +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
  • contrastive_fn (ContrastiveLossType): the contrastive function to apply
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + contrastive_loss_bayesian( gw_mod: shimmer.modules.gw_module.GWModuleBayesian, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) -> dict[str, torch.Tensor]: + + + +
+ +
270def contrastive_loss_bayesian(
+271    gw_mod: GWModuleBayesian,
+272    latent_domains: LatentsDomainGroupsT,
+273    contrastive_fn: ContrastiveLossType,
+274) -> dict[str, torch.Tensor]:
+275    """
+276    Computes the contrastive loss with a Bayesian based uncertainty prediction.
+277
+278    This return multiple metrics:
+279        * `contrastive_{domain_1}_and_{domain_2}` with the contrastive
+280            between 2 domains;
+281        * `contrastive_{domain_1}_and_{domain_2}_{metric}` with
+282            additional metrics provided by the domain_mod's
+283            `compute_cont_loss` output;
+284        * `contrastives` with the average value of all
+285            `contrastive_{domain_1}_and_{domain_2}` values.
+286
+287    Args:
+288        gw_mod (`GWModuleBayesian`): The GWModule to use
+289        latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+290        contrastive_fn (`ContrastiveLossBayesianType`): the contrastive function
+291            to apply
+292
+293    Returns:
+294        `dict[str, torch.Tensor]`: a dict of metrics.
+295    """
+296    losses: dict[str, torch.Tensor] = {}
+297    metrics: dict[str, torch.Tensor] = {}
+298    keys: list[set[str]] = []
+299
+300    for latents in latent_domains.values():
+301        if len(latents) < 2:
+302            continue
+303        for domain1_name, domain1 in latents.items():
+304            z1 = gw_mod.encode({domain1_name: domain1})[domain1_name]
+305            z1_precision = gw_mod.get_precision(domain1_name, domain1)
+306            for domain2_name, domain2 in latents.items():
+307                selected_domains = {domain1_name, domain2_name}
+308                if domain1_name == domain2_name or selected_domains in keys:
+309                    continue
+310
+311                keys.append(selected_domains)
+312
+313                loss_name = f"contrastive_{domain1_name}_and_{domain2_name}"
+314                z2 = gw_mod.encode({domain2_name: domain2})[domain2_name]
+315                z2_precision = gw_mod.get_precision(domain2_name, domain2)
+316                coef = torch.softmax(
+317                    gw_mod.precision_softmax_temp
+318                    * torch.stack([z1_precision, z2_precision]),
+319                    dim=0,
+320                )
+321                norm = torch.sqrt(coef[0] * coef[1])
+322                loss_output = contrastive_fn(z1 * norm, z2 * norm)
+323                loss_output_no_norm = contrastive_fn(z1, z2)
+324                losses[loss_name] = loss_output.loss
+325                metrics.update(
+326                    {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
+327                )
+328                metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss
+329
+330    losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
+331    losses.update(metrics)
+332    return losses
+
+ + +

Computes the contrastive loss with a Bayesian based uncertainty prediction.

+ +
This return multiple metrics:
+ +
+
    +
  • contrastive_{domain_1}_and_{domain_2} with the contrastive + between 2 domains;
  • +
  • contrastive_{domain_1}_and_{domain_2}_{metric} with + additional metrics provided by the domain_mod's + compute_cont_loss output;
  • +
  • contrastives with the average value of all + contrastive_{domain_1}_and_{domain_2} values.
  • +
+
+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBayesian): The GWModule to use
  • +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
  • contrastive_fn (ContrastiveLossBayesianType): the contrastive function +to apply
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + class + LossCoefs(typing.TypedDict): + + + +
+ +
335class LossCoefs(TypedDict, total=False):
+336    """
+337    Dict of loss coefficients used in the GWLosses.
+338
+339    If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+340    If the loss is excplicitely set to 0, it will be logged, but not take part in
+341    the total loss.
+342    """
+343
+344    demi_cycles: float
+345    """Demi-cycle loss coefficient."""
+346
+347    cycles: float
+348    """Cycle loss coefficient."""
+349
+350    translations: float
+351    """Translation loss coefficient."""
+352
+353    contrastives: float
+354    """Contrastive loss coefficient."""
+
+ + +

Dict of loss coefficients used in the GWLosses.

+ +

If one is not provided, the coefficient is assumed to be 0 and will not be logged. +If the loss is excplicitely set to 0, it will be logged, but not take part in +the total loss.

+
+ + +
+
+ demi_cycles: float + + +
+ + +

Demi-cycle loss coefficient.

+
+ + +
+
+
+ cycles: float + + +
+ + +

Cycle loss coefficient.

+
+ + +
+
+
+ translations: float + + +
+ + +

Translation loss coefficient.

+
+ + +
+
+
+ contrastives: float + + +
+ + +

Contrastive loss coefficient.

+
+ + +
+
+
+ +
+ + class + GWLosses2Domains(GWLossesBase): + + + +
+ +
357class GWLosses2Domains(GWLossesBase):
+358    """
+359    Implementation of `GWLossesBase` used for `GWModule`.
+360    """
+361
+362    def __init__(
+363        self,
+364        gw_mod: GWModule,
+365        selection_mod: SelectionBase,
+366        domain_mods: dict[str, DomainModule],
+367        loss_coefs: LossCoefs,
+368        contrastive_fn: ContrastiveLossType,
+369    ):
+370        """
+371        Main loss module to use with the GlobalWorkspace
+372
+373        Args:
+374            gw_mod (`GWModule`): the GWModule
+375            selection_mod (`SelectionBase`): selection module
+376            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+377                domain name and value is the DomainModule
+378            loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+379                mapping to float with correct keys.
+380            contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+381                in contrastive loss
+382        """
+383
+384        super().__init__()
+385        self.gw_mod = gw_mod
+386        self.selection_mod = selection_mod
+387        self.domain_mods = domain_mods
+388        self.loss_coefs = loss_coefs
+389        self.contrastive_fn = contrastive_fn
+390
+391    def demi_cycle_loss(
+392        self, latent_domains: LatentsDomainGroupsT
+393    ) -> dict[str, torch.Tensor]:
+394        """
+395        Computes the demi-cycle loss.
+396
+397        See `shimmer.modules.losses.demi_cycle_loss`.
+398
+399        Args:
+400            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+401
+402        Returns:
+403            `dict[str, torch.Tensor]`: a dict of metrics.
+404        """
+405        return demi_cycle_loss(
+406            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+407        )
+408
+409    def cycle_loss(
+410        self, latent_domains: LatentsDomainGroupsT
+411    ) -> dict[str, torch.Tensor]:
+412        """
+413        Computes the cycle loss.
+414
+415        See `shimmer.modules.losses.cycle_loss`.
+416
+417        Args:
+418            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+419
+420        Returns:
+421            `dict[str, torch.Tensor]`: a dict of metrics.
+422        """
+423        return cycle_loss(
+424            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+425        )
+426
+427    def translation_loss(
+428        self, latent_domains: LatentsDomainGroupsT
+429    ) -> dict[str, torch.Tensor]:
+430        """
+431        Computes the translation loss.
+432
+433        See `shimmer.modules.losses.translation_loss`.
+434
+435        Args:
+436            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+437
+438        Returns:
+439            `dict[str, torch.Tensor]`: a dict of metrics.
+440        """
+441        return translation_loss(
+442            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+443        )
+444
+445    def contrastive_loss(
+446        self, latent_domains: LatentsDomainGroupsT
+447    ) -> dict[str, torch.Tensor]:
+448        """
+449        Computes the contrastive loss.
+450
+451        See `shimmer.modules.losses.contrastive_loss`.
+452
+453        Args:
+454            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+455
+456        Returns:
+457            `dict[str, torch.Tensor]`: a dict of metrics.
+458        """
+459        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+460
+461    def step(
+462        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+463    ) -> LossOutput:
+464        """
+465        Computes and returns the losses
+466
+467        Contains:
+468            - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+469            - Cycle metrics (see `GWLosses.cycle_loss`)
+470            - Translation metrics (see `GWLosses.translation_loss`)
+471            - Contrastive metrics (see `GWLosses.contrastive_loss`)
+472
+473        Args:
+474            domain_latents (`LatentsDomainGroupsT`): All latent groups
+475            mode (`ModelModeT`): model mode
+476        Returns:
+477            `LossOutput`: the losses
+478        """
+479        metrics: dict[str, torch.Tensor] = {}
+480
+481        metrics.update(self.demi_cycle_loss(domain_latents))
+482        metrics.update(self.cycle_loss(domain_latents))
+483        metrics.update(self.translation_loss(domain_latents))
+484        metrics.update(self.contrastive_loss(domain_latents))
+485
+486        loss = torch.stack(
+487            [
+488                metrics[name] * coef
+489                for name, coef in self.loss_coefs.items()
+490                if isinstance(coef, float) and coef > 0
+491            ],
+492            dim=0,
+493        ).mean()
+494
+495        return LossOutput(loss, metrics)
+
+ + +

Implementation of GWLossesBase used for GWModule.

+
+ + +
+ +
+ + GWLosses2Domains( gw_mod: shimmer.modules.gw_module.GWModule, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: dict[str, shimmer.modules.domain.DomainModule], loss_coefs: LossCoefs, contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) + + + +
+ +
362    def __init__(
+363        self,
+364        gw_mod: GWModule,
+365        selection_mod: SelectionBase,
+366        domain_mods: dict[str, DomainModule],
+367        loss_coefs: LossCoefs,
+368        contrastive_fn: ContrastiveLossType,
+369    ):
+370        """
+371        Main loss module to use with the GlobalWorkspace
+372
+373        Args:
+374            gw_mod (`GWModule`): the GWModule
+375            selection_mod (`SelectionBase`): selection module
+376            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+377                domain name and value is the DomainModule
+378            loss_coefs (`LossCoefs`): loss coefficients. LossCoefs object, or a
+379                mapping to float with correct keys.
+380            contrastive_fn (`ContrastiveLossType`): the contrastive function to use
+381                in contrastive loss
+382        """
+383
+384        super().__init__()
+385        self.gw_mod = gw_mod
+386        self.selection_mod = selection_mod
+387        self.domain_mods = domain_mods
+388        self.loss_coefs = loss_coefs
+389        self.contrastive_fn = contrastive_fn
+
+ + +

Main loss module to use with the GlobalWorkspace

+ +
Arguments:
+ +
    +
  • gw_mod (GWModule): the GWModule
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • domain_mods (dict[str, DomainModule]): a dict where the key is the +domain name and value is the DomainModule
  • +
  • loss_coefs (LossCoefs): loss coefficients. LossCoefs object, or a +mapping to float with correct keys.
  • +
  • contrastive_fn (ContrastiveLossType): the contrastive function to use +in contrastive loss
  • +
+
+ + +
+
+
+ gw_mod + + +
+ + + + +
+
+
+ selection_mod + + +
+ + + + +
+
+
+ domain_mods + + +
+ + + + +
+
+
+ loss_coefs + + +
+ + + + +
+
+
+ contrastive_fn + + +
+ + + + +
+
+ +
+ + def + demi_cycle_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
391    def demi_cycle_loss(
+392        self, latent_domains: LatentsDomainGroupsT
+393    ) -> dict[str, torch.Tensor]:
+394        """
+395        Computes the demi-cycle loss.
+396
+397        See `shimmer.modules.losses.demi_cycle_loss`.
+398
+399        Args:
+400            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+401
+402        Returns:
+403            `dict[str, torch.Tensor]`: a dict of metrics.
+404        """
+405        return demi_cycle_loss(
+406            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+407        )
+
+ + +

Computes the demi-cycle loss.

+ +

See demi_cycle_loss.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + cycle_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
409    def cycle_loss(
+410        self, latent_domains: LatentsDomainGroupsT
+411    ) -> dict[str, torch.Tensor]:
+412        """
+413        Computes the cycle loss.
+414
+415        See `shimmer.modules.losses.cycle_loss`.
+416
+417        Args:
+418            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+419
+420        Returns:
+421            `dict[str, torch.Tensor]`: a dict of metrics.
+422        """
+423        return cycle_loss(
+424            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+425        )
+
+ + +

Computes the cycle loss.

+ +

See cycle_loss.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + translation_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
427    def translation_loss(
+428        self, latent_domains: LatentsDomainGroupsT
+429    ) -> dict[str, torch.Tensor]:
+430        """
+431        Computes the translation loss.
+432
+433        See `shimmer.modules.losses.translation_loss`.
+434
+435        Args:
+436            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+437
+438        Returns:
+439            `dict[str, torch.Tensor]`: a dict of metrics.
+440        """
+441        return translation_loss(
+442            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+443        )
+
+ + +

Computes the translation loss.

+ +

See translation_loss.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + contrastive_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
445    def contrastive_loss(
+446        self, latent_domains: LatentsDomainGroupsT
+447    ) -> dict[str, torch.Tensor]:
+448        """
+449        Computes the contrastive loss.
+450
+451        See `shimmer.modules.losses.contrastive_loss`.
+452
+453        Args:
+454            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+455
+456        Returns:
+457            `dict[str, torch.Tensor]`: a dict of metrics.
+458        """
+459        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+
+ + +

Computes the contrastive loss.

+ +

See contrastive_loss.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + step( self, domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput: + + + +
+ +
461    def step(
+462        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+463    ) -> LossOutput:
+464        """
+465        Computes and returns the losses
+466
+467        Contains:
+468            - Demi-cycle metrics (see `GWLosses.demi_cycle_loss`)
+469            - Cycle metrics (see `GWLosses.cycle_loss`)
+470            - Translation metrics (see `GWLosses.translation_loss`)
+471            - Contrastive metrics (see `GWLosses.contrastive_loss`)
+472
+473        Args:
+474            domain_latents (`LatentsDomainGroupsT`): All latent groups
+475            mode (`ModelModeT`): model mode
+476        Returns:
+477            `LossOutput`: the losses
+478        """
+479        metrics: dict[str, torch.Tensor] = {}
+480
+481        metrics.update(self.demi_cycle_loss(domain_latents))
+482        metrics.update(self.cycle_loss(domain_latents))
+483        metrics.update(self.translation_loss(domain_latents))
+484        metrics.update(self.contrastive_loss(domain_latents))
+485
+486        loss = torch.stack(
+487            [
+488                metrics[name] * coef
+489                for name, coef in self.loss_coefs.items()
+490                if isinstance(coef, float) and coef > 0
+491            ],
+492            dim=0,
+493        ).mean()
+494
+495        return LossOutput(loss, metrics)
+
+ + +

Computes and returns the losses

+ +
Contains:
+ +
+
    +
  • Demi-cycle metrics (see GWLosses.demi_cycle_loss)
  • +
  • Cycle metrics (see GWLosses.cycle_loss)
  • +
  • Translation metrics (see GWLosses.translation_loss)
  • +
  • Contrastive metrics (see GWLosses.contrastive_loss)
  • +
+
+ +
Arguments:
+ +
    +
  • domain_latents (LatentsDomainGroupsT): All latent groups
  • +
  • mode (ModelModeT): model mode
  • +
+ +
Returns:
+ +
+

LossOutput: the losses

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + def + generate_partitions(n: int) -> collections.abc.Generator[tuple[int, ...], None, None]: + + + +
+ +
498def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]:
+499    """
+500    Generates all possible partitions of zeros and ones for `n` elements,
+501    excluding the all-zeros partition.
+502
+503    Args:
+504        n (`int`): The number of modalities to generate partitions for.
+505
+506    Yields:
+507        `tuple[int, ...]`: A partition of zeros and ones, excluding the
+508        all-zeros partition.
+509    """
+510    for perm in product([0, 1], repeat=n):
+511        if any(perm):
+512            yield perm
+
+ + +

Generates all possible partitions of zeros and ones for n elements, +excluding the all-zeros partition.

+ +
Arguments:
+ +
    +
  • n (int): The number of modalities to generate partitions for.
  • +
+ +
Yields:
+ +
+

tuple[int, ...]: A partition of zeros and ones, excluding the + all-zeros partition.

+
+
+ + +
+
+ +
+ + def + broadcast_loss( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
515def broadcast_loss(
+516    gw_mod: GWModuleBase,
+517    selection_mod: SelectionBase,
+518    domain_mods: Mapping[str, DomainModule],
+519    latent_domains: LatentsDomainGroupsT,
+520) -> dict[str, torch.Tensor]:
+521    """
+522    Computes broadcast loss including demi-cycle, cycle, and translation losses.
+523
+524    Args:
+525        gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
+526        selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
+527        domain_mods (`Mapping[str, DomainModule]`): the domain modules
+528        latent_domains: The latent domain representations.
+529
+530    Returns:
+531        A dictionary with the total loss and additional metrics.
+532    """
+533    losses: dict[str, torch.Tensor] = {}
+534    metrics: dict[str, torch.Tensor] = {}
+535
+536    demi_cycle_losses: list[str] = []
+537    cycle_losses: list[str] = []
+538    translation_losses: list[str] = []
+539    fused_losses: list[str] = []
+540
+541    for group_domains, latents in latent_domains.items():
+542        encoded_latents = gw_mod.encode(latents)
+543        partitions = generate_partitions(len(group_domains))
+544        domain_names = list(latents)
+545        group_name = "-".join(group_domains)
+546
+547        for partition in partitions:
+548            selected_latents = {
+549                domain: latents[domain]
+550                for domain, present in zip(domain_names, partition, strict=True)
+551                if present
+552            }
+553            selected_encoded_latents = {
+554                domain: encoded_latents[domain] for domain in selected_latents
+555            }
+556            selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}"
+557
+558            selection_scores = selection_mod(selected_latents, selected_encoded_latents)
+559            fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores)
+560            decoded_latents = gw_mod.decode(fused_latents)
+561
+562            num_active_domains = sum(partition)
+563            num_total_domains = len(decoded_latents)
+564
+565            for domain, pred in decoded_latents.items():
+566                if domain not in group_domains:  # if we don't have ground truth
+567                    continue
+568                ground_truth = latents[domain]
+569                loss_output = domain_mods[domain].compute_loss(pred, ground_truth)
+570                loss_label = f"from_{selected_group_label}_to_{domain}"
+571                losses[loss_label + "_loss"] = loss_output.loss
+572                metrics.update(
+573                    {f"{loss_label}_{k}": v for k, v in loss_output.metrics.items()}
+574                )
+575
+576                if num_active_domains == 1 and domain in selected_latents:
+577                    demi_cycle_losses.append(loss_label + "_loss")
+578                elif domain not in selected_latents:
+579                    translation_losses.append(loss_label + "_loss")
+580                else:  # fused loss
+581                    fused_losses.append(loss_label + "_loss")
+582
+583            if num_active_domains < num_total_domains:
+584                inverse_selected_latents = {
+585                    domain: decoded_latents[domain]
+586                    for domain in decoded_latents
+587                    if domain not in selected_latents
+588                }
+589
+590                inverse_selected_group_label = (
+591                    "{" + ", ".join(sorted(inverse_selected_latents)) + "}"
+592                )
+593
+594                re_encoded_latents = gw_mod.encode(inverse_selected_latents)
+595                re_selection_scores = selection_mod(
+596                    inverse_selected_latents, re_encoded_latents
+597                )
+598                re_fused_latents = gw_mod.fuse(re_encoded_latents, re_selection_scores)
+599                re_decoded_latents = gw_mod.decode(
+600                    re_fused_latents, domains=selected_latents.keys()
+601                )
+602
+603                for domain in selected_latents:
+604                    re_ground_truth = latents[domain]
+605                    re_loss_output = domain_mods[domain].compute_loss(
+606                        re_decoded_latents[domain], re_ground_truth
+607                    )
+608                    loss_label = (
+609                        f"from_{selected_group_label}_"
+610                        f"through_{inverse_selected_group_label}_to_{domain}_"
+611                        f"case_{group_name}"
+612                    )
+613                    losses[loss_label + "_loss"] = re_loss_output.loss
+614                    metrics.update(
+615                        {
+616                            f"{loss_label}_{k}": v
+617                            for k, v in re_loss_output.metrics.items()
+618                        }
+619                    )
+620                    cycle_losses.append(loss_label + "_loss")
+621
+622    if demi_cycle_losses:
+623        metrics["demi_cycles"] = torch.mean(
+624            torch.stack([losses[loss_name] for loss_name in demi_cycle_losses])
+625        )
+626    if cycle_losses:
+627        metrics["cycles"] = torch.mean(
+628            torch.stack([losses[loss_name] for loss_name in cycle_losses])
+629        )
+630    if translation_losses:
+631        metrics["translations"] = torch.mean(
+632            torch.stack([losses[loss_name] for loss_name in translation_losses])
+633        )
+634    if fused_losses:
+635        metrics["fused"] = torch.mean(
+636            torch.stack([losses[loss_name] for loss_name in fused_losses])
+637        )
+638
+639    metrics.update(losses)
+640    return metrics
+
+ + +

Computes broadcast loss including demi-cycle, cycle, and translation losses.

+ +
Arguments:
+ + + +
Returns:
+ +
+

A dictionary with the total loss and additional metrics.

+
+
+ + +
+
+ +
+ + class + BroadcastLossCoefs(typing.TypedDict): + + + +
+ +
643class BroadcastLossCoefs(TypedDict, total=False):
+644    """
+645    Dict of loss coefficients used in the GWLossesFusion.
+646
+647    If one is not provided, the coefficient is assumed to be 0 and will not be logged.
+648    If the loss is excplicitely set to 0, it will be logged, but not take part in
+649    the total loss.
+650    """
+651
+652    contrastives: float
+653    """Contrastive loss coefficient."""
+654
+655    fused: float
+656    """fused loss coefficient (encode multiple domains and decode to one of them)."""
+657
+658    demi_cycles: float
+659    """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
+660
+661    cycles: float
+662    """cycles loss coefficient. Cycles can be many-to-one"""
+663
+664    translations: float
+665    """translation loss coefficient. Translation, like cycles, can be many-to-one."""
+
+ + +

Dict of loss coefficients used in the GWLossesFusion.

+ +

If one is not provided, the coefficient is assumed to be 0 and will not be logged. +If the loss is excplicitely set to 0, it will be logged, but not take part in +the total loss.

+
+ + +
+
+ contrastives: float + + +
+ + +

Contrastive loss coefficient.

+
+ + +
+
+
+ fused: float + + +
+ + +

fused loss coefficient (encode multiple domains and decode to one of them).

+
+ + +
+
+
+ demi_cycles: float + + +
+ + +

demi_cycles loss coefficient. Demi-cycles are always one-to-one

+
+ + +
+
+
+ cycles: float + + +
+ + +

cycles loss coefficient. Cycles can be many-to-one

+
+ + +
+
+
+ translations: float + + +
+ + +

translation loss coefficient. Translation, like cycles, can be many-to-one.

+
+ + +
+
+
+ +
+ + class + GWLosses(GWLossesBase): + + + +
+ +
668class GWLosses(GWLossesBase):
+669    """
+670    Implementation of `GWLossesBase` for fusion-based models.
+671    """
+672
+673    def __init__(
+674        self,
+675        gw_mod: GWModule,
+676        selection_mod: SelectionBase,
+677        domain_mods: dict[str, DomainModule],
+678        loss_coefs: BroadcastLossCoefs,
+679        contrastive_fn: ContrastiveLossType,
+680    ):
+681        """
+682        Initializes the loss computation module for a Global Workspace Fusion model.
+683
+684        Args:
+685            gw_mod: The GWModule for the global workspace.
+686            selection_mod: The selection mechanism for the model.
+687            domain_mods: A mapping of domain names to their respective DomainModule.
+688            loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+689            contrastive_fn: The function used for computing contrastive loss.
+690        """
+691        super().__init__()
+692        self.gw_mod = gw_mod
+693        self.selection_mod = selection_mod
+694        self.domain_mods = domain_mods
+695        self.loss_coefs = loss_coefs
+696        self.contrastive_fn = contrastive_fn
+697
+698    def contrastive_loss(
+699        self, latent_domains: LatentsDomainGroupsT
+700    ) -> dict[str, torch.Tensor]:
+701        """
+702        Computes the contrastive loss for the given latent domains.
+703
+704        Args:
+705            latent_domains: The latent domain representations.
+706
+707        Returns:
+708            A dictionary of contrastive loss metrics.
+709        """
+710
+711        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+712
+713    def broadcast_loss(
+714        self, latent_domains: LatentsDomainGroupsT
+715    ) -> dict[str, torch.Tensor]:
+716        return broadcast_loss(
+717            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+718        )
+719
+720    def step(
+721        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+722    ) -> LossOutput:
+723        """
+724        Performs a step of loss computation.
+725
+726        Args:
+727            domain_latents: Latent representations for all domains.
+728            mode: The mode in which the model is currently operating.
+729
+730        Returns:
+731            A LossOutput object containing the loss and metrics for this step.
+732        """
+733
+734        metrics: dict[str, torch.Tensor] = {}
+735
+736        metrics.update(self.contrastive_loss(domain_latents))
+737        metrics.update(self.broadcast_loss(domain_latents))
+738
+739        loss = torch.stack(
+740            [
+741                metrics[name] * coef
+742                for name, coef in self.loss_coefs.items()
+743                if isinstance(coef, float) and coef > 0
+744            ],
+745            dim=0,
+746        ).mean()
+747
+748        metrics["broadcast_loss"] = torch.stack(
+749            [
+750                metrics[name]
+751                for name, coef in self.loss_coefs.items()
+752                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+753            ],
+754            dim=0,
+755        ).mean()
+756
+757        return LossOutput(loss, metrics)
+
+ + +

Implementation of GWLossesBase for fusion-based models.

+
+ + +
+ +
+ + GWLosses( gw_mod: shimmer.modules.gw_module.GWModule, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: dict[str, shimmer.modules.domain.DomainModule], loss_coefs: BroadcastLossCoefs, contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput]) + + + +
+ +
673    def __init__(
+674        self,
+675        gw_mod: GWModule,
+676        selection_mod: SelectionBase,
+677        domain_mods: dict[str, DomainModule],
+678        loss_coefs: BroadcastLossCoefs,
+679        contrastive_fn: ContrastiveLossType,
+680    ):
+681        """
+682        Initializes the loss computation module for a Global Workspace Fusion model.
+683
+684        Args:
+685            gw_mod: The GWModule for the global workspace.
+686            selection_mod: The selection mechanism for the model.
+687            domain_mods: A mapping of domain names to their respective DomainModule.
+688            loss_coefs (`BroadcastLossCoefs`): coefs for the losses
+689            contrastive_fn: The function used for computing contrastive loss.
+690        """
+691        super().__init__()
+692        self.gw_mod = gw_mod
+693        self.selection_mod = selection_mod
+694        self.domain_mods = domain_mods
+695        self.loss_coefs = loss_coefs
+696        self.contrastive_fn = contrastive_fn
+
+ + +

Initializes the loss computation module for a Global Workspace Fusion model.

+ +
Arguments:
+ +
    +
  • gw_mod: The GWModule for the global workspace.
  • +
  • selection_mod: The selection mechanism for the model.
  • +
  • domain_mods: A mapping of domain names to their respective DomainModule.
  • +
  • loss_coefs (BroadcastLossCoefs): coefs for the losses
  • +
  • contrastive_fn: The function used for computing contrastive loss.
  • +
+
+ + +
+
+
+ gw_mod + + +
+ + + + +
+
+
+ selection_mod + + +
+ + + + +
+
+
+ domain_mods + + +
+ + + + +
+
+
+ loss_coefs + + +
+ + + + +
+
+
+ contrastive_fn + + +
+ + + + +
+
+ +
+ + def + contrastive_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
698    def contrastive_loss(
+699        self, latent_domains: LatentsDomainGroupsT
+700    ) -> dict[str, torch.Tensor]:
+701        """
+702        Computes the contrastive loss for the given latent domains.
+703
+704        Args:
+705            latent_domains: The latent domain representations.
+706
+707        Returns:
+708            A dictionary of contrastive loss metrics.
+709        """
+710
+711        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+
+ + +

Computes the contrastive loss for the given latent domains.

+ +
Arguments:
+ +
    +
  • latent_domains: The latent domain representations.
  • +
+ +
Returns:
+ +
+

A dictionary of contrastive loss metrics.

+
+
+ + +
+
+ +
+ + def + broadcast_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
713    def broadcast_loss(
+714        self, latent_domains: LatentsDomainGroupsT
+715    ) -> dict[str, torch.Tensor]:
+716        return broadcast_loss(
+717            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+718        )
+
+ + + + +
+
+ +
+ + def + step( self, domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput: + + + +
+ +
720    def step(
+721        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+722    ) -> LossOutput:
+723        """
+724        Performs a step of loss computation.
+725
+726        Args:
+727            domain_latents: Latent representations for all domains.
+728            mode: The mode in which the model is currently operating.
+729
+730        Returns:
+731            A LossOutput object containing the loss and metrics for this step.
+732        """
+733
+734        metrics: dict[str, torch.Tensor] = {}
+735
+736        metrics.update(self.contrastive_loss(domain_latents))
+737        metrics.update(self.broadcast_loss(domain_latents))
+738
+739        loss = torch.stack(
+740            [
+741                metrics[name] * coef
+742                for name, coef in self.loss_coefs.items()
+743                if isinstance(coef, float) and coef > 0
+744            ],
+745            dim=0,
+746        ).mean()
+747
+748        metrics["broadcast_loss"] = torch.stack(
+749            [
+750                metrics[name]
+751                for name, coef in self.loss_coefs.items()
+752                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+753            ],
+754            dim=0,
+755        ).mean()
+756
+757        return LossOutput(loss, metrics)
+
+ + +

Performs a step of loss computation.

+ +
Arguments:
+ +
    +
  • domain_latents: Latent representations for all domains.
  • +
  • mode: The mode in which the model is currently operating.
  • +
+ +
Returns:
+ +
+

A LossOutput object containing the loss and metrics for this step.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + GWLossesBayesian(GWLossesBase): + + + +
+ +
760class GWLossesBayesian(GWLossesBase):
+761    """
+762    Implementation of `GWLossesBase` used for `GWModuleBayesian`.
+763    """
+764
+765    def __init__(
+766        self,
+767        gw_mod: GWModuleBayesian,
+768        selection_mod: SelectionBase,
+769        domain_mods: dict[str, DomainModule],
+770        loss_coefs: BroadcastLossCoefs,
+771        contrastive_fn: ContrastiveLossType,
+772        use_normalized_constrastive: bool = True,
+773    ):
+774        """
+775        Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+776
+777        Args:
+778            gw_mod (`GWModuleBayesian`): the GWModule
+779            selection_mod (`SelectionBase`): selection module
+780            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+781                domain name and value is the DomainModule
+782            loss_coefs (`BroadcastLossCoefs`): loss coefficients
+783            contrastive_fn (`ContrastiveLossType`): the contrastive function
+784                to use in contrastive loss
+785            use_normalized_constrastive (`bool`): whether to use the normalized cont
+786                loss by the precision coefs
+787        """
+788        super().__init__()
+789
+790        self.gw_mod = gw_mod
+791        """The GWModule."""
+792
+793        self.selection_mod = selection_mod
+794        """Selection module"""
+795
+796        self.domain_mods = domain_mods
+797        """Domain modules linked to the GW."""
+798
+799        self.loss_coefs = loss_coefs
+800        """The loss coefficients."""
+801
+802        self.contrastive_fn = contrastive_fn
+803        """
+804        Contrastive loss to use.
+805        """
+806
+807        self.use_normalized_constrastive = use_normalized_constrastive
+808
+809    def contrastive_loss(
+810        self, latent_domains: LatentsDomainGroupsT
+811    ) -> dict[str, torch.Tensor]:
+812        """
+813        Contrastive loss.
+814
+815        Args:
+816            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+817
+818        Returns:
+819            `dict[str, torch.Tensor]`: a dict of metrics.
+820        """
+821        if self.use_normalized_constrastive:
+822            return contrastive_loss_bayesian(
+823                self.gw_mod, latent_domains, self.contrastive_fn
+824            )
+825        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+826
+827    def broadcast_loss(
+828        self, latent_domains: LatentsDomainGroupsT
+829    ) -> dict[str, torch.Tensor]:
+830        return broadcast_loss(
+831            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+832        )
+833
+834    def step(
+835        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+836    ) -> LossOutput:
+837        """
+838        Performs a step of loss computation.
+839
+840        Args:
+841            domain_latents: Latent representations for all domains.
+842            mode: The mode in which the model is currently operating.
+843
+844        Returns:
+845            A LossOutput object containing the loss and metrics for this step.
+846        """
+847
+848        metrics: dict[str, torch.Tensor] = {}
+849
+850        metrics.update(self.contrastive_loss(domain_latents))
+851        metrics.update(self.broadcast_loss(domain_latents))
+852
+853        loss = torch.stack(
+854            [
+855                metrics[name] * coef
+856                for name, coef in self.loss_coefs.items()
+857                if isinstance(coef, float) and coef > 0
+858            ],
+859            dim=0,
+860        ).mean()
+861
+862        metrics["broadcast_loss"] = torch.stack(
+863            [
+864                metrics[name]
+865                for name, coef in self.loss_coefs.items()
+866                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+867            ],
+868            dim=0,
+869        ).mean()
+870
+871        return LossOutput(loss, metrics)
+
+ + +

Implementation of GWLossesBase used for GWModuleBayesian.

+
+ + +
+ +
+ + GWLossesBayesian( gw_mod: shimmer.modules.gw_module.GWModuleBayesian, selection_mod: shimmer.modules.selection.SelectionBase, domain_mods: dict[str, shimmer.modules.domain.DomainModule], loss_coefs: BroadcastLossCoefs, contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput], use_normalized_constrastive: bool = True) + + + +
+ +
765    def __init__(
+766        self,
+767        gw_mod: GWModuleBayesian,
+768        selection_mod: SelectionBase,
+769        domain_mods: dict[str, DomainModule],
+770        loss_coefs: BroadcastLossCoefs,
+771        contrastive_fn: ContrastiveLossType,
+772        use_normalized_constrastive: bool = True,
+773    ):
+774        """
+775        Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
+776
+777        Args:
+778            gw_mod (`GWModuleBayesian`): the GWModule
+779            selection_mod (`SelectionBase`): selection module
+780            domain_mods (`dict[str, DomainModule]`): a dict where the key is the
+781                domain name and value is the DomainModule
+782            loss_coefs (`BroadcastLossCoefs`): loss coefficients
+783            contrastive_fn (`ContrastiveLossType`): the contrastive function
+784                to use in contrastive loss
+785            use_normalized_constrastive (`bool`): whether to use the normalized cont
+786                loss by the precision coefs
+787        """
+788        super().__init__()
+789
+790        self.gw_mod = gw_mod
+791        """The GWModule."""
+792
+793        self.selection_mod = selection_mod
+794        """Selection module"""
+795
+796        self.domain_mods = domain_mods
+797        """Domain modules linked to the GW."""
+798
+799        self.loss_coefs = loss_coefs
+800        """The loss coefficients."""
+801
+802        self.contrastive_fn = contrastive_fn
+803        """
+804        Contrastive loss to use.
+805        """
+806
+807        self.use_normalized_constrastive = use_normalized_constrastive
+
+ + +

Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian

+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBayesian): the GWModule
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • domain_mods (dict[str, DomainModule]): a dict where the key is the +domain name and value is the DomainModule
  • +
  • loss_coefs (BroadcastLossCoefs): loss coefficients
  • +
  • contrastive_fn (ContrastiveLossType): the contrastive function +to use in contrastive loss
  • +
  • use_normalized_constrastive (bool): whether to use the normalized cont +loss by the precision coefs
  • +
+
+ + +
+
+
+ gw_mod + + +
+ + +

The GWModule.

+
+ + +
+
+
+ selection_mod + + +
+ + +

Selection module

+
+ + +
+
+
+ domain_mods + + +
+ + +

Domain modules linked to the GW.

+
+ + +
+
+
+ loss_coefs + + +
+ + +

The loss coefficients.

+
+ + +
+
+
+ contrastive_fn + + +
+ + +

Contrastive loss to use.

+
+ + +
+
+
+ use_normalized_constrastive + + +
+ + + + +
+
+ +
+ + def + contrastive_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
809    def contrastive_loss(
+810        self, latent_domains: LatentsDomainGroupsT
+811    ) -> dict[str, torch.Tensor]:
+812        """
+813        Contrastive loss.
+814
+815        Args:
+816            latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
+817
+818        Returns:
+819            `dict[str, torch.Tensor]`: a dict of metrics.
+820        """
+821        if self.use_normalized_constrastive:
+822            return contrastive_loss_bayesian(
+823                self.gw_mod, latent_domains, self.contrastive_fn
+824            )
+825        return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)
+
+ + +

Contrastive loss.

+ +
Arguments:
+ +
    +
  • latent_domains (LatentsDomainGroupsT): the latent unimodal groups
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: a dict of metrics.

+
+
+ + +
+
+ +
+ + def + broadcast_loss( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
827    def broadcast_loss(
+828        self, latent_domains: LatentsDomainGroupsT
+829    ) -> dict[str, torch.Tensor]:
+830        return broadcast_loss(
+831            self.gw_mod, self.selection_mod, self.domain_mods, latent_domains
+832        )
+
+ + + + +
+
+ +
+ + def + step( self, domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> shimmer.modules.domain.LossOutput: + + + +
+ +
834    def step(
+835        self, domain_latents: LatentsDomainGroupsT, mode: ModelModeT
+836    ) -> LossOutput:
+837        """
+838        Performs a step of loss computation.
+839
+840        Args:
+841            domain_latents: Latent representations for all domains.
+842            mode: The mode in which the model is currently operating.
+843
+844        Returns:
+845            A LossOutput object containing the loss and metrics for this step.
+846        """
+847
+848        metrics: dict[str, torch.Tensor] = {}
+849
+850        metrics.update(self.contrastive_loss(domain_latents))
+851        metrics.update(self.broadcast_loss(domain_latents))
+852
+853        loss = torch.stack(
+854            [
+855                metrics[name] * coef
+856                for name, coef in self.loss_coefs.items()
+857                if isinstance(coef, float) and coef > 0
+858            ],
+859            dim=0,
+860        ).mean()
+861
+862        metrics["broadcast_loss"] = torch.stack(
+863            [
+864                metrics[name]
+865                for name, coef in self.loss_coefs.items()
+866                if isinstance(coef, float) and coef > 0 and name != "contrastives"
+867            ],
+868            dim=0,
+869        ).mean()
+870
+871        return LossOutput(loss, metrics)
+
+ + +

Performs a step of loss computation.

+ +
Arguments:
+ +
    +
  • domain_latents: Latent representations for all domains.
  • +
  • mode: The mode in which the model is currently operating.
  • +
+ +
Returns:
+ +
+

A LossOutput object containing the loss and metrics for this step.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
forward
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/selection.html b/docs/api/v0.5.1/shimmer/modules/selection.html new file mode 100644 index 00000000..6a4bb9c0 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/selection.html @@ -0,0 +1,2151 @@ + + + + + + + shimmer.modules.selection API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.selection

+ + + + + + +
  1from abc import ABC, abstractmethod
+  2from collections.abc import Iterable
+  3
+  4import torch
+  5import torch.nn as nn
+  6
+  7from shimmer.types import LatentsDomainGroupT
+  8from shimmer.utils import group_batch_size, group_device
+  9
+ 10
+ 11class SelectionBase(torch.nn.Module, ABC):
+ 12    """
+ 13    This is the base class for the selection mechanism.
+ 14    The selection mechanisms handles the "competition" between modules and *selects*
+ 15    fusion coefficients for the domains.
+ 16    """
+ 17
+ 18    def update_gw_state(self, gw_state: torch.Tensor) -> None:
+ 19        """
+ 20        Update the internal copy of the previous GW state.
+ 21        By default, this is not implemented and will raise an error if used.
+ 22
+ 23        :note..
+ 24            This is not defined as an abstractmethod as some selection method may
+ 25            not need it.
+ 26
+ 27        Args:
+ 28            gw_state (`torch.Tensor`): the previous GW state
+ 29        """
+ 30        pass
+ 31
+ 32    @abstractmethod
+ 33    def forward(
+ 34        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+ 35    ) -> dict[str, torch.Tensor]:
+ 36        """
+ 37        Forward pass of the selection method.
+ 38
+ 39        Args:
+ 40            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+ 41
+ 42        Returns:
+ 43            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+ 44            coefficient for each item in the batch.
+ 45
+ 46        Example:
+ 47            >>> SomeSelectionImplementation().forward(
+ 48            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+ 49            ... )
+ 50            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+ 51        """
+ 52        ...
+ 53
+ 54    # This is just for proper auto-completion...
+ 55    def __call__(
+ 56        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+ 57    ) -> dict[str, torch.Tensor]:
+ 58        return super().__call__(domains, encodings_pre_fusion)
+ 59
+ 60
+ 61class SingleDomainSelection(SelectionBase):
+ 62    """
+ 63    This selection mechanism handles groups that can have multiple domains, but always
+ 64    return a selection of 1 domain from the group with a uniform distribution.
+ 65
+ 66    For example, if the group has 2 domains, there is a 50% chance of selecting each
+ 67    domain.
+ 68    """
+ 69
+ 70    def forward(
+ 71        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+ 72    ) -> dict[str, torch.Tensor]:
+ 73        """
+ 74        Forward pass of the module.
+ 75
+ 76        Args:
+ 77            domains (`LatentsDomainGroupT`): input unimodal latent representations
+ 78            gw_state (`torch.Tensor`): the previous GW state
+ 79
+ 80        Returns:
+ 81            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+ 82            in the batch.
+ 83        """
+ 84        selection: dict[str, torch.Tensor] = {}
+ 85        bs = group_batch_size(domains)
+ 86        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
+ 87        for k, domain in enumerate(domains.keys()):
+ 88            selection[domain] = (choice == k).to(torch.float32)
+ 89        return selection
+ 90
+ 91
+ 92class FixedSharedSelection(SelectionBase):
+ 93    """
+ 94    This selection mechanism is deterministic and always shares the weights equally
+ 95    between domains.
+ 96
+ 97    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
+ 98    """
+ 99
+100    def forward(
+101        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+102    ) -> dict[str, torch.Tensor]:
+103        """
+104        Forward pass of the module.
+105
+106        Args:
+107            domains (`LatentsDomainGroupT`): input unimodal latent representations
+108            gw_state (`torch.Tensor`): the previous GW state
+109
+110        Returns:
+111            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+112            in the batch.
+113        """
+114        selection: dict[str, torch.Tensor] = {}
+115        bs = group_batch_size(domains)
+116        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
+117        for domain in domains:
+118            selection[domain] = coef.clone()
+119        return selection
+120
+121
+122def _calculate_attention_dict(
+123    domains: LatentsDomainGroupT,
+124    keys: dict[str, torch.Tensor],
+125    query: torch.Tensor,
+126) -> dict[str, torch.Tensor]:
+127    """
+128    Args:
+129        domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+130        keys (`dict[str, torch.Tensor]`): The keys for each domain.
+131        query (`torch.Tensor`): The query tensor.
+132
+133    Returns:
+134        `dict[str, torch.Tensor]`: The attention scores for each domain.
+135    """
+136    dot_products = {
+137        domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
+138        for domain, key in keys.items()
+139    }
+140
+141    dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
+142
+143    attention_scores = torch.softmax(dot_products_tensor, dim=1)
+144
+145    attention_dict = {
+146        domain: attention_scores[:, i] for i, domain in enumerate(domains)
+147    }
+148    return attention_dict
+149
+150
+151class KQFixedQSelection(SelectionBase):
+152    """
+153    Key-Query attention with a fixed gw vector.
+154    """
+155
+156    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+157        """
+158        Args:
+159            head_size (`int`) : dimension of the key and query vectors.
+160            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+161                for now)
+162            domain_names  (`Iterable[str]`) : list of input domains
+163        """
+164        super().__init__()
+165        self.head_size = head_size
+166        self.query_layer = nn.Linear(domain_dim, head_size)
+167        self.key_layers = nn.ModuleDict(
+168            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+169        )
+170        # Start with a random gw state
+171        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+172
+173    def forward(
+174        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+175    ) -> dict[str, torch.Tensor]:
+176        """
+177        Compute keys and queries, match them with dot product and softmax.
+178        Does this twice, once with the static query and once with a dynamic query.
+179
+180        Args:
+181            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+182            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+183
+184        Returns:
+185            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+186            group.
+187        """
+188
+189        keys = {
+190            domain: self.key_layers[domain](encoding)
+191            for domain, encoding in domains.items()
+192        }
+193
+194        batch_size = group_batch_size(domains)
+195
+196        # Retrieve random query
+197        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+198
+199        # Calculate the attention scores
+200        return _calculate_attention_dict(domains, keys, query)
+201
+202
+203class RandomSelection(SelectionBase):
+204    """
+205    Modified random attention to only utilize uniform-softmax scores across modalities.
+206    This version omits the binary scaling factors and focuses on generating attention
+207    coefficients using a uniform distribution followed by a domain-wise softmax.
+208    """
+209
+210    def __init__(self, temperature: float):
+211        """
+212        Args:
+213            temperature (`float`): Temperature of the softmax applied to uniform
+214                scaling factors.
+215        """
+216        super().__init__()
+217        self.temperature = temperature
+218
+219    def forward(
+220        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+221    ) -> dict[str, torch.Tensor]:
+222        """
+223        Generate uniform-then-domain-wise-softmaxed samples for each domain.
+224
+225        Args:
+226            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+227                This is not used in the function directly but determines the structure
+228                of the returned attention coefficients.
+229
+230        Returns:
+231            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+232            coefficient for each item in the batch, based solely on
+233            uniform-softmax scores.
+234        """
+235        num_domains = len(domains)
+236        batch_size = group_batch_size(domains)
+237        device = group_device(domains)
+238
+239        # Generate uniform scores
+240        uniform_scores = torch.rand(batch_size, num_domains, device=device)
+241
+242        # Apply softmax across domains with temperature scaling
+243        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
+244        # Create attention dictionary for each domain
+245        attention_dict = {
+246            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
+247        }
+248
+249        return attention_dict
+250
+251
+252class DynamicQueryAttention(SelectionBase):
+253    """
+254    Key-Query attention with a dynamic gw vector.
+255    The query is updated based on the scaled gw vector.
+256    """
+257
+258    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+259        """
+260        Args:
+261            head_size (`int`) : dimension of the key and query vectors.
+262            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+263                for now)
+264            domain_names  (`Iterable[str]`) : list of input domains
+265        """
+266        super().__init__()
+267        self.head_size = head_size
+268        self.query_layer = nn.Linear(domain_dim, head_size)
+269        self.key_layers = nn.ModuleDict(
+270            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+271        )
+272        # Start with a random gw state
+273        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+274
+275    def fuse_weighted_encodings(
+276        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
+277    ) -> torch.Tensor:
+278        """
+279        Fuse the weighted encodings using the attention scores.
+280
+281        Args:
+282            encodings (`LatentsDomainGroupT`): Unimodal latent representation
+283            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+284                domain in the group.
+285
+286        Returns:
+287            `torch.Tensor`: The fused tensor.
+288        """
+289        # Apply attention scores to the encodings
+290        weighted_encodings = {}
+291        for key in attention_dict:
+292            if key in encodings:
+293                # Perform element-wise multiplication
+294                weighted_encodings[key] = (
+295                    attention_dict[key].unsqueeze(1) * encodings[key]
+296                )
+297
+298        # Stack the tensors along a new dimension (dimension 0)
+299        stacked_tensors = torch.stack(list(weighted_encodings.values()))
+300
+301        # Apply fusion by summing along the newly created dimension
+302        summed_tensor = torch.sum(stacked_tensors, dim=0)
+303        return summed_tensor
+304
+305    def forward(
+306        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+307    ) -> dict[str, torch.Tensor]:
+308        """
+309        Compute keys and queries, match them with dot product and softmax.
+310        Does this twice, once with the static query and once with a dynamic query.
+311
+312        Args:
+313            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+314            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+315
+316        Returns:
+317            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+318            group.
+319        """
+320
+321        keys = {
+322            domain: self.key_layers[domain](encoding)
+323            for domain, encoding in domains.items()
+324        }
+325
+326        batch_size = group_batch_size(domains)
+327
+328        # Retrieve random query
+329        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+330
+331        # Calculate the attention scores
+332        static_attention_dict = _calculate_attention_dict(domains, keys, query)
+333
+334        # Apply the attention scores to the encodings
+335        summed_tensor = self.fuse_weighted_encodings(
+336            encodings_pre_fusion, static_attention_dict
+337        )
+338
+339        # Retrieve query (now it is dependent on the new gw state)
+340        query = self.query_layer(summed_tensor)
+341
+342        # Calculate the attention scores again
+343        dynamic_attention_dict = _calculate_attention_dict(domains, keys, query)
+344
+345        return dynamic_attention_dict
+
+ + +
+
+ +
+ + class + SelectionBase(torch.nn.modules.module.Module, abc.ABC): + + + +
+ +
12class SelectionBase(torch.nn.Module, ABC):
+13    """
+14    This is the base class for the selection mechanism.
+15    The selection mechanisms handles the "competition" between modules and *selects*
+16    fusion coefficients for the domains.
+17    """
+18
+19    def update_gw_state(self, gw_state: torch.Tensor) -> None:
+20        """
+21        Update the internal copy of the previous GW state.
+22        By default, this is not implemented and will raise an error if used.
+23
+24        :note..
+25            This is not defined as an abstractmethod as some selection method may
+26            not need it.
+27
+28        Args:
+29            gw_state (`torch.Tensor`): the previous GW state
+30        """
+31        pass
+32
+33    @abstractmethod
+34    def forward(
+35        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+36    ) -> dict[str, torch.Tensor]:
+37        """
+38        Forward pass of the selection method.
+39
+40        Args:
+41            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+42
+43        Returns:
+44            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+45            coefficient for each item in the batch.
+46
+47        Example:
+48            >>> SomeSelectionImplementation().forward(
+49            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+50            ... )
+51            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+52        """
+53        ...
+54
+55    # This is just for proper auto-completion...
+56    def __call__(
+57        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+58    ) -> dict[str, torch.Tensor]:
+59        return super().__call__(domains, encodings_pre_fusion)
+
+ + +

This is the base class for the selection mechanism. +The selection mechanisms handles the "competition" between modules and selects +fusion coefficients for the domains.

+
+ + +
+ +
+ + def + update_gw_state(self, gw_state: torch.Tensor) -> None: + + + +
+ +
19    def update_gw_state(self, gw_state: torch.Tensor) -> None:
+20        """
+21        Update the internal copy of the previous GW state.
+22        By default, this is not implemented and will raise an error if used.
+23
+24        :note..
+25            This is not defined as an abstractmethod as some selection method may
+26            not need it.
+27
+28        Args:
+29            gw_state (`torch.Tensor`): the previous GW state
+30        """
+31        pass
+
+ + +

Update the internal copy of the previous GW state. +By default, this is not implemented and will raise an error if used.

+ +

:note.. + This is not defined as an abstractmethod as some selection method may + not need it.

+ +
Arguments:
+ +
    +
  • gw_state (torch.Tensor): the previous GW state
  • +
+
+ + +
+
+ +
+
@abstractmethod
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
33    @abstractmethod
+34    def forward(
+35        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+36    ) -> dict[str, torch.Tensor]:
+37        """
+38        Forward pass of the selection method.
+39
+40        Args:
+41            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+42
+43        Returns:
+44            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
+45            coefficient for each item in the batch.
+46
+47        Example:
+48            >>> SomeSelectionImplementation().forward(
+49            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+50            ... )
+51            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+52        """
+53        ...
+
+ + +

Forward pass of the selection method.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: for each domain in the group, the fusion + coefficient for each item in the batch.

+
+ +
Example:
+ +
+
+
>>> SomeSelectionImplementation().forward(
+...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
+... )
+{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
+
+
+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + SingleDomainSelection(SelectionBase): + + + +
+ +
62class SingleDomainSelection(SelectionBase):
+63    """
+64    This selection mechanism handles groups that can have multiple domains, but always
+65    return a selection of 1 domain from the group with a uniform distribution.
+66
+67    For example, if the group has 2 domains, there is a 50% chance of selecting each
+68    domain.
+69    """
+70
+71    def forward(
+72        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+73    ) -> dict[str, torch.Tensor]:
+74        """
+75        Forward pass of the module.
+76
+77        Args:
+78            domains (`LatentsDomainGroupT`): input unimodal latent representations
+79            gw_state (`torch.Tensor`): the previous GW state
+80
+81        Returns:
+82            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+83            in the batch.
+84        """
+85        selection: dict[str, torch.Tensor] = {}
+86        bs = group_batch_size(domains)
+87        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
+88        for k, domain in enumerate(domains.keys()):
+89            selection[domain] = (choice == k).to(torch.float32)
+90        return selection
+
+ + +

This selection mechanism handles groups that can have multiple domains, but always +return a selection of 1 domain from the group with a uniform distribution.

+ +

For example, if the group has 2 domains, there is a 50% chance of selecting each +domain.

+
+ + +
+ +
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
71    def forward(
+72        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+73    ) -> dict[str, torch.Tensor]:
+74        """
+75        Forward pass of the module.
+76
+77        Args:
+78            domains (`LatentsDomainGroupT`): input unimodal latent representations
+79            gw_state (`torch.Tensor`): the previous GW state
+80
+81        Returns:
+82            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+83            in the batch.
+84        """
+85        selection: dict[str, torch.Tensor] = {}
+86        bs = group_batch_size(domains)
+87        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
+88        for k, domain in enumerate(domains.keys()):
+89            selection[domain] = (choice == k).to(torch.float32)
+90        return selection
+
+ + +

Forward pass of the module.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • +
  • gw_state (torch.Tensor): the previous GW state
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: whether the domain is selected for each input + in the batch.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+ +
+
+
+
+ +
+ + class + FixedSharedSelection(SelectionBase): + + + +
+ +
 93class FixedSharedSelection(SelectionBase):
+ 94    """
+ 95    This selection mechanism is deterministic and always shares the weights equally
+ 96    between domains.
+ 97
+ 98    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
+ 99    """
+100
+101    def forward(
+102        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+103    ) -> dict[str, torch.Tensor]:
+104        """
+105        Forward pass of the module.
+106
+107        Args:
+108            domains (`LatentsDomainGroupT`): input unimodal latent representations
+109            gw_state (`torch.Tensor`): the previous GW state
+110
+111        Returns:
+112            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+113            in the batch.
+114        """
+115        selection: dict[str, torch.Tensor] = {}
+116        bs = group_batch_size(domains)
+117        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
+118        for domain in domains:
+119            selection[domain] = coef.clone()
+120        return selection
+
+ + +

This selection mechanism is deterministic and always shares the weights equally +between domains.

+ +

For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...

+
+ + +
+ +
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
101    def forward(
+102        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+103    ) -> dict[str, torch.Tensor]:
+104        """
+105        Forward pass of the module.
+106
+107        Args:
+108            domains (`LatentsDomainGroupT`): input unimodal latent representations
+109            gw_state (`torch.Tensor`): the previous GW state
+110
+111        Returns:
+112            `dict[str, torch.Tensor]`: whether the domain is selected for each input
+113            in the batch.
+114        """
+115        selection: dict[str, torch.Tensor] = {}
+116        bs = group_batch_size(domains)
+117        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
+118        for domain in domains:
+119            selection[domain] = coef.clone()
+120        return selection
+
+ + +

Forward pass of the module.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • +
  • gw_state (torch.Tensor): the previous GW state
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: whether the domain is selected for each input + in the batch.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+ +
+
+
+
+ +
+ + class + KQFixedQSelection(SelectionBase): + + + +
+ +
152class KQFixedQSelection(SelectionBase):
+153    """
+154    Key-Query attention with a fixed gw vector.
+155    """
+156
+157    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+158        """
+159        Args:
+160            head_size (`int`) : dimension of the key and query vectors.
+161            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+162                for now)
+163            domain_names  (`Iterable[str]`) : list of input domains
+164        """
+165        super().__init__()
+166        self.head_size = head_size
+167        self.query_layer = nn.Linear(domain_dim, head_size)
+168        self.key_layers = nn.ModuleDict(
+169            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+170        )
+171        # Start with a random gw state
+172        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+173
+174    def forward(
+175        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+176    ) -> dict[str, torch.Tensor]:
+177        """
+178        Compute keys and queries, match them with dot product and softmax.
+179        Does this twice, once with the static query and once with a dynamic query.
+180
+181        Args:
+182            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+183            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+184
+185        Returns:
+186            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+187            group.
+188        """
+189
+190        keys = {
+191            domain: self.key_layers[domain](encoding)
+192            for domain, encoding in domains.items()
+193        }
+194
+195        batch_size = group_batch_size(domains)
+196
+197        # Retrieve random query
+198        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+199
+200        # Calculate the attention scores
+201        return _calculate_attention_dict(domains, keys, query)
+
+ + +

Key-Query attention with a fixed gw vector.

+
+ + +
+ +
+ + KQFixedQSelection( head_size: int, domain_dim: int, domain_names: collections.abc.Iterable[str]) + + + +
+ +
157    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+158        """
+159        Args:
+160            head_size (`int`) : dimension of the key and query vectors.
+161            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+162                for now)
+163            domain_names  (`Iterable[str]`) : list of input domains
+164        """
+165        super().__init__()
+166        self.head_size = head_size
+167        self.query_layer = nn.Linear(domain_dim, head_size)
+168        self.key_layers = nn.ModuleDict(
+169            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+170        )
+171        # Start with a random gw state
+172        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+
+ + +
Arguments:
+ +
    +
  • head_size (int) : dimension of the key and query vectors.
  • +
  • domain_dim (int) : dimension of the input dims (assumed to be the same +for now)
  • +
  • domain_names (Iterable[str]) : list of input domains
  • +
+
+ + +
+
+
+ head_size + + +
+ + + + +
+
+
+ query_layer + + +
+ + + + +
+
+
+ key_layers + + +
+ + + + +
+
+ +
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
174    def forward(
+175        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+176    ) -> dict[str, torch.Tensor]:
+177        """
+178        Compute keys and queries, match them with dot product and softmax.
+179        Does this twice, once with the static query and once with a dynamic query.
+180
+181        Args:
+182            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+183            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+184
+185        Returns:
+186            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+187            group.
+188        """
+189
+190        keys = {
+191            domain: self.key_layers[domain](encoding)
+192            for domain, encoding in domains.items()
+193        }
+194
+195        batch_size = group_batch_size(domains)
+196
+197        # Retrieve random query
+198        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+199
+200        # Calculate the attention scores
+201        return _calculate_attention_dict(domains, keys, query)
+
+ + +

Compute keys and queries, match them with dot product and softmax. +Does this twice, once with the static query and once with a dynamic query.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • +
  • encodings (LatentsDomainGroupT): Group of pre-fusion encodings.
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: the attention scores for each domain in the + group.

+
+
+ + +
+
+
Inherited Members
+
+ +
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + RandomSelection(SelectionBase): + + + +
+ +
204class RandomSelection(SelectionBase):
+205    """
+206    Modified random attention to only utilize uniform-softmax scores across modalities.
+207    This version omits the binary scaling factors and focuses on generating attention
+208    coefficients using a uniform distribution followed by a domain-wise softmax.
+209    """
+210
+211    def __init__(self, temperature: float):
+212        """
+213        Args:
+214            temperature (`float`): Temperature of the softmax applied to uniform
+215                scaling factors.
+216        """
+217        super().__init__()
+218        self.temperature = temperature
+219
+220    def forward(
+221        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+222    ) -> dict[str, torch.Tensor]:
+223        """
+224        Generate uniform-then-domain-wise-softmaxed samples for each domain.
+225
+226        Args:
+227            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+228                This is not used in the function directly but determines the structure
+229                of the returned attention coefficients.
+230
+231        Returns:
+232            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+233            coefficient for each item in the batch, based solely on
+234            uniform-softmax scores.
+235        """
+236        num_domains = len(domains)
+237        batch_size = group_batch_size(domains)
+238        device = group_device(domains)
+239
+240        # Generate uniform scores
+241        uniform_scores = torch.rand(batch_size, num_domains, device=device)
+242
+243        # Apply softmax across domains with temperature scaling
+244        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
+245        # Create attention dictionary for each domain
+246        attention_dict = {
+247            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
+248        }
+249
+250        return attention_dict
+
+ + +

Modified random attention to only utilize uniform-softmax scores across modalities. +This version omits the binary scaling factors and focuses on generating attention +coefficients using a uniform distribution followed by a domain-wise softmax.

+
+ + +
+ +
+ + RandomSelection(temperature: float) + + + +
+ +
211    def __init__(self, temperature: float):
+212        """
+213        Args:
+214            temperature (`float`): Temperature of the softmax applied to uniform
+215                scaling factors.
+216        """
+217        super().__init__()
+218        self.temperature = temperature
+
+ + +
Arguments:
+ +
    +
  • temperature (float): Temperature of the softmax applied to uniform +scaling factors.
  • +
+
+ + +
+
+
+ temperature + + +
+ + + + +
+
+ +
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
220    def forward(
+221        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+222    ) -> dict[str, torch.Tensor]:
+223        """
+224        Generate uniform-then-domain-wise-softmaxed samples for each domain.
+225
+226        Args:
+227            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+228                This is not used in the function directly but determines the structure
+229                of the returned attention coefficients.
+230
+231        Returns:
+232            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
+233            coefficient for each item in the batch, based solely on
+234            uniform-softmax scores.
+235        """
+236        num_domains = len(domains)
+237        batch_size = group_batch_size(domains)
+238        device = group_device(domains)
+239
+240        # Generate uniform scores
+241        uniform_scores = torch.rand(batch_size, num_domains, device=device)
+242
+243        # Apply softmax across domains with temperature scaling
+244        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
+245        # Create attention dictionary for each domain
+246        attention_dict = {
+247            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
+248        }
+249
+250        return attention_dict
+
+ + +

Generate uniform-then-domain-wise-softmaxed samples for each domain.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): Group of unimodal latent representations. +This is not used in the function directly but determines the structure +of the returned attention coefficients.
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: For each domain in the group, the fusion + coefficient for each item in the batch, based solely on + uniform-softmax scores.

+
+
+ + +
+
+
Inherited Members
+
+ +
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + DynamicQueryAttention(SelectionBase): + + + +
+ +
253class DynamicQueryAttention(SelectionBase):
+254    """
+255    Key-Query attention with a dynamic gw vector.
+256    The query is updated based on the scaled gw vector.
+257    """
+258
+259    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+260        """
+261        Args:
+262            head_size (`int`) : dimension of the key and query vectors.
+263            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+264                for now)
+265            domain_names  (`Iterable[str]`) : list of input domains
+266        """
+267        super().__init__()
+268        self.head_size = head_size
+269        self.query_layer = nn.Linear(domain_dim, head_size)
+270        self.key_layers = nn.ModuleDict(
+271            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+272        )
+273        # Start with a random gw state
+274        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+275
+276    def fuse_weighted_encodings(
+277        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
+278    ) -> torch.Tensor:
+279        """
+280        Fuse the weighted encodings using the attention scores.
+281
+282        Args:
+283            encodings (`LatentsDomainGroupT`): Unimodal latent representation
+284            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+285                domain in the group.
+286
+287        Returns:
+288            `torch.Tensor`: The fused tensor.
+289        """
+290        # Apply attention scores to the encodings
+291        weighted_encodings = {}
+292        for key in attention_dict:
+293            if key in encodings:
+294                # Perform element-wise multiplication
+295                weighted_encodings[key] = (
+296                    attention_dict[key].unsqueeze(1) * encodings[key]
+297                )
+298
+299        # Stack the tensors along a new dimension (dimension 0)
+300        stacked_tensors = torch.stack(list(weighted_encodings.values()))
+301
+302        # Apply fusion by summing along the newly created dimension
+303        summed_tensor = torch.sum(stacked_tensors, dim=0)
+304        return summed_tensor
+305
+306    def forward(
+307        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+308    ) -> dict[str, torch.Tensor]:
+309        """
+310        Compute keys and queries, match them with dot product and softmax.
+311        Does this twice, once with the static query and once with a dynamic query.
+312
+313        Args:
+314            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+315            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+316
+317        Returns:
+318            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+319            group.
+320        """
+321
+322        keys = {
+323            domain: self.key_layers[domain](encoding)
+324            for domain, encoding in domains.items()
+325        }
+326
+327        batch_size = group_batch_size(domains)
+328
+329        # Retrieve random query
+330        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+331
+332        # Calculate the attention scores
+333        static_attention_dict = _calculate_attention_dict(domains, keys, query)
+334
+335        # Apply the attention scores to the encodings
+336        summed_tensor = self.fuse_weighted_encodings(
+337            encodings_pre_fusion, static_attention_dict
+338        )
+339
+340        # Retrieve query (now it is dependent on the new gw state)
+341        query = self.query_layer(summed_tensor)
+342
+343        # Calculate the attention scores again
+344        dynamic_attention_dict = _calculate_attention_dict(domains, keys, query)
+345
+346        return dynamic_attention_dict
+
+ + +

Key-Query attention with a dynamic gw vector. +The query is updated based on the scaled gw vector.

+
+ + +
+ +
+ + DynamicQueryAttention( head_size: int, domain_dim: int, domain_names: collections.abc.Iterable[str]) + + + +
+ +
259    def __init__(self, head_size: int, domain_dim: int, domain_names: Iterable[str]):
+260        """
+261        Args:
+262            head_size (`int`) : dimension of the key and query vectors.
+263            domain_dim (`int`) : dimension of the input dims (assumed to be the same
+264                for now)
+265            domain_names  (`Iterable[str]`) : list of input domains
+266        """
+267        super().__init__()
+268        self.head_size = head_size
+269        self.query_layer = nn.Linear(domain_dim, head_size)
+270        self.key_layers = nn.ModuleDict(
+271            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
+272        )
+273        # Start with a random gw state
+274        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
+
+ + +
Arguments:
+ +
    +
  • head_size (int) : dimension of the key and query vectors.
  • +
  • domain_dim (int) : dimension of the input dims (assumed to be the same +for now)
  • +
  • domain_names (Iterable[str]) : list of input domains
  • +
+
+ + +
+
+
+ head_size + + +
+ + + + +
+
+
+ query_layer + + +
+ + + + +
+
+
+ key_layers + + +
+ + + + +
+
+ +
+ + def + fuse_weighted_encodings( self, encodings: collections.abc.Mapping[str, torch.Tensor], attention_dict: dict[str, torch.Tensor]) -> torch.Tensor: + + + +
+ +
276    def fuse_weighted_encodings(
+277        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
+278    ) -> torch.Tensor:
+279        """
+280        Fuse the weighted encodings using the attention scores.
+281
+282        Args:
+283            encodings (`LatentsDomainGroupT`): Unimodal latent representation
+284            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
+285                domain in the group.
+286
+287        Returns:
+288            `torch.Tensor`: The fused tensor.
+289        """
+290        # Apply attention scores to the encodings
+291        weighted_encodings = {}
+292        for key in attention_dict:
+293            if key in encodings:
+294                # Perform element-wise multiplication
+295                weighted_encodings[key] = (
+296                    attention_dict[key].unsqueeze(1) * encodings[key]
+297                )
+298
+299        # Stack the tensors along a new dimension (dimension 0)
+300        stacked_tensors = torch.stack(list(weighted_encodings.values()))
+301
+302        # Apply fusion by summing along the newly created dimension
+303        summed_tensor = torch.sum(stacked_tensors, dim=0)
+304        return summed_tensor
+
+ + +

Fuse the weighted encodings using the attention scores.

+ +
Arguments:
+ +
    +
  • encodings (LatentsDomainGroupT): Unimodal latent representation
  • +
  • attention_dict (dict[str, torch.Tensor]): The attention scores for each +domain in the group.
  • +
+ +
Returns:
+ +
+

torch.Tensor: The fused tensor.

+
+
+ + +
+
+ +
+ + def + forward( self, domains: collections.abc.Mapping[str, torch.Tensor], encodings_pre_fusion: collections.abc.Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: + + + +
+ +
306    def forward(
+307        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
+308    ) -> dict[str, torch.Tensor]:
+309        """
+310        Compute keys and queries, match them with dot product and softmax.
+311        Does this twice, once with the static query and once with a dynamic query.
+312
+313        Args:
+314            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
+315            encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
+316
+317        Returns:
+318            `dict[str, torch.Tensor]`: the attention scores for each domain in the
+319            group.
+320        """
+321
+322        keys = {
+323            domain: self.key_layers[domain](encoding)
+324            for domain, encoding in domains.items()
+325        }
+326
+327        batch_size = group_batch_size(domains)
+328
+329        # Retrieve random query
+330        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
+331
+332        # Calculate the attention scores
+333        static_attention_dict = _calculate_attention_dict(domains, keys, query)
+334
+335        # Apply the attention scores to the encodings
+336        summed_tensor = self.fuse_weighted_encodings(
+337            encodings_pre_fusion, static_attention_dict
+338        )
+339
+340        # Retrieve query (now it is dependent on the new gw state)
+341        query = self.query_layer(summed_tensor)
+342
+343        # Calculate the attention scores again
+344        dynamic_attention_dict = _calculate_attention_dict(domains, keys, query)
+345
+346        return dynamic_attention_dict
+
+ + +

Compute keys and queries, match them with dot product and softmax. +Does this twice, once with the static query and once with a dynamic query.

+ +
Arguments:
+ +
    +
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • +
  • encodings (LatentsDomainGroupT): Group of pre-fusion encodings.
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: the attention scores for each domain in the + group.

+
+
+ + +
+
+
Inherited Members
+
+ +
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/utils.html b/docs/api/v0.5.1/shimmer/modules/utils.html new file mode 100644 index 00000000..790987da --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/utils.html @@ -0,0 +1,760 @@ + + + + + + + shimmer.modules.utils API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.utils

+ + + + + + +
  1from collections.abc import Iterable
+  2
+  3import torch
+  4
+  5from shimmer.modules.gw_module import GWModuleBase
+  6from shimmer.modules.selection import SelectionBase
+  7from shimmer.types import (
+  8    LatentsDomainGroupDT,
+  9    LatentsDomainGroupsT,
+ 10    LatentsDomainGroupT,
+ 11)
+ 12
+ 13
+ 14def translation(
+ 15    gw_module: GWModuleBase,
+ 16    selection_mod: SelectionBase,
+ 17    x: LatentsDomainGroupT,
+ 18    to: str,
+ 19) -> torch.Tensor:
+ 20    """
+ 21    Translate from multiple domains to one domain.
+ 22
+ 23    Args:
+ 24        gw_module (`GWModuleBase`): GWModule to perform the translation over
+ 25        selection_mod (`SelectionBase`): selection module
+ 26        x (`LatentsDomainGroupT`): the group of latent representations
+ 27        to (`str`): the domain name to encode to
+ 28
+ 29    Returns:
+ 30        `torch.Tensor`: the translated unimodal representation
+ 31            of the provided domain.
+ 32    """
+ 33    return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[
+ 34        to
+ 35    ]
+ 36
+ 37
+ 38def cycle(
+ 39    gw_module: GWModuleBase,
+ 40    selection_mod: SelectionBase,
+ 41    x: LatentsDomainGroupT,
+ 42    through: str,
+ 43) -> LatentsDomainGroupDT:
+ 44    """
+ 45    Do a full cycle from a group of representation through one domain.
+ 46
+ 47    [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
+ 48
+ 49    Args:
+ 50        gw_module (`GWModuleBase`): GWModule to perform the translation over
+ 51        selection_mod (`SelectionBase`): selection module
+ 52        x (`LatentsDomainGroupT`): group of unimodal latent representation
+ 53        through (`str`): domain name to cycle through
+ 54    Returns:
+ 55        `LatentsDomainGroupDT`: group of unimodal latent representation after
+ 56            cycling.
+ 57    """
+ 58    return {
+ 59        domain: translation(
+ 60            gw_module,
+ 61            selection_mod,
+ 62            {through: translation(gw_module, selection_mod, x, through)},
+ 63            domain,
+ 64        )
+ 65        for domain in x
+ 66    }
+ 67
+ 68
+ 69def batch_demi_cycles(
+ 70    gw_mod: GWModuleBase,
+ 71    selection_mod: SelectionBase,
+ 72    latent_domains: LatentsDomainGroupsT,
+ 73) -> dict[str, torch.Tensor]:
+ 74    """
+ 75    Computes demi-cycles of a batch of groups of domains.
+ 76
+ 77    Args:
+ 78        gw_mod (`GWModuleBase`): the GWModuleBase
+ 79        selection_mod (`SelectionBase`): selection module
+ 80        latent_domains (`LatentsT`): the batch of groups of domains
+ 81
+ 82    Returns:
+ 83        `dict[str, torch.Tensor]`: demi-cycles predictions for each domain.
+ 84    """
+ 85    predictions: dict[str, torch.Tensor] = {}
+ 86    for domains, latents in latent_domains.items():
+ 87        if len(domains) > 1:
+ 88            continue
+ 89        domain_name = list(domains)[0]
+ 90        z = translation(gw_mod, selection_mod, latents, to=domain_name)
+ 91        predictions[domain_name] = z
+ 92    return predictions
+ 93
+ 94
+ 95def batch_cycles(
+ 96    gw_mod: GWModuleBase,
+ 97    selection_mod: SelectionBase,
+ 98    latent_domains: LatentsDomainGroupsT,
+ 99    through_domains: Iterable[str],
+100) -> dict[tuple[str, str], torch.Tensor]:
+101    """
+102    Computes cycles of a batch of groups of domains.
+103
+104    Args:
+105        gw_mod (`GWModuleBase`): GWModule to use for the cycle
+106        selection_mod (`SelectionBase`): selection module
+107        latent_domains (`LatentsT`): the batch of groups of domains
+108        out_domains (`Iterable[str]`): iterable of domain names to do the cycle through.
+109            Each domain will be done separetely.
+110
+111    Returns:
+112        `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each
+113            couple of (start domain, intermediary domain).
+114    """
+115    predictions: dict[tuple[str, str], torch.Tensor] = {}
+116    for domains_source, latents_source in latent_domains.items():
+117        if len(domains_source) > 1:
+118            continue
+119        domain_name_source = next(iter(domains_source))
+120        for domain_name_through in through_domains:
+121            if domain_name_source == domain_name_through:
+122                continue
+123            z = cycle(
+124                gw_mod, selection_mod, latents_source, through=domain_name_through
+125            )
+126            domains = (domain_name_source, domain_name_through)
+127            predictions[domains] = z[domain_name_source]
+128    return predictions
+129
+130
+131def batch_translations(
+132    gw_mod: GWModuleBase,
+133    selection_mod: SelectionBase,
+134    latent_domains: LatentsDomainGroupsT,
+135) -> dict[tuple[str, str], torch.Tensor]:
+136    """
+137    Computes translations of a batch of groups of domains.
+138
+139    Args:
+140        gw_mod (`GWModuleBase`): GWModule to do the translation
+141        selection_mod (`SelectionBase`): selection module
+142        latent_domains (`LatentsT`): the batch of groups of domains
+143
+144    Returns:
+145        `dict[tuple[str, str], torch.Tensor]`: translation predictions for each
+146            couple of (start domain, target domain).
+147    """
+148    predictions: dict[tuple[str, str], torch.Tensor] = {}
+149    for domains, latents in latent_domains.items():
+150        if len(domains) < 2:
+151            continue
+152        for domain_name_source in domains:
+153            for domain_name_target in domains:
+154                if domain_name_source == domain_name_target:
+155                    continue
+156                prediction = translation(
+157                    gw_mod,
+158                    selection_mod,
+159                    {domain_name_source: latents[domain_name_source]},
+160                    to=domain_name_target,
+161                )
+162                predictions[(domain_name_source, domain_name_target)] = prediction
+163    return predictions
+
+ + +
+
+ +
+ + def + translation( gw_module: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, x: collections.abc.Mapping[str, torch.Tensor], to: str) -> torch.Tensor: + + + +
+ +
15def translation(
+16    gw_module: GWModuleBase,
+17    selection_mod: SelectionBase,
+18    x: LatentsDomainGroupT,
+19    to: str,
+20) -> torch.Tensor:
+21    """
+22    Translate from multiple domains to one domain.
+23
+24    Args:
+25        gw_module (`GWModuleBase`): GWModule to perform the translation over
+26        selection_mod (`SelectionBase`): selection module
+27        x (`LatentsDomainGroupT`): the group of latent representations
+28        to (`str`): the domain name to encode to
+29
+30    Returns:
+31        `torch.Tensor`: the translated unimodal representation
+32            of the provided domain.
+33    """
+34    return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[
+35        to
+36    ]
+
+ + +

Translate from multiple domains to one domain.

+ +
Arguments:
+ +
    +
  • gw_module (GWModuleBase): GWModule to perform the translation over
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • x (LatentsDomainGroupT): the group of latent representations
  • +
  • to (str): the domain name to encode to
  • +
+ +
Returns:
+ +
+

torch.Tensor: the translated unimodal representation + of the provided domain.

+
+
+ + +
+
+ +
+ + def + cycle( gw_module: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, x: collections.abc.Mapping[str, torch.Tensor], through: str) -> dict[str, torch.Tensor]: + + + +
+ +
39def cycle(
+40    gw_module: GWModuleBase,
+41    selection_mod: SelectionBase,
+42    x: LatentsDomainGroupT,
+43    through: str,
+44) -> LatentsDomainGroupDT:
+45    """
+46    Do a full cycle from a group of representation through one domain.
+47
+48    [Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]
+49
+50    Args:
+51        gw_module (`GWModuleBase`): GWModule to perform the translation over
+52        selection_mod (`SelectionBase`): selection module
+53        x (`LatentsDomainGroupT`): group of unimodal latent representation
+54        through (`str`): domain name to cycle through
+55    Returns:
+56        `LatentsDomainGroupDT`: group of unimodal latent representation after
+57            cycling.
+58    """
+59    return {
+60        domain: translation(
+61            gw_module,
+62            selection_mod,
+63            {through: translation(gw_module, selection_mod, x, through)},
+64            domain,
+65        )
+66        for domain in x
+67    }
+
+ + +

Do a full cycle from a group of representation through one domain.

+ +

[Original domains] -> [GW] -> [through] -> [GW] -> [Original domains]

+ +
Arguments:
+ +
    +
  • gw_module (GWModuleBase): GWModule to perform the translation over
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • x (LatentsDomainGroupT): group of unimodal latent representation
  • +
  • through (str): domain name to cycle through
  • +
+ +
Returns:
+ +
+

LatentsDomainGroupDT: group of unimodal latent representation after + cycling.

+
+
+ + +
+
+ +
+ + def + batch_demi_cycles( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + + + +
+ +
70def batch_demi_cycles(
+71    gw_mod: GWModuleBase,
+72    selection_mod: SelectionBase,
+73    latent_domains: LatentsDomainGroupsT,
+74) -> dict[str, torch.Tensor]:
+75    """
+76    Computes demi-cycles of a batch of groups of domains.
+77
+78    Args:
+79        gw_mod (`GWModuleBase`): the GWModuleBase
+80        selection_mod (`SelectionBase`): selection module
+81        latent_domains (`LatentsT`): the batch of groups of domains
+82
+83    Returns:
+84        `dict[str, torch.Tensor]`: demi-cycles predictions for each domain.
+85    """
+86    predictions: dict[str, torch.Tensor] = {}
+87    for domains, latents in latent_domains.items():
+88        if len(domains) > 1:
+89            continue
+90        domain_name = list(domains)[0]
+91        z = translation(gw_mod, selection_mod, latents, to=domain_name)
+92        predictions[domain_name] = z
+93    return predictions
+
+ + +

Computes demi-cycles of a batch of groups of domains.

+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): the GWModuleBase
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • latent_domains (LatentsT): the batch of groups of domains
  • +
+ +
Returns:
+ +
+

dict[str, torch.Tensor]: demi-cycles predictions for each domain.

+
+
+ + +
+
+ +
+ + def + batch_cycles( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], through_domains: collections.abc.Iterable[str]) -> dict[tuple[str, str], torch.Tensor]: + + + +
+ +
 96def batch_cycles(
+ 97    gw_mod: GWModuleBase,
+ 98    selection_mod: SelectionBase,
+ 99    latent_domains: LatentsDomainGroupsT,
+100    through_domains: Iterable[str],
+101) -> dict[tuple[str, str], torch.Tensor]:
+102    """
+103    Computes cycles of a batch of groups of domains.
+104
+105    Args:
+106        gw_mod (`GWModuleBase`): GWModule to use for the cycle
+107        selection_mod (`SelectionBase`): selection module
+108        latent_domains (`LatentsT`): the batch of groups of domains
+109        out_domains (`Iterable[str]`): iterable of domain names to do the cycle through.
+110            Each domain will be done separetely.
+111
+112    Returns:
+113        `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each
+114            couple of (start domain, intermediary domain).
+115    """
+116    predictions: dict[tuple[str, str], torch.Tensor] = {}
+117    for domains_source, latents_source in latent_domains.items():
+118        if len(domains_source) > 1:
+119            continue
+120        domain_name_source = next(iter(domains_source))
+121        for domain_name_through in through_domains:
+122            if domain_name_source == domain_name_through:
+123                continue
+124            z = cycle(
+125                gw_mod, selection_mod, latents_source, through=domain_name_through
+126            )
+127            domains = (domain_name_source, domain_name_through)
+128            predictions[domains] = z[domain_name_source]
+129    return predictions
+
+ + +

Computes cycles of a batch of groups of domains.

+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): GWModule to use for the cycle
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • latent_domains (LatentsT): the batch of groups of domains
  • +
  • out_domains (Iterable[str]): iterable of domain names to do the cycle through. +Each domain will be done separetely.
  • +
+ +
Returns:
+ +
+

dict[tuple[str, str], torch.Tensor]: cycles predictions for each + couple of (start domain, intermediary domain).

+
+
+ + +
+
+ +
+ + def + batch_translations( gw_mod: shimmer.modules.gw_module.GWModuleBase, selection_mod: shimmer.modules.selection.SelectionBase, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[tuple[str, str], torch.Tensor]: + + + +
+ +
132def batch_translations(
+133    gw_mod: GWModuleBase,
+134    selection_mod: SelectionBase,
+135    latent_domains: LatentsDomainGroupsT,
+136) -> dict[tuple[str, str], torch.Tensor]:
+137    """
+138    Computes translations of a batch of groups of domains.
+139
+140    Args:
+141        gw_mod (`GWModuleBase`): GWModule to do the translation
+142        selection_mod (`SelectionBase`): selection module
+143        latent_domains (`LatentsT`): the batch of groups of domains
+144
+145    Returns:
+146        `dict[tuple[str, str], torch.Tensor]`: translation predictions for each
+147            couple of (start domain, target domain).
+148    """
+149    predictions: dict[tuple[str, str], torch.Tensor] = {}
+150    for domains, latents in latent_domains.items():
+151        if len(domains) < 2:
+152            continue
+153        for domain_name_source in domains:
+154            for domain_name_target in domains:
+155                if domain_name_source == domain_name_target:
+156                    continue
+157                prediction = translation(
+158                    gw_mod,
+159                    selection_mod,
+160                    {domain_name_source: latents[domain_name_source]},
+161                    to=domain_name_target,
+162                )
+163                predictions[(domain_name_source, domain_name_target)] = prediction
+164    return predictions
+
+ + +

Computes translations of a batch of groups of domains.

+ +
Arguments:
+ +
    +
  • gw_mod (GWModuleBase): GWModule to do the translation
  • +
  • selection_mod (SelectionBase): selection module
  • +
  • latent_domains (LatentsT): the batch of groups of domains
  • +
+ +
Returns:
+ +
+

dict[tuple[str, str], torch.Tensor]: translation predictions for each + couple of (start domain, target domain).

+
+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/modules/vae.html b/docs/api/v0.5.1/shimmer/modules/vae.html new file mode 100644 index 00000000..c251cdcb --- /dev/null +++ b/docs/api/v0.5.1/shimmer/modules/vae.html @@ -0,0 +1,1288 @@ + + + + + + + shimmer.modules.vae API documentation + + + + + + + + + + + + + +
+
+

+shimmer.modules.vae

+ + + + + + +
  1import math
+  2from abc import ABC, abstractmethod
+  3from typing import Any
+  4
+  5import torch
+  6from torch import nn
+  7
+  8
+  9def reparameterize(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+ 10    """
+ 11    Reparameterization trick for VAE
+ 12
+ 13    Args:
+ 14        mean (`torch.Tensor`): predicted means
+ 15        logvar (`torch.Tensor`): predicted log variance
+ 16
+ 17    Returns:
+ 18        `torch.Tensor`: a sample from normal distribution with provided
+ 19            parameters, sampled using the reparameterization trick.
+ 20    """
+ 21    std = (0.5 * logvar).exp()
+ 22    eps = torch.randn_like(std)
+ 23    return std * eps + mean
+ 24
+ 25
+ 26def kl_divergence_loss(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+ 27    """
+ 28    Computes the KL divergence loss used in VAE.
+ 29
+ 30    Args:
+ 31        mean (`torch.Tensor`): predicted means
+ 32        logvar (`torch.Tensor`): predicted logvars
+ 33
+ 34    Returns:
+ 35        `torch.Tensor`: the loss
+ 36    """
+ 37    kl = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
+ 38    return kl
+ 39
+ 40
+ 41def gaussian_nll(
+ 42    mu: torch.Tensor, log_sigma: torch.Tensor, x: torch.Tensor
+ 43) -> torch.Tensor:
+ 44    """
+ 45    Computes gaussian nll loss used in VAE.
+ 46
+ 47    Args:
+ 48        mu (`torch.Tensor`): predictions
+ 49        log_sigma (`torch.Tensor`): log sigma
+ 50        x (`torch.Tensor`): targets
+ 51
+ 52    Returns:
+ 53        `torch.Tensor`: the Gaussian NLL loss
+ 54    """
+ 55    return (
+ 56        0.5 * torch.pow((x - mu) / log_sigma.exp(), 2)
+ 57        + log_sigma
+ 58        + 0.5 * math.log(2 * math.pi)
+ 59    )
+ 60
+ 61
+ 62class VAEEncoder(nn.Module, ABC):
+ 63    """
+ 64    Base class for a VAE encoder.
+ 65    """
+ 66
+ 67    @abstractmethod
+ 68    def forward(self, x: Any) -> tuple[torch.Tensor, torch.Tensor]:
+ 69        """
+ 70        Encode representation with VAE.
+ 71
+ 72
+ 73        Args:
+ 74            x (`Any`): Some input value
+ 75
+ 76        Returns:
+ 77            `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+ 78        """
+ 79        ...
+ 80
+ 81
+ 82class VAEDecoder(nn.Module, ABC):
+ 83    """
+ 84    Base class for a VAE decoder.
+ 85    """
+ 86
+ 87    @abstractmethod
+ 88    def forward(self, x: torch.Tensor) -> Any:
+ 89        """
+ 90        Decode representation with VAE
+ 91
+ 92        Args:
+ 93            x (`torch.Tensor`): VAE latent representation representation
+ 94
+ 95        Returns:
+ 96            `Any`: the reconstructed input
+ 97        """
+ 98        ...
+ 99
+100
+101class VAE(nn.Module):
+102    """VAE module"""
+103
+104    def __init__(
+105        self,
+106        encoder: VAEEncoder,
+107        decoder: VAEDecoder,
+108        beta: float = 1,
+109    ):
+110        """
+111        Initializes a VAE.
+112
+113        Args:
+114            encoder (`VAEEncoder`): VAE encode
+115            decoder (`VAEDecoder`): VAE decoder
+116            beta (`float`): beta value for Beta-VAE. Defaults to 1.
+117        """
+118        super().__init__()
+119
+120        assert beta >= 0
+121
+122        self.beta = beta
+123        """Beta value for Beta-VAEs"""
+124
+125        self.encoder = encoder
+126        """The encoder"""
+127
+128        self.decoder = decoder
+129        """The decoder"""
+130
+131    def encode(self, x: Any) -> torch.Tensor:
+132        """
+133        Encode the representation and returns the mean prediction of VAE.
+134
+135        Args:
+136            x (`Any`): Some input value
+137
+138        Returns:
+139            `torch.Tensor`: The mean representation.
+140        """
+141        mean_z, _ = self.encoder(x)
+142        return mean_z
+143
+144    def decode(self, z: torch.Tensor) -> Any:
+145        """
+146        Decode the VAE latent representation into input value.
+147
+148        Args:
+149            z (`torch.Tensor`): the VAE latent representation.
+150
+151        Returns:
+152            `Any`: the reconstructed input.
+153        """
+154        return self.decoder(z)
+155
+156    def forward(self, x: Any) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
+157        """
+158        Encode and decodes from x.
+159
+160        Args:
+161            x (`Any`): the input data
+162
+163        Returns:
+164            `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+165                first tuple contains the mean and logvar of the encoded input,
+166                the second item is the reconstructed input.
+167        """
+168        mean, logvar = self.encoder(x)
+169        z = reparameterize(mean, logvar)
+170
+171        x_reconstructed = self.decoder(z)
+172
+173        return (mean, logvar), x_reconstructed
+
+ + +
+
+ +
+ + def + reparameterize(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + + + +
+ +
10def reparameterize(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+11    """
+12    Reparameterization trick for VAE
+13
+14    Args:
+15        mean (`torch.Tensor`): predicted means
+16        logvar (`torch.Tensor`): predicted log variance
+17
+18    Returns:
+19        `torch.Tensor`: a sample from normal distribution with provided
+20            parameters, sampled using the reparameterization trick.
+21    """
+22    std = (0.5 * logvar).exp()
+23    eps = torch.randn_like(std)
+24    return std * eps + mean
+
+ + +

Reparameterization trick for VAE

+ +
Arguments:
+ +
    +
  • mean (torch.Tensor): predicted means
  • +
  • logvar (torch.Tensor): predicted log variance
  • +
+ +
Returns:
+ +
+

torch.Tensor: a sample from normal distribution with provided + parameters, sampled using the reparameterization trick.

+
+
+ + +
+
+ +
+ + def + kl_divergence_loss(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + + + +
+ +
27def kl_divergence_loss(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+28    """
+29    Computes the KL divergence loss used in VAE.
+30
+31    Args:
+32        mean (`torch.Tensor`): predicted means
+33        logvar (`torch.Tensor`): predicted logvars
+34
+35    Returns:
+36        `torch.Tensor`: the loss
+37    """
+38    kl = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
+39    return kl
+
+ + +

Computes the KL divergence loss used in VAE.

+ +
Arguments:
+ +
    +
  • mean (torch.Tensor): predicted means
  • +
  • logvar (torch.Tensor): predicted logvars
  • +
+ +
Returns:
+ +
+

torch.Tensor: the loss

+
+
+ + +
+
+ +
+ + def + gaussian_nll( mu: torch.Tensor, log_sigma: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + + + +
+ +
42def gaussian_nll(
+43    mu: torch.Tensor, log_sigma: torch.Tensor, x: torch.Tensor
+44) -> torch.Tensor:
+45    """
+46    Computes gaussian nll loss used in VAE.
+47
+48    Args:
+49        mu (`torch.Tensor`): predictions
+50        log_sigma (`torch.Tensor`): log sigma
+51        x (`torch.Tensor`): targets
+52
+53    Returns:
+54        `torch.Tensor`: the Gaussian NLL loss
+55    """
+56    return (
+57        0.5 * torch.pow((x - mu) / log_sigma.exp(), 2)
+58        + log_sigma
+59        + 0.5 * math.log(2 * math.pi)
+60    )
+
+ + +

Computes gaussian nll loss used in VAE.

+ +
Arguments:
+ +
    +
  • mu (torch.Tensor): predictions
  • +
  • log_sigma (torch.Tensor): log sigma
  • +
  • x (torch.Tensor): targets
  • +
+ +
Returns:
+ +
+

torch.Tensor: the Gaussian NLL loss

+
+
+ + +
+
+ +
+ + class + VAEEncoder(torch.nn.modules.module.Module, abc.ABC): + + + +
+ +
63class VAEEncoder(nn.Module, ABC):
+64    """
+65    Base class for a VAE encoder.
+66    """
+67
+68    @abstractmethod
+69    def forward(self, x: Any) -> tuple[torch.Tensor, torch.Tensor]:
+70        """
+71        Encode representation with VAE.
+72
+73
+74        Args:
+75            x (`Any`): Some input value
+76
+77        Returns:
+78            `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+79        """
+80        ...
+
+ + +

Base class for a VAE encoder.

+
+ + +
+ +
+
@abstractmethod
+ + def + forward(self, x: Any) -> tuple[torch.Tensor, torch.Tensor]: + + + +
+ +
68    @abstractmethod
+69    def forward(self, x: Any) -> tuple[torch.Tensor, torch.Tensor]:
+70        """
+71        Encode representation with VAE.
+72
+73
+74        Args:
+75            x (`Any`): Some input value
+76
+77        Returns:
+78            `tuple[torch.Tensor, torch.Tensor]`: the mean and log variance
+79        """
+80        ...
+
+ + +

Encode representation with VAE.

+ +
Arguments:
+ +
    +
  • x (Any): Some input value
  • +
+ +
Returns:
+ +
+

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + VAEDecoder(torch.nn.modules.module.Module, abc.ABC): + + + +
+ +
83class VAEDecoder(nn.Module, ABC):
+84    """
+85    Base class for a VAE decoder.
+86    """
+87
+88    @abstractmethod
+89    def forward(self, x: torch.Tensor) -> Any:
+90        """
+91        Decode representation with VAE
+92
+93        Args:
+94            x (`torch.Tensor`): VAE latent representation representation
+95
+96        Returns:
+97            `Any`: the reconstructed input
+98        """
+99        ...
+
+ + +

Base class for a VAE decoder.

+
+ + +
+ +
+
@abstractmethod
+ + def + forward(self, x: torch.Tensor) -> Any: + + + +
+ +
88    @abstractmethod
+89    def forward(self, x: torch.Tensor) -> Any:
+90        """
+91        Decode representation with VAE
+92
+93        Args:
+94            x (`torch.Tensor`): VAE latent representation representation
+95
+96        Returns:
+97            `Any`: the reconstructed input
+98        """
+99        ...
+
+ + +

Decode representation with VAE

+ +
Arguments:
+ +
    +
  • x (torch.Tensor): VAE latent representation representation
  • +
+ +
Returns:
+ +
+

Any: the reconstructed input

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ +
+ + class + VAE(torch.nn.modules.module.Module): + + + +
+ +
102class VAE(nn.Module):
+103    """VAE module"""
+104
+105    def __init__(
+106        self,
+107        encoder: VAEEncoder,
+108        decoder: VAEDecoder,
+109        beta: float = 1,
+110    ):
+111        """
+112        Initializes a VAE.
+113
+114        Args:
+115            encoder (`VAEEncoder`): VAE encode
+116            decoder (`VAEDecoder`): VAE decoder
+117            beta (`float`): beta value for Beta-VAE. Defaults to 1.
+118        """
+119        super().__init__()
+120
+121        assert beta >= 0
+122
+123        self.beta = beta
+124        """Beta value for Beta-VAEs"""
+125
+126        self.encoder = encoder
+127        """The encoder"""
+128
+129        self.decoder = decoder
+130        """The decoder"""
+131
+132    def encode(self, x: Any) -> torch.Tensor:
+133        """
+134        Encode the representation and returns the mean prediction of VAE.
+135
+136        Args:
+137            x (`Any`): Some input value
+138
+139        Returns:
+140            `torch.Tensor`: The mean representation.
+141        """
+142        mean_z, _ = self.encoder(x)
+143        return mean_z
+144
+145    def decode(self, z: torch.Tensor) -> Any:
+146        """
+147        Decode the VAE latent representation into input value.
+148
+149        Args:
+150            z (`torch.Tensor`): the VAE latent representation.
+151
+152        Returns:
+153            `Any`: the reconstructed input.
+154        """
+155        return self.decoder(z)
+156
+157    def forward(self, x: Any) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
+158        """
+159        Encode and decodes from x.
+160
+161        Args:
+162            x (`Any`): the input data
+163
+164        Returns:
+165            `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+166                first tuple contains the mean and logvar of the encoded input,
+167                the second item is the reconstructed input.
+168        """
+169        mean, logvar = self.encoder(x)
+170        z = reparameterize(mean, logvar)
+171
+172        x_reconstructed = self.decoder(z)
+173
+174        return (mean, logvar), x_reconstructed
+
+ + +

VAE module

+
+ + +
+ +
+ + VAE( encoder: VAEEncoder, decoder: VAEDecoder, beta: float = 1) + + + +
+ +
105    def __init__(
+106        self,
+107        encoder: VAEEncoder,
+108        decoder: VAEDecoder,
+109        beta: float = 1,
+110    ):
+111        """
+112        Initializes a VAE.
+113
+114        Args:
+115            encoder (`VAEEncoder`): VAE encode
+116            decoder (`VAEDecoder`): VAE decoder
+117            beta (`float`): beta value for Beta-VAE. Defaults to 1.
+118        """
+119        super().__init__()
+120
+121        assert beta >= 0
+122
+123        self.beta = beta
+124        """Beta value for Beta-VAEs"""
+125
+126        self.encoder = encoder
+127        """The encoder"""
+128
+129        self.decoder = decoder
+130        """The decoder"""
+
+ + +

Initializes a VAE.

+ +
Arguments:
+ +
    +
  • encoder (VAEEncoder): VAE encode
  • +
  • decoder (VAEDecoder): VAE decoder
  • +
  • beta (float): beta value for Beta-VAE. Defaults to 1.
  • +
+
+ + +
+
+
+ beta + + +
+ + +

Beta value for Beta-VAEs

+
+ + +
+
+
+ encoder + + +
+ + +

The encoder

+
+ + +
+
+
+ decoder + + +
+ + +

The decoder

+
+ + +
+
+ +
+ + def + encode(self, x: Any) -> torch.Tensor: + + + +
+ +
132    def encode(self, x: Any) -> torch.Tensor:
+133        """
+134        Encode the representation and returns the mean prediction of VAE.
+135
+136        Args:
+137            x (`Any`): Some input value
+138
+139        Returns:
+140            `torch.Tensor`: The mean representation.
+141        """
+142        mean_z, _ = self.encoder(x)
+143        return mean_z
+
+ + +

Encode the representation and returns the mean prediction of VAE.

+ +
Arguments:
+ +
    +
  • x (Any): Some input value
  • +
+ +
Returns:
+ +
+

torch.Tensor: The mean representation.

+
+
+ + +
+
+ +
+ + def + decode(self, z: torch.Tensor) -> Any: + + + +
+ +
145    def decode(self, z: torch.Tensor) -> Any:
+146        """
+147        Decode the VAE latent representation into input value.
+148
+149        Args:
+150            z (`torch.Tensor`): the VAE latent representation.
+151
+152        Returns:
+153            `Any`: the reconstructed input.
+154        """
+155        return self.decoder(z)
+
+ + +

Decode the VAE latent representation into input value.

+ +
Arguments:
+ +
    +
  • z (torch.Tensor): the VAE latent representation.
  • +
+ +
Returns:
+ +
+

Any: the reconstructed input.

+
+
+ + +
+
+ +
+ + def + forward(self, x: Any) -> tuple[tuple[torch.Tensor, torch.Tensor], typing.Any]: + + + +
+ +
157    def forward(self, x: Any) -> tuple[tuple[torch.Tensor, torch.Tensor], Any]:
+158        """
+159        Encode and decodes from x.
+160
+161        Args:
+162            x (`Any`): the input data
+163
+164        Returns:
+165            `tuple[tuple[torch.Tensor, torch.Tensor], Any]`: The
+166                first tuple contains the mean and logvar of the encoded input,
+167                the second item is the reconstructed input.
+168        """
+169        mean, logvar = self.encoder(x)
+170        z = reparameterize(mean, logvar)
+171
+172        x_reconstructed = self.decoder(z)
+173
+174        return (mean, logvar), x_reconstructed
+
+ + +

Encode and decodes from x.

+ +
Arguments:
+ +
    +
  • x (Any): the input data
  • +
+ +
Returns:
+ +
+

tuple[tuple[torch.Tensor, torch.Tensor], Any]: The + first tuple contains the mean and logvar of the encoded input, + the second item is the reconstructed input.

+
+
+ + +
+
+
Inherited Members
+
+
torch.nn.modules.module.Module
+
dump_patches
+
training
+
call_super_init
+
register_buffer
+
register_parameter
+
add_module
+
register_module
+
get_submodule
+
get_parameter
+
get_buffer
+
get_extra_state
+
set_extra_state
+
apply
+
cuda
+
ipu
+
xpu
+
cpu
+
type
+
float
+
double
+
half
+
bfloat16
+
to_empty
+
to
+
register_full_backward_pre_hook
+
register_backward_hook
+
register_full_backward_hook
+
register_forward_pre_hook
+
register_forward_hook
+
register_state_dict_pre_hook
+
state_dict
+
register_load_state_dict_post_hook
+
load_state_dict
+
parameters
+
named_parameters
+
buffers
+
named_buffers
+
children
+
named_children
+
modules
+
named_modules
+
train
+
eval
+
requires_grad_
+
zero_grad
+
share_memory
+
extra_repr
+
compile
+ +
+
+
+
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/types.html b/docs/api/v0.5.1/shimmer/types.html new file mode 100644 index 00000000..a1057714 --- /dev/null +++ b/docs/api/v0.5.1/shimmer/types.html @@ -0,0 +1,872 @@ + + + + + + + shimmer.types API documentation + + + + + + + + + + + + + +
+
+

+shimmer.types

+ + + + + + +
  1from collections.abc import Mapping
+  2from typing import Any, Literal
+  3
+  4import torch
+  5
+  6RawDomainGroupT = Mapping[str, Any]
+  7"""
+  8Matched raw unimodal data from multiple domains.
+  9Keys of the mapping are domains names and values are the domain data.
+ 10
+ 11All values in the mapping should be matched and represent the same information.
+ 12
+ 13Example:
+ 14    ```python
+ 15    def fun(domain_group: RawDomainGroupT): ...
+ 16
+ 17
+ 18    x = {
+ 19        "vision": PIL.Image.Image("path/to/dog/picture.png"),
+ 20        "language": "This is a picture of a dog.",
+ 21    }
+ 22
+ 23    fun(x)
+ 24    ```
+ 25
+ 26Note:
+ 27    This type uses `collections.abc.Mapping` and is used for functions' inputs.
+ 28    Use `RawDomainGroupDT` for functions' outputs.
+ 29
+ 30    This allows to be more generic and allow passing other mappings.
+ 31"""
+ 32
+ 33RawDomainGroupDT = dict[str, Any]
+ 34"""
+ 35Output type version of `RawDomainGroupT`.
+ 36Matched raw unimodal data from multiple domains.
+ 37Keys of the mapping are domains names and values are the domain data.
+ 38
+ 39Example:
+ 40    ```python
+ 41    def fun() -> RawDomainGroupDT:
+ 42        return {
+ 43            "vision": PIL.Image.Image("path/to/dog/picture.png"),
+ 44            "language": "This is a picture of a dog.",
+ 45        }
+ 46
+ 47    ```
+ 48
+ 49Note:
+ 50    This type uses `dict`s and is used for functions' outputs.
+ 51    Use `RawDomainGroupT` for functions' inputs.
+ 52
+ 53"""
+ 54
+ 55LatentsDomainGroupT = Mapping[str, torch.Tensor]
+ 56"""
+ 57Matched unimodal latent representations from multiple domains.
+ 58Keys of the mapping are domains names and values are `torch.Tensor` latent
+ 59representation of the domain.
+ 60
+ 61Example:
+ 62    ```python
+ 63    def fun(domain_group: LatentsDomainGroupT): ...
+ 64
+ 65
+ 66    x = {
+ 67        "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+ 68        "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+ 69    }
+ 70
+ 71    fun(x)
+ 72    ```
+ 73
+ 74Note:
+ 75    This type uses `collections.abc.Mapping` and is used for functions' inputs.
+ 76    Use `LatentsDomainGroupDT` for functions' outputs.
+ 77
+ 78    This allows to be more generic and allow passing other mappings.
+ 79"""
+ 80
+ 81LatentsDomainGroupDT = dict[str, torch.Tensor]
+ 82"""
+ 83Matched unimodal latent representations from multiple domains.
+ 84Keys of the dict are domains names and values are `torch.Tensor` latent
+ 85representation of the domain.
+ 86
+ 87Example:
+ 88    ```python
+ 89    def fun() -> LatentsDomainGroupDT:
+ 90        return {
+ 91            "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+ 92            "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+ 93        }
+ 94
+ 95    ```
+ 96
+ 97Note:
+ 98    This type uses `dict`s and is used for functions' outputs.
+ 99    Use `LatentsDomainGroupT` for functions' inputs.
+100"""
+101
+102RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT]
+103"""
+104Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
+105Each group is independent and contains different data (unpaired).
+106
+107Example:
+108    ```python
+109    def fun() -> RawDomainGroupsDT:
+110        return {
+111            frozenset(["vision"]): {
+112                "vision": PIL.Image.Image("path/to/cat/picture.png"),
+113            },
+114            frozenset(["language"]): {
+115                "language": "This is a picture of a rabbit.",
+116            },
+117            frozenset(["vision", "language"]): {
+118                "vision": PIL.Image.Image("path/to/dog/picture.png"),
+119                "language": "This is a picture of a dog.",
+120            },
+121        }
+122
+123    ```
+124
+125Note:
+126    This type uses `dict`s and is used for functions' outputs.
+127    Use `RawDomainGroupsT` for functions' inputs.
+128"""
+129
+130RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT]
+131"""
+132Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group.
+133Each group is independent and contains different data (unpaired).
+134
+135Example:
+136    ```python
+137    def fun() -> RawDomainGroupsDT:
+138        return {
+139            frozenset(["vision"]): {
+140                "vision": PIL.Image.Image("path/to/cat/picture.png"),
+141            },
+142            frozenset(["language"]): {
+143                "language": "This is a picture of a rabbit.",
+144            },
+145            frozenset(["vision", "language"]): {
+146                "vision": PIL.Image.Image("path/to/dog/picture.png"),
+147                "language": "This is a picture of a dog.",
+148            },
+149        }
+150
+151    ```
+152
+153Note:
+154    This type uses `dict`s and is used for functions' outputs.
+155    Use `RawDomainGroupsT` for functions' inputs.
+156"""
+157
+158LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT]
+159"""
+160Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group.
+161Each group is independent and contains different data (unpaired).
+162
+163Example:
+164    ```python
+165    def fun(domain_group: LatentsDomainGroupsT): ...
+166
+167
+168    x = {
+169        frozenset(["vision"]): {
+170            "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+171        },
+172        frozenset(["language"]): {
+173            "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+174        },
+175        frozenset(["vision", "language"]): {
+176            "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+177            "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+178        },
+179    }
+180
+181    fun(x)
+182    ```
+183Note:
+184    This type uses `collections.abc.Mapping` and is used for functions' inputs.
+185    Use `LatentsDomainGroupsDT` for functions' outputs.
+186
+187    This allows to be more generic and allow passing other mappings.
+188
+189"""
+190
+191LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT]
+192"""
+193Mapping of `LatentsDomainGroupDT`.
+194Keys are frozenset of domains matched in the group.
+195Each group is independent and contains different data (unpaired).
+196
+197Example:
+198    ```python
+199    def fun() -> LatentsDomainGroupsDT:
+200        return {
+201            frozenset(["vision"]): {
+202                "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+203            },
+204            frozenset(["language"]): {
+205                "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+206            },
+207            frozenset(["vision", "language"]): {
+208                "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+209                "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+210            },
+211        }
+212
+213    ```
+214
+215Note:
+216    This type uses `dict`s and is used for functions' outputs.
+217    Use `LatentsDomainGroupT` for functions' inputs.
+218"""
+219
+220
+221ModelModeT = Literal["train", "val", "test", "val/ood", "test/ood"]
+222"""
+223Mode used by pytorch lightning (train/val, ...).
+224
+225When validating or testing in out-of-distribution data, "val/ood" or "test/ood" mode is
+226used.
+227"""
+
+ + +
+
+
+ RawDomainGroupT = +collections.abc.Mapping[str, typing.Any] + + +
+ + +

Matched raw unimodal data from multiple domains. +Keys of the mapping are domains names and values are the domain data.

+ +

All values in the mapping should be matched and represent the same information.

+ +
Example:
+ +
+
+
def fun(domain_group: RawDomainGroupT): ...
+
+
+x = {
+    "vision": PIL.Image.Image("path/to/dog/picture.png"),
+    "language": "This is a picture of a dog.",
+}
+
+fun(x)
+
+
+
+ +
Note:
+ +
+

This type uses collections.abc.Mapping and is used for functions' inputs. + Use RawDomainGroupDT for functions' outputs.

+ +

This allows to be more generic and allow passing other mappings.

+
+
+ + +
+
+
+ RawDomainGroupDT = +dict[str, typing.Any] + + +
+ + +

Output type version of RawDomainGroupT. +Matched raw unimodal data from multiple domains. +Keys of the mapping are domains names and values are the domain data.

+ +
Example:
+ +
+
+
def fun() -> RawDomainGroupDT:
+    return {
+        "vision": PIL.Image.Image("path/to/dog/picture.png"),
+        "language": "This is a picture of a dog.",
+    }
+
+
+
+ +
Note:
+ +
+

This type uses dicts and is used for functions' outputs. + Use RawDomainGroupT for functions' inputs.

+
+
+ + +
+
+
+ LatentsDomainGroupT = +collections.abc.Mapping[str, torch.Tensor] + + +
+ + +

Matched unimodal latent representations from multiple domains. +Keys of the mapping are domains names and values are torch.Tensor latent +representation of the domain.

+ +
Example:
+ +
+
+
def fun(domain_group: LatentsDomainGroupT): ...
+
+
+x = {
+    "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+    "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+}
+
+fun(x)
+
+
+
+ +
Note:
+ +
+

This type uses collections.abc.Mapping and is used for functions' inputs. + Use LatentsDomainGroupDT for functions' outputs.

+ +

This allows to be more generic and allow passing other mappings.

+
+
+ + +
+
+
+ LatentsDomainGroupDT = +dict[str, torch.Tensor] + + +
+ + +

Matched unimodal latent representations from multiple domains. +Keys of the dict are domains names and values are torch.Tensor latent +representation of the domain.

+ +
Example:
+ +
+
+
def fun() -> LatentsDomainGroupDT:
+    return {
+        "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+        "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+    }
+
+
+
+ +
Note:
+ +
+

This type uses dicts and is used for functions' outputs. + Use LatentsDomainGroupT for functions' inputs.

+
+
+ + +
+
+
+ RawDomainGroupsT = +collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]] + + +
+ + +

Mapping of RawDomainGroupT. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).

+ +
Example:
+ +
+
+
def fun() -> RawDomainGroupsDT:
+    return {
+        frozenset(["vision"]): {
+            "vision": PIL.Image.Image("path/to/cat/picture.png"),
+        },
+        frozenset(["language"]): {
+            "language": "This is a picture of a rabbit.",
+        },
+        frozenset(["vision", "language"]): {
+            "vision": PIL.Image.Image("path/to/dog/picture.png"),
+            "language": "This is a picture of a dog.",
+        },
+    }
+
+
+
+ +
Note:
+ +
+

This type uses dicts and is used for functions' outputs. + Use RawDomainGroupsT for functions' inputs.

+
+
+ + +
+
+
+ RawDomainGroupsDT = +dict[frozenset[str], dict[str, typing.Any]] + + +
+ + +

Mapping of RawDomainGroupT. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).

+ +
Example:
+ +
+
+
def fun() -> RawDomainGroupsDT:
+    return {
+        frozenset(["vision"]): {
+            "vision": PIL.Image.Image("path/to/cat/picture.png"),
+        },
+        frozenset(["language"]): {
+            "language": "This is a picture of a rabbit.",
+        },
+        frozenset(["vision", "language"]): {
+            "vision": PIL.Image.Image("path/to/dog/picture.png"),
+            "language": "This is a picture of a dog.",
+        },
+    }
+
+
+
+ +
Note:
+ +
+

This type uses dicts and is used for functions' outputs. + Use RawDomainGroupsT for functions' inputs.

+
+
+ + +
+
+
+ LatentsDomainGroupsT = +collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]] + + +
+ + +

Mapping of LatentsDomainGroupT. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).

+ +
Example:
+ +
+
+
def fun(domain_group: LatentsDomainGroupsT): ...
+
+
+x = {
+    frozenset(["vision"]): {
+        "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+    },
+    frozenset(["language"]): {
+        "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+    },
+    frozenset(["vision", "language"]): {
+        "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+        "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+    },
+}
+
+fun(x)
+
+
+
+ +
Note:
+ +
+

This type uses collections.abc.Mapping and is used for functions' inputs. + Use LatentsDomainGroupsDT for functions' outputs.

+ +

This allows to be more generic and allow passing other mappings.

+
+
+ + +
+
+
+ LatentsDomainGroupsDT = +dict[frozenset[str], dict[str, torch.Tensor]] + + +
+ + +

Mapping of LatentsDomainGroupDT. +Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).

+ +
Example:
+ +
+
+
def fun() -> LatentsDomainGroupsDT:
+    return {
+        frozenset(["vision"]): {
+            "vision": torch.Tensor([1.0, 0.0, 0.3, ...]),
+        },
+        frozenset(["language"]): {
+            "language": torch.Tensor([1.0, 0.2, 0.9, ...]),
+        },
+        frozenset(["vision", "language"]): {
+            "vision": torch.Tensor([0.0, 1.0, 0.0, ...]),
+            "language": torch.Tensor([0.0, 0.3, 0.2, ...]),
+        },
+    }
+
+
+
+ +
Note:
+ +
+

This type uses dicts and is used for functions' outputs. + Use LatentsDomainGroupT for functions' inputs.

+
+
+ + +
+
+
+ ModelModeT = +typing.Literal['train', 'val', 'test', 'val/ood', 'test/ood'] + + +
+ + +

Mode used by pytorch lightning (train/val, ...).

+ +

When validating or testing in out-of-distribution data, "val/ood" or "test/ood" mode is +used.

+
+ + +
+
+ + \ No newline at end of file diff --git a/docs/api/v0.5.1/shimmer/utils.html b/docs/api/v0.5.1/shimmer/utils.html new file mode 100644 index 00000000..87a2b3bd --- /dev/null +++ b/docs/api/v0.5.1/shimmer/utils.html @@ -0,0 +1,701 @@ + + + + + + + shimmer.utils API documentation + + + + + + + + + + + + + +
+
+

+shimmer.utils

+ + + + + + +
 1from os import PathLike
+ 2from pathlib import Path
+ 3from typing import Any
+ 4
+ 5import torch
+ 6from lightning.pytorch import Callback, LightningModule, Trainer
+ 7from migrate_ckpt import (
+ 8    ckpt_migration_key,
+ 9    get_folder_migrations,
+10    migrate_from_folder,
+11)
+12
+13from shimmer.types import LatentsDomainGroupsT, LatentsDomainGroupT
+14
+15MIGRATION_DIR = Path(__file__).parent / "ckpt_migrations"
+16
+17
+18def group_batch_size(x: LatentsDomainGroupT) -> int:
+19    for val in x.values():
+20        return val.size(0)
+21    raise ValueError("Got empty group.")
+22
+23
+24def groups_batch_size(domain_latents: LatentsDomainGroupsT) -> int:
+25    """
+26    Get the batch size of the batch.
+27
+28    Args:
+29        domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+30
+31    Returns:
+32        int: the batch size.
+33    """
+34    for data in domain_latents.values():
+35        for tensor in data.values():
+36            return tensor.size(0)
+37    raise ValueError("Empty batch.")
+38
+39
+40def groups_device(domain_latents: LatentsDomainGroupsT) -> int:
+41    """
+42    Get the batch size of the batch.
+43
+44    Args:
+45        domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+46
+47    Returns:
+48        int: the batch size.
+49    """
+50    for data in domain_latents.values():
+51        for tensor in data.values():
+52            return tensor.size(0)
+53    raise ValueError("Empty batch.")
+54
+55
+56def group_device(x: LatentsDomainGroupT) -> torch.device:
+57    for val in x.values():
+58        return val.device
+59    raise ValueError("Got empty group.")
+60
+61
+62def migrate_model(ckpt_path: str | PathLike, **torch_load_kwargs):
+63    """
+64    Migrates a model checkpoint
+65
+66    After the migration, the given checkpoint will be migrated.
+67    Other versions of the checkpoint will be saved under the stem-version.suffix.
+68
+69    Args:
+70        ckpt_path (`str | PathLike`):  path to checkpoint
+71        torch_load_kwargs: additional args given to torch.load.
+72    """
+73    ckpt_path = Path(ckpt_path)
+74    ckpt = torch.load(ckpt_path, **torch_load_kwargs)
+75    new_ckpt, done_migrations = migrate_from_folder(ckpt, MIGRATION_DIR)
+76    done_migration_log = ", ".join(map(lambda x: x.name, done_migrations))
+77    print(f"Migrating: {done_migration_log}")
+78    if len(done_migrations) or ckpt_migration_key not in ckpt:
+79        version = 0
+80        if ckpt_migration_key in ckpt:
+81            version = len(ckpt[ckpt_migration_key])
+82        torch.save(ckpt, ckpt_path.with_stem(f"{ckpt_path.stem}-{version}"))
+83        torch.save(new_ckpt, ckpt_path)
+84
+85
+86class SaveMigrations(Callback):
+87    def __init__(self):
+88        self.migrations = get_folder_migrations(MIGRATION_DIR)
+89
+90    def on_save_checkpoint(
+91        self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]
+92    ):
+93        checkpoint[ckpt_migration_key] = [mig.name for mig in self.migrations]
+
+ + +
+
+
+ MIGRATION_DIR = +PosixPath('/home/runner/work/shimmer/shimmer/shimmer/ckpt_migrations') + + +
+ + + + +
+
+ +
+ + def + group_batch_size(x: collections.abc.Mapping[str, torch.Tensor]) -> int: + + + +
+ +
19def group_batch_size(x: LatentsDomainGroupT) -> int:
+20    for val in x.values():
+21        return val.size(0)
+22    raise ValueError("Got empty group.")
+
+ + + + +
+
+ +
+ + def + groups_batch_size( domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> int: + + + +
+ +
25def groups_batch_size(domain_latents: LatentsDomainGroupsT) -> int:
+26    """
+27    Get the batch size of the batch.
+28
+29    Args:
+30        domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+31
+32    Returns:
+33        int: the batch size.
+34    """
+35    for data in domain_latents.values():
+36        for tensor in data.values():
+37            return tensor.size(0)
+38    raise ValueError("Empty batch.")
+
+ + +

Get the batch size of the batch.

+ +
Arguments:
+ +
    +
  • domain_latents (LatentsDomainGroupsT): the batch of groups.
  • +
+ +
Returns:
+ +
+

int: the batch size.

+
+
+ + +
+
+ +
+ + def + groups_device( domain_latents: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> int: + + + +
+ +
41def groups_device(domain_latents: LatentsDomainGroupsT) -> int:
+42    """
+43    Get the batch size of the batch.
+44
+45    Args:
+46        domain_latents (`LatentsDomainGroupsT`): the batch of groups.
+47
+48    Returns:
+49        int: the batch size.
+50    """
+51    for data in domain_latents.values():
+52        for tensor in data.values():
+53            return tensor.size(0)
+54    raise ValueError("Empty batch.")
+
+ + +

Get the batch size of the batch.

+ +
Arguments:
+ +
    +
  • domain_latents (LatentsDomainGroupsT): the batch of groups.
  • +
+ +
Returns:
+ +
+

int: the batch size.

+
+
+ + +
+
+ +
+ + def + group_device(x: collections.abc.Mapping[str, torch.Tensor]) -> torch.device: + + + +
+ +
57def group_device(x: LatentsDomainGroupT) -> torch.device:
+58    for val in x.values():
+59        return val.device
+60    raise ValueError("Got empty group.")
+
+ + + + +
+
+ +
+ + def + migrate_model(ckpt_path: str | os.PathLike, **torch_load_kwargs): + + + +
+ +
63def migrate_model(ckpt_path: str | PathLike, **torch_load_kwargs):
+64    """
+65    Migrates a model checkpoint
+66
+67    After the migration, the given checkpoint will be migrated.
+68    Other versions of the checkpoint will be saved under the stem-version.suffix.
+69
+70    Args:
+71        ckpt_path (`str | PathLike`):  path to checkpoint
+72        torch_load_kwargs: additional args given to torch.load.
+73    """
+74    ckpt_path = Path(ckpt_path)
+75    ckpt = torch.load(ckpt_path, **torch_load_kwargs)
+76    new_ckpt, done_migrations = migrate_from_folder(ckpt, MIGRATION_DIR)
+77    done_migration_log = ", ".join(map(lambda x: x.name, done_migrations))
+78    print(f"Migrating: {done_migration_log}")
+79    if len(done_migrations) or ckpt_migration_key not in ckpt:
+80        version = 0
+81        if ckpt_migration_key in ckpt:
+82            version = len(ckpt[ckpt_migration_key])
+83        torch.save(ckpt, ckpt_path.with_stem(f"{ckpt_path.stem}-{version}"))
+84        torch.save(new_ckpt, ckpt_path)
+
+ + +

Migrates a model checkpoint

+ +

After the migration, the given checkpoint will be migrated. +Other versions of the checkpoint will be saved under the stem-version.suffix.

+ +
Arguments:
+ +
    +
  • ckpt_path (str | PathLike): path to checkpoint
  • +
  • torch_load_kwargs: additional args given to torch.load.
  • +
+
+ + +
+
+ +
+ + class + SaveMigrations(lightning.pytorch.callbacks.callback.Callback): + + + +
+ +
87class SaveMigrations(Callback):
+88    def __init__(self):
+89        self.migrations = get_folder_migrations(MIGRATION_DIR)
+90
+91    def on_save_checkpoint(
+92        self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]
+93    ):
+94        checkpoint[ckpt_migration_key] = [mig.name for mig in self.migrations]
+
+ + +

Abstract base class used to build new callbacks.

+ +

Subclass this class and override any of the relevant hooks

+
+ + +
+
+ migrations + + +
+ + + + +
+
+ +
+ + def + on_save_checkpoint( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, checkpoint: dict[str, typing.Any]): + + + +
+ +
91    def on_save_checkpoint(
+92        self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]
+93    ):
+94        checkpoint[ckpt_migration_key] = [mig.name for mig in self.migrations]
+
+ + +

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

+ +
Arguments:
+ +
    +
  • trainer: the current ~lightning.pytorch.trainer.trainer.Trainer instance.
  • +
  • pl_module: the current ~lightning.pytorch.core.LightningModule instance.
  • +
  • checkpoint: the checkpoint dictionary that will be saved.
  • +
+
+ + +
+
+
Inherited Members
+
+
lightning.pytorch.callbacks.callback.Callback
+
state_key
+
setup
+
teardown
+
on_fit_start
+
on_fit_end
+
on_sanity_check_start
+
on_sanity_check_end
+
on_train_batch_start
+
on_train_batch_end
+
on_train_epoch_start
+
on_train_epoch_end
+
on_validation_epoch_start
+
on_validation_epoch_end
+
on_test_epoch_start
+
on_test_epoch_end
+
on_predict_epoch_start
+
on_predict_epoch_end
+
on_validation_batch_start
+
on_validation_batch_end
+
on_test_batch_start
+
on_test_batch_end
+
on_predict_batch_start
+
on_predict_batch_end
+
on_train_start
+
on_train_end
+
on_validation_start
+
on_validation_end
+
on_test_start
+
on_test_end
+
on_predict_start
+
on_predict_end
+
on_exception
+
state_dict
+
load_state_dict
+
on_load_checkpoint
+
on_before_backward
+
on_after_backward
+
on_before_optimizer_step
+
on_before_zero_grad
+ +
+
+
+
+
+ + \ No newline at end of file