Skip to content

Commit

Permalink
improve cie
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Aug 1, 2024
1 parent 2b86172 commit a164550
Show file tree
Hide file tree
Showing 9 changed files with 694 additions and 62 deletions.
Empty file added relik/cli/__init__.py
Empty file.
643 changes: 643 additions & 0 deletions relik/cli/data.py

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions relik/cli/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from relik.cli.utils import resolve_config
from relik.common.log import get_logger, print_relik_text_art
from relik.reader.trainer.train import train as reader_train
from relik.reader.trainer.train_cie import train as reader_train_cie

logger = get_logger(__name__)

Expand Down Expand Up @@ -44,3 +45,35 @@ def _reader_train(conf):
sys.argv.extend(overrides)

_reader_train()

@app.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
def train_cie():
"""
Trains the reader model.
This function prints the Relik text art, resolves the configuration file path,
and then calls the `_reader_train` function to train the reader model.
Args:
None
Returns:
None
"""
print_relik_text_art()
config_dir, config_name, overrides = resolve_config("reader")

@hydra.main(
config_path=str(config_dir),
config_name=str(config_name),
version_base="1.3",
)
def _reader_train_cie(conf):
reader_train_cie(conf)

# clean sys.argv for hydra
sys.argv = sys.argv[:1]
# add the overrides to sys.argv
sys.argv.extend(overrides)

_reader_train_cie()
15 changes: 15 additions & 0 deletions relik/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def resolve_config(type: str | None = None) -> OmegaConf:
# cfg = compose(config_name=config_name, overrides=overrides)

return config_dir, config_name, overrides


def int_or_str_typer(value: str) -> int | None:
"""
Converts a string value to an integer or None.
Args:
value (str): The string value to be converted.
Returns:
int | None: The converted integer value or None if the input is "None".
"""
if value == "None":
return None
return int(value)
53 changes: 0 additions & 53 deletions relik/common/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,6 @@
import torch
import transformers as tr

from relik.common.utils import is_package_available

# check if ORT is available
if is_package_available("onnxruntime"):
from optimum.onnxruntime import (
ORTModel,
ORTModelForCustomTasks,
ORTModelForSequenceClassification,
ORTOptimizer,
)
from optimum.onnxruntime.configuration import AutoOptimizationConfig

# from relik.retriever.pytorch_modules import PRECISION_MAP


def get_autocast_context(
device: str | torch.device, precision: str
Expand All @@ -41,42 +27,3 @@ def get_autocast_context(
)
)
return autocast_manager


# def load_ort_optimized_hf_model(
# hf_model: tr.PreTrainedModel,
# provider: str = "CPUExecutionProvider",
# ort_model_type: callable = "ORTModelForCustomTasks",
# ) -> ORTModel:
# """
# Load an optimized ONNX Runtime HF model.
#
# Args:
# hf_model (`tr.PreTrainedModel`):
# The HF model to optimize.
# provider (`str`, optional):
# The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider".
#
# Returns:
# `ORTModel`: The optimized HF model.
# """
# if isinstance(hf_model, ORTModel):
# return hf_model
# temp_dir = tempfile.mkdtemp()
# hf_model.save_pretrained(temp_dir)
# ort_model = ort_model_type.from_pretrained(
# temp_dir, export=True, provider=provider, use_io_binding=True
# )
# if is_package_available("onnxruntime"):
# optimizer = ORTOptimizer.from_pretrained(ort_model)
# optimization_config = AutoOptimizationConfig.O4()
# optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config)
# ort_model = ort_model_type.from_pretrained(
# temp_dir,
# export=True,
# provider=provider,
# use_io_binding=bool(provider == "CUDAExecutionProvider"),
# )
# return ort_model
# else:
# raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.")
2 changes: 2 additions & 0 deletions relik/reader/lightning_modules/relik_reader_re_pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
default_reader_class: str = "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel",
*args: Any,
**kwargs: Any
):
Expand All @@ -37,6 +38,7 @@ def __init__(
linears_hidden_size,
use_last_k_layers,
training=training,
default_reader_class=default_reader_class,
**kwargs,
)
self.optimizer_factory = None
Expand Down
1 change: 0 additions & 1 deletion relik/reader/pytorch_modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from relik.common.log import get_logger

# from relik.common.torch_utils import load_ort_optimized_hf_model
from relik.common.utils import get_callable_from_string
from relik.inference.data.objects import AnnotationType

Expand Down
3 changes: 0 additions & 3 deletions relik/retriever/indexers/inmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample


# check if ORT is available
# if is_package_available("onnxruntime"):

logger = get_logger(__name__, level=logging.INFO)


Expand Down
6 changes: 1 addition & 5 deletions relik/retriever/pytorch_modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from relik.common.log import get_logger
from relik.common.torch_utils import (
get_autocast_context,
) # , # load_ort_optimized_hf_model
)
from relik.common.utils import get_callable_from_string, is_package_available, to_config
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
Expand All @@ -27,10 +27,6 @@
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel

# check if ORT is available
if is_package_available("onnxruntime"):
from optimum.onnxruntime import ORTModel

logger = get_logger(__name__, level=logging.INFO)


Expand Down

0 comments on commit a164550

Please sign in to comment.