Skip to content

Commit

Permalink
Autoregressive mode and embedding calculation addition
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Oct 3, 2023
1 parent 4df4837 commit 0de926c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
3 changes: 3 additions & 0 deletions dalm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def train_retriever_only(
] = True,
use_peft: Annotated[bool, typer.Option(help="Whether to use Peft during fine-tuning.")] = True,
use_bnb: Annotated[bool, typer.Option(help="Whether to use BNB during fine-tuning.")] = True,
is_autoregressive: Annotated[bool, typer.Option(help="Whether the model is autoregressive.")] = False,
) -> None:
"""Train only the retriever using contrastive training"""
train_retriever(
Expand Down Expand Up @@ -387,6 +388,7 @@ def eval_retriever(
TorchDtype, typer.Option(help="torch.dtype to use for tensors. float16 or bfloat16.")
] = TorchDtype.float16,
top_k: Annotated[int, typer.Option(help="Top K retrieval")] = 10,
is_autoregressive: Annotated[bool, typer.Option(help="Whether the model is autoregressive.")] = False,
) -> None:
"""Evaluate your retriever only"""
evaluate_retriever(
Expand All @@ -401,6 +403,7 @@ def eval_retriever(
device=device,
torch_dtype=torch_dtype.value,
top_k=top_k,
is_autoregressive=is_autoregressive,
)


Expand Down
5 changes: 4 additions & 1 deletion dalm/eval/eval_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ def evaluate_retriever(
device: str = "cuda",
torch_dtype: Literal["float16", "bfloat16"] = "float16",
top_k: int = 10,
is_autoregressive: bool = False,
) -> None:
"""Runs rag evaluation. See `dalm eval-retriever --help for details on params"""
test_dataset = load_dataset(dataset_or_path)
selected_torch_dtype: Final[torch.dtype] = torch.float16 if torch_dtype == "float16" else torch.bfloat16
retriever_model = AutoModelForSentenceEmbedding(retriever_name_or_path, get_peft=False, use_bnb=False)
retriever_model = AutoModelForSentenceEmbedding(
retriever_name_or_path, get_peft=False, use_bnb=False, is_autoregressive=is_autoregressive
)
retriever_tokenizer = retriever_model.tokenizer

processed_datasets = preprocess_dataset(
Expand Down
16 changes: 12 additions & 4 deletions dalm/models/retriever_only_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
normalize: bool = True,
use_bnb: bool = True,
get_peft: bool = True,
is_autoregressive: bool = False,
) -> None:
super(AutoModelForSentenceEmbedding, self).__init__()

Expand All @@ -30,18 +31,25 @@ def __init__(
)

self.normalize = normalize
self.is_autoregressive = is_autoregressive
self.tokenizer = AutoTokenizer.from_pretrained(model_name)

def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
model_output = self.model(input_ids, attention_mask)
embeddings = self.mean_pooling(model_output, attention_mask)
if self.is_autoregressive:
# we take the last hidden state of the model
token_embeddings = self.model.sample(
input_ids, attention_mask, output_hidden_states=True, return_dict_in_generate=True
).hidden_states[-1]
else:
# First element of model_output contains all token embeddings
token_embeddings = self.model(input_ids, attention_mask)[0]
embeddings = self.mean_pooling(token_embeddings, attention_mask)
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

return embeddings

def mean_pooling(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
def mean_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

Expand Down
11 changes: 10 additions & 1 deletion dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def parse_args() -> Namespace:
action="store_true",
help="Whether to use model quantization.",
)
parser.add_argument(
"--is_autoregressive",
action="store_true",
help="Whether model is an auto-regressive model/clm ",
)
args = parser.parse_args()

return args
Expand Down Expand Up @@ -195,6 +200,7 @@ def train_retriever(
sanity_test: bool = True,
use_peft: bool = True,
use_bnb: bool = True,
is_autoregressive: bool = False,
) -> None:
# Get the passed in vars before beginning training, in case we report training
args = dict(locals())
Expand Down Expand Up @@ -227,7 +233,9 @@ def train_retriever(
os.makedirs(output_dir, exist_ok=True)
accelerator.wait_for_everyone()

model = AutoModelForSentenceEmbedding(retriever_name_or_path, use_bnb=use_bnb, get_peft=use_peft)
model = AutoModelForSentenceEmbedding(
retriever_name_or_path, use_bnb=use_bnb, get_peft=use_peft, is_autoregressive=is_autoregressive
)
tokenizer = model.tokenizer

# dataset download and preprocessing
Expand Down Expand Up @@ -443,6 +451,7 @@ def main() -> None:
sanity_test=args.sanity_test,
use_peft=args.use_peft,
use_bnb=args.use_bnb,
is_autoregressive=args.is_autoregressive,
)


Expand Down

0 comments on commit 0de926c

Please sign in to comment.