Skip to content

Commit

Permalink
refactor: add model arch
Browse files Browse the repository at this point in the history
  • Loading branch information
aiwantaozi committed Nov 27, 2024
1 parent 5d85168 commit b558211
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions vox_box/elstimator/funasr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import logging
import os
from typing import Dict
from typing import Dict, Tuple

import yaml
from vox_box.config.config import BackendEnum, Config, TaskTypeEnum
from vox_box.downloader.downloaders import download_model
from vox_box.elstimator.base import Elstimator
Expand All @@ -17,23 +19,24 @@ def __init__(
cfg: Config,
):
self._cfg = cfg
self._optional_files = ["configuration.json", "config.json"]
self._optional_files = ["configuration.json", "config.json", "config.yaml"]

def model_info(self) -> Dict:
model = (
self._cfg.model
or self._cfg.huggingface_repo_id
or self._cfg.model_scope_model_id
)
supported = self._supported()
supported, model_architecture = self._supported()
return create_model_dict(
model,
supported=supported,
task_type=TaskTypeEnum.STT,
backend_framework=BackendEnum.FUN_ASR,
model_architecture=model_architecture,
)

def _supported(self) -> bool:
def _supported(self) -> Tuple[bool, str]:
if self._cfg.model is not None:
return self._check_local_model(self._cfg.model)
elif (
Expand All @@ -42,12 +45,17 @@ def _supported(self) -> bool:
):
return self._check_remote_model()

def _check_local_model(self, base_dir: str) -> bool:
def _check_local_model(self, base_dir: str) -> Tuple[bool, str]:
configuration_json = None
config_json = None
config_yaml = None

configuration_path = os.path.join(base_dir, "configuration.json")
config_json_path = os.path.join(base_dir, "config.json")
config_yaml_path = os.path.join(base_dir, "config.yaml")

supported = False
model_architecture = ""

if os.path.exists(configuration_path):
with open(configuration_path, "r", encoding="utf-8") as f:
Expand All @@ -57,22 +65,29 @@ def _check_local_model(self, base_dir: str) -> bool:
with open(config_json_path, "r", encoding="utf-8") as f:
config_json = json.load(f)

if os.path.exists(config_yaml_path):
with open(config_yaml_path, "r", encoding="utf-8") as f:
config_yaml = yaml.safe_load(f)

if configuration_json is not None:
task = configuration_json.get("task", "")
model_type = configuration_json.get("model", {}).get("type", "")
if task == "auto-speech-recognition" and model_type == "funasr":
return True
supported = True
if config_yaml is not None:
model_architecture = config_yaml.get("model")

if config_json is not None:
architectures = config_json.get("architectures")
if (architectures is not None and "QWenLMHeadModel" in architectures) and (
config_json.get("audio", {}).get("n_layer", 0) != 0
):
return True
supported = True
model_architecture = "QwenAudio"

return False
return supported, model_architecture

def _check_remote_model(self) -> bool:
def _check_remote_model(self) -> Tuple[bool, str]:
downloaded_files = []
for f in self._optional_files:
try:
Expand Down

0 comments on commit b558211

Please sign in to comment.