Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Unify supernet registry #29

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions dynast/supernetwork/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict


class SupernetRegistryHolder(type):

_REGISTRY: Dict[str, "SupernetRegistryHolder"] = {}

def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
print(new_cls.__name__)
cls._REGISTRY[new_cls._name] = new_cls()
return new_cls

@classmethod
def get_registry(cls):
return {
k: v for k, v in cls._REGISTRY.items() if k not in ['SupernetBaseRegisteredClass', 'dynast.supernetwork']
}


class SupernetBaseRegisteredClass(metaclass=SupernetRegistryHolder):
_name = __name__
_encoding = None
_parameters = None
_evaluation_interface = None
_linas_innerloop_evals = None
_supernet_type = None
_supernet_metrics = None

@property
def encoding(self):
return self._encoding

@property
def parameters(self):
return self._parameters

@property
def evaluation_interface(self):
return self._evaluation_interface

@property
def linas_innerloop_evals(self):
return self._linas_innerloop_evals

@property
def supernet_type(self):
return self._supernet_type

@property
def supernet_metrics(self):
return self._supernet_metrics

def __str__(self) -> str:
return self._name
43 changes: 43 additions & 0 deletions dynast/supernetwork/image_classification/ofa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dynast.supernetwork import SupernetBaseRegisteredClass
from dynast.supernetwork.image_classification.ofa.ofa_encoding import OFAMobileNetV3Encoding, OFAResNet50Encoding
from dynast.supernetwork.image_classification.ofa.ofa_interface import (
EvaluationInterfaceOFAMobileNetV3,
EvaluationInterfaceOFAResNet50,
)


class OFAResNet50Supernet(SupernetBaseRegisteredClass):
_name = 'ofa_resnet50'
_encoding = OFAResNet50Encoding
_parameters = {
'd': {'count': 5, 'vars': [0, 1, 2]},
'e': {'count': 18, 'vars': [0.2, 0.25, 0.35]},
'w': {'count': 6, 'vars': [0, 1, 2]},
}
_evaluation_interface = EvaluationInterfaceOFAResNet50
_linas_innerloop_evals = 5000
_supernet_type = 'image_classification'
_supernet_metrics = ['params', 'latency', 'macs', 'accuracy_top1']


class OFAMBv3_d234_e346_k357_w10_Supernet(SupernetBaseRegisteredClass):
_name = 'ofa_mbv3_d234_e346_k357_w1.0'
_encoding = OFAMobileNetV3Encoding
_parameters = {
'ks': {'count': 20, 'vars': [3, 5, 7]},
'e': {'count': 20, 'vars': [3, 4, 6]},
'd': {'count': 5, 'vars': [2, 3, 4]},
}
_evaluation_interface = EvaluationInterfaceOFAMobileNetV3
_linas_innerloop_evals = 20000
_supernet_type = 'image_classification'
_supernet_metrics = ['params', 'latency', 'macs', 'accuracy_top1']


class OFAMBv3_d234_e346_k357_w12_Supernet(OFAMBv3_d234_e346_k357_w10_Supernet):
_name = 'ofa_mbv3_d234_e346_k357_w1.2'


class OFAProxyless_d234_e346_k357_w13_Supernet(OFAMBv3_d234_e346_k357_w10_Supernet):
_name = 'ofa_proxyless_d234_e346_k357_w1.3'
25 changes: 25 additions & 0 deletions dynast/supernetwork/machine_translation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dynast.supernetwork import SupernetBaseRegisteredClass
from dynast.supernetwork.machine_translation.transformer_encoding import TransformerLTEncoding
from dynast.supernetwork.machine_translation.transformer_interface import EvaluationInterfaceTransformerLT


class TransformerLT_WMT_en_de(SupernetBaseRegisteredClass):
_name = 'transformer_lt_wmt_en_de'
_encoding = TransformerLTEncoding
_parameters = {
'encoder_embed_dim': {'count': 1, 'vars': [640, 512]},
'decoder_embed_dim': {'count': 1, 'vars': [640, 512]},
'encoder_ffn_embed_dim': {'count': 6, 'vars': [3072, 2048, 1024]},
'decoder_ffn_embed_dim': {'count': 6, 'vars': [3072, 2048, 1024]},
'decoder_layer_num': {'count': 1, 'vars': [6, 5, 4, 3, 2, 1]},
'encoder_self_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_self_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_ende_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_arbitrary_ende_attn': {'count': 6, 'vars': [-1, 1, 2]},
}
_evaluation_interface = EvaluationInterfaceTransformerLT
_linas_innerloop_evals = 10000
_supernet_type = 'machine_translation'
_supernet_metrics = ['latency', 'macs', 'params', 'bleu']
90 changes: 6 additions & 84 deletions dynast/supernetwork/supernetwork_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from dynast.supernetwork import SupernetBaseRegisteredClass, SupernetRegistryHolder
from dynast.supernetwork.image_classification.ofa import OFAResNet50Supernet
from dynast.supernetwork.image_classification.ofa.ofa_encoding import OFAMobileNetV3Encoding, OFAResNet50Encoding
from dynast.supernetwork.image_classification.ofa.ofa_interface import (
EvaluationInterfaceOFAMobileNetV3,
Expand All @@ -23,92 +25,12 @@
from dynast.supernetwork.text_classification.bert_encoding import BertSST2Encoding
from dynast.supernetwork.text_classification.bert_interface import EvaluationInterfaceBertSST2

SUPERNET_ENCODING = {
'ofa_resnet50': OFAResNet50Encoding,
'ofa_mbv3_d234_e346_k357_w1.0': OFAMobileNetV3Encoding,
'ofa_mbv3_d234_e346_k357_w1.2': OFAMobileNetV3Encoding,
'ofa_proxyless_d234_e346_k357_w1.3': OFAMobileNetV3Encoding,
'transformer_lt_wmt_en_de': TransformerLTEncoding,
'bert_base_sst2': BertSST2Encoding,
}

SUPERNET_PARAMETERS = {
'ofa_resnet50': {
'd': {'count': 5, 'vars': [0, 1, 2]},
'e': {'count': 18, 'vars': [0.2, 0.25, 0.35]},
'w': {'count': 6, 'vars': [0, 1, 2]},
},
'ofa_mbv3_d234_e346_k357_w1.0': {
'ks': {'count': 20, 'vars': [3, 5, 7]},
'e': {'count': 20, 'vars': [3, 4, 6]},
'd': {'count': 5, 'vars': [2, 3, 4]},
},
'ofa_mbv3_d234_e346_k357_w1.2': {
'ks': {'count': 20, 'vars': [3, 5, 7]},
'e': {'count': 20, 'vars': [3, 4, 6]},
'd': {'count': 5, 'vars': [2, 3, 4]},
},
'ofa_proxyless_d234_e346_k357_w1.3': {
'ks': {'count': 20, 'vars': [3, 5, 7]},
'e': {'count': 20, 'vars': [3, 4, 6]},
'd': {'count': 5, 'vars': [2, 3, 4]},
},
'transformer_lt_wmt_en_de': {
'encoder_embed_dim': {'count': 1, 'vars': [640, 512]},
'decoder_embed_dim': {'count': 1, 'vars': [640, 512]},
'encoder_ffn_embed_dim': {'count': 6, 'vars': [3072, 2048, 1024]},
'decoder_ffn_embed_dim': {'count': 6, 'vars': [3072, 2048, 1024]},
'decoder_layer_num': {'count': 1, 'vars': [6, 5, 4, 3, 2, 1]},
'encoder_self_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_self_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_ende_attention_heads': {'count': 6, 'vars': [8, 4]},
'decoder_arbitrary_ende_attn': {'count': 6, 'vars': [-1, 1, 2]},
},
'bert_base_sst2': {
'num_layers': {'count': 1, 'vars': [6, 7, 8, 9, 10, 11, 12]},
'num_attention_heads': {'count': 12, 'vars': [6, 8, 10, 12]},
'intermediate_size': {'count': 12, 'vars': [1024, 2048, 3072]},
},
}

EVALUATION_INTERFACE = {
'ofa_resnet50': EvaluationInterfaceOFAResNet50,
'ofa_mbv3_d234_e346_k357_w1.0': EvaluationInterfaceOFAMobileNetV3,
'ofa_mbv3_d234_e346_k357_w1.2': EvaluationInterfaceOFAMobileNetV3,
'ofa_proxyless_d234_e346_k357_w1.3': EvaluationInterfaceOFAMobileNetV3,
'transformer_lt_wmt_en_de': EvaluationInterfaceTransformerLT,
'bert_base_sst2': EvaluationInterfaceBertSST2,
}

LINAS_INNERLOOP_EVALS = {
'ofa_resnet50': 5000,
'ofa_mbv3_d234_e346_k357_w1.0': 20000,
'ofa_mbv3_d234_e346_k357_w1.2': 20000,
'ofa_proxyless_d234_e346_k357_w1.3': 20000,
'transformer_lt_wmt_en_de': 10000,
'bert_base_sst2': 20000,
}
def main():
print([s for s in SupernetRegistryHolder.get_registry()])

SUPERNET_TYPE = {
'image_classification': [
'ofa_resnet50',
'ofa_mbv3_d234_e346_k357_w1.0',
'ofa_mbv3_d234_e346_k357_w1.2',
'ofa_proxyless_d234_e346_k357_w1.3',
],
'machine_translation': ['transformer_lt_wmt_en_de'],
'text_classification': ['bert_base_sst2'],
'recommendation': [],
}

SUPERNET_METRICS = {
'ofa_resnet50': ['params', 'latency', 'macs', 'accuracy_top1'],
'ofa_mbv3_d234_e346_k357_w1.0': ['params', 'latency', 'macs', 'accuracy_top1'],
'ofa_mbv3_d234_e346_k357_w1.2': ['params', 'latency', 'macs', 'accuracy_top1'],
'ofa_proxyless_d234_e346_k357_w1.3': ['params', 'latency', 'macs', 'accuracy_top1'],
'transformer_lt_wmt_en_de': ['latency', 'macs', 'params', 'bleu'],
'bert_base_sst2': ['latency', 'macs', 'params', 'accuracy_sst2'],
}


SEARCH_ALGORITHMS = ['linas', 'evolutionary', 'random']
if __name__ == '__main__':
main()
19 changes: 19 additions & 0 deletions dynast/supernetwork/text_classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dynast.supernetwork import SupernetBaseRegisteredClass
from dynast.supernetwork.text_classification.bert_encoding import BertSST2Encoding
from dynast.supernetwork.text_classification.bert_interface import EvaluationInterfaceBertSST2


class BertBase_SST2(SupernetBaseRegisteredClass):
_name = 'bert_base_sst2'
_encoding = BertSST2Encoding
_parameters = {
'num_layers': {'count': 1, 'vars': [6, 7, 8, 9, 10, 11, 12]},
'num_attention_heads': {'count': 12, 'vars': [6, 8, 10, 12]},
'intermediate_size': {'count': 12, 'vars': [1024, 2048, 3072]},
}
_evaluation_interface = EvaluationInterfaceBertSST2
_linas_innerloop_evals = 20000
_supernet_type = 'text_classification'
_supernet_metrics = ['latency', 'macs', 'params', 'accuracy_sst2']