Skip to content

Commit

Permalink
Switch head (#180)
Browse files Browse the repository at this point in the history
* support caching of metadata

* upd version

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add tests for switchencoderandhead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Upd version

* `poetry lock` versins

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: M. Yusuf Sarıgöz <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2022
1 parent 760020a commit f2be2a4
Show file tree
Hide file tree
Showing 12 changed files with 378 additions and 518 deletions.
700 changes: 207 additions & 493 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quaterion"
version = "0.1.30"
version = "0.1.31"
description = "Similarity Learning fine-tuning framework"
authors = ["Quaterion Authors <[email protected]>"]
packages = [
Expand All @@ -16,7 +16,7 @@ keywords = ["framework", "similarity-learning", "metric-learning", "similarity",
python = ">=3.8,<3.11"
torch = ">=1.8.2"
pytorch-lightning = "^1.6.4"
quaterion-models = "^0.1.16"
quaterion-models = "0.1.17"
loguru = "^0.5.3"
mmh3 = "^3.0.0"
pytorch-metric-learning = {version = "^1.3.0", optional = true}
Expand Down
16 changes: 13 additions & 3 deletions quaterion/dataset/train_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from typing import Any, Callable, Dict, List, Tuple, Union

from quaterion_models.types import CollateFnType
from quaterion_models.types import CollateFnType, MetaExtractorFnType
from quaterion_models.utils.meta import merge_meta

from quaterion.dataset import SimilarityGroupSample, SimilarityPairSample

Expand All @@ -22,10 +23,14 @@ class TrainCollator:
"""

def __init__(
self, pre_collate_fn: Callable, encoder_collates: Dict[str, CollateFnType]
self,
pre_collate_fn: Callable,
encoder_collates: Dict[str, CollateFnType],
meta_extractors: Dict[str, MetaExtractorFnType],
):
self.pre_collate_fn = pre_collate_fn
self.encoder_collates = encoder_collates
self.meta_extractors = meta_extractors

def pre_encoder_collate(
self, features: List[Any], ids: List[int] = None, encoder_name: str = None
Expand All @@ -35,15 +40,20 @@ def pre_encoder_collate(
"""
return features

def process_meta(self, meta: Dict[str, List]) -> Any:
return merge_meta(meta)

def __call__(
self,
batch: List[Tuple[int, Union[SimilarityPairSample, SimilarityGroupSample]]],
):
ids, features, labels = self.pre_collate_fn(batch)

encoder_collate_result = {}
meta = {}
for encoder_name, collate_fn in self.encoder_collates.items():
encoder_features = self.pre_encoder_collate(features, ids, encoder_name)
encoder_collate_result[encoder_name] = collate_fn(encoder_features)
meta[encoder_name] = self.meta_extractors[encoder_name](encoder_features)

return encoder_collate_result, labels
return {"data": encoder_collate_result, "meta": self.process_meta(meta)}, labels
25 changes: 23 additions & 2 deletions quaterion/train/cache/cache_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Hashable, List, Tuple, Union

from quaterion_models.encoders import Encoder
from quaterion_models.types import CollateFnType, TensorInterchange
from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
from torch import Tensor

KeyExtractorType = Callable[[Any], Hashable]
Expand Down Expand Up @@ -58,6 +58,24 @@ def embedding_size(self) -> int:
"""
return self._encoder.embedding_size

def cache_extract_meta(self, batch: List[Any]) -> List[dict]:
"""Extracts meta information from batch.
Args:
batch: batch of data
Returns:
List[dict]: list of meta information
"""
raise NotImplementedError()

def get_meta_extractor(self) -> MetaExtractorFnType:
"""Provides function that extracts meta information from batch.
Returns:
MetaExtractorFnType: meta extractor function
"""
return self.cache_extract_meta

def cache_collate(
self, batch: Union[Tuple[List[Hashable], List[Any]], List[Hashable]]
) -> "CacheCollateReturnType":
Expand Down Expand Up @@ -120,13 +138,16 @@ def is_filled(self) -> bool:
"""Check if cache already filled"""
raise NotImplementedError()

def fill_cache(self, keys: List[Hashable], data: "TensorInterchange") -> None:
def fill_cache(
self, keys: List[Hashable], data: "TensorInterchange", meta: List[Any]
) -> None:
"""Apply wrapped encoder to data and store processed data on
corresponding device.
Args:
keys: Hash keys which should be associated with resulting vectors
data: Tuple of keys and batches suitable for encoder
meta: List of batch meta information
"""
raise NotImplementedError()
Expand Down
8 changes: 5 additions & 3 deletions quaterion/train/cache/cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def predict_step(
torch.Tensor: loss mock
"""
features, _labels = batch
features_data = features["data"]
features_meta = features["meta"]
for encoder_name, encoder in self.encoders.items():
if encoder_name not in features:
if encoder_name not in features_data:
continue
keys, encoder_features = features.get(encoder_name)
keys, encoder_features = features_data.get(encoder_name)
if len(keys) == 0:
# empty batch possible if all unique object already cached
continue
encoder.fill_cache(keys, encoder_features)
encoder.fill_cache(keys, encoder_features, features_meta[encoder_name])

return torch.Tensor([1])

Expand Down
16 changes: 14 additions & 2 deletions quaterion/train/cache/cache_train_collator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import Any, Dict, Hashable, List

from quaterion_models.types import CollateFnType
from quaterion_models.types import CollateFnType, MetaExtractorFnType

from quaterion.dataset.train_collator import TrainCollator
from quaterion.train.cache.cache_config import KeyExtractorType
Expand All @@ -15,11 +15,12 @@ def __init__(
self,
pre_collate_fn,
encoder_collates: Dict[str, "CollateFnType"],
meta_extractors: Dict[str, "MetaExtractorFnType"],
key_extractors: Dict[str, "KeyExtractorType"],
cachable_encoders: List[str],
mode: CacheMode,
):
super().__init__(pre_collate_fn, encoder_collates)
super().__init__(pre_collate_fn, encoder_collates, meta_extractors)
self.cachable_encoders = cachable_encoders
self.mode = mode
self.key_extractors = key_extractors
Expand Down Expand Up @@ -78,3 +79,14 @@ def pre_encoder_collate(
return keys

raise NotImplementedError(f"Cache mode {self.mode} is not implemented")

def process_meta(self, meta: Dict[str, List]) -> Any:
if self.mode == CacheMode.FILL:
# On the cache fill stage we need to know meta per encoder mapping
# To make proper cache filling
return meta
elif self.mode == CacheMode.TRAIN:
# On the train stage we fall back to the default behavior
return super().process_meta(meta)
else:
raise NotImplementedError(f"Cache mode {self.mode} is not implemented")
28 changes: 24 additions & 4 deletions quaterion/train/cache/in_memory_cache_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import Hashable, List
from typing import Any, Hashable, List, Union

import torch
from quaterion_models.encoders import Encoder
Expand All @@ -23,6 +23,7 @@ def __init__(
):
super().__init__(encoder)
self._cache = None
self._meta_cache = {}
self._offset_map = {}
self._cache_type = cache_type
self._original_device = None
Expand Down Expand Up @@ -53,6 +54,16 @@ def forward(self, batch: "TensorInterchange") -> Tensor:

return embeddings

def cache_extract_meta(self, batch: Union[tuple, List[Hashable]]) -> List[dict]:
if isinstance(batch, tuple):
# Cache filling phase
_keys, features = batch
return self._encoder.get_meta_extractor()(features)
else:
# Assume training phase.
# Only keys are provided here
return [self._meta_cache[key] for key in batch]

def get_collate_fn(self) -> "CollateFnType":
"""Provides function that converts raw data batch into suitable input.
Expand All @@ -65,7 +76,9 @@ def get_collate_fn(self) -> "CollateFnType":
def is_filled(self) -> bool:
return self._cache is not None

def fill_cache(self, keys: List[Hashable], data: "TensorInterchange") -> None:
def fill_cache(
self, keys: List[Hashable], data: "TensorInterchange", meta: List[Any]
) -> None:
embeddings = self._encoder(data)
if self.cache_type == CacheType.CPU:
embeddings = embeddings.to("cpu")
Expand All @@ -76,21 +89,28 @@ def fill_cache(self, keys: List[Hashable], data: "TensorInterchange") -> None:
self._offset_map[key] = len(self._offset_map)
self._tmp.append(embeddings)

for key, meta_item in zip(keys, meta):
self._meta_cache[key] = meta_item

def finish_fill(self):
self._cache = torch.cat(self._tmp)
self._tmp = []

def reset_cache(self) -> None:
"""Resets cache."""
self._cache = None
self._meta_cache = {}
self._offset_map = {}
self._tmp = []

def save_cache(self, path):
pickle.dump([self._cache.to("cpu"), self._offset_map], open(path, "wb"))
pickle.dump(
[self._cache.to("cpu"), self._offset_map, self._meta_cache],
open(path, "wb"),
)

def load_cache(self, path):
self._cache, self._offset_map = pickle.load(open(path, "rb"))
self._cache, self._offset_map, self._meta_cache = pickle.load(open(path, "rb"))
if self.cache_type != CacheType.CPU:
device = self._encoder_device()
self._cache = self._cache.to(device)
3 changes: 3 additions & 0 deletions quaterion/train/cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def _cache(
encoder_collates={
name: encoder.get_collate_fn() for name, encoder in encoders.items()
},
meta_extractors={
name: encoder.get_meta_extractor() for name, encoder in encoders.items()
},
key_extractors=key_extractors,
cachable_encoders=list(cache_encoders.keys()),
mode=CacheMode.TRAIN,
Expand Down
5 changes: 5 additions & 0 deletions quaterion/train/trainable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,15 @@ def setup_dataloader(self, dataloader: SimilarityDataLoader):
(key, encoder.get_collate_fn())
for key, encoder in self.model.encoders.items()
)
meta_extractors = dict(
(key, encoder.get_meta_extractor())
for key, encoder in self.model.encoders.items()
)

collator = TrainCollator(
pre_collate_fn=dataloader.collate_fn,
encoder_collates=encoder_collate_fns,
meta_extractors=meta_extractors,
)

dataloader.collate_fn = collator
17 changes: 11 additions & 6 deletions tests/cache/test_cache_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,20 @@ def test_cache_dataloader():
assert isinstance(encoder, InMemoryCacheEncoder)
assert len(encoder._cache) == len(dataset.data) * 2

cached_ids, labels = next(iter(dataloader))
print("cached_batch: ", cached_ids)
cached_batch, labels = next(iter(dataloader))
cached_data = cached_batch["data"]
cached_meta = cached_batch["meta"]
print("cached_batch: ", cached_data)
print("cached_meta: ", cached_meta)

assert len(cached_data[DEFAULT_ENCODER_KEY]) == len(cached_meta)

# check that batch for cache contains only IDs
assert isinstance(cached_ids[DEFAULT_ENCODER_KEY], list)
assert len(cached_ids[DEFAULT_ENCODER_KEY]) == batch_size * 2
assert isinstance(cached_ids[DEFAULT_ENCODER_KEY][0], int)
assert isinstance(cached_data[DEFAULT_ENCODER_KEY], list)
assert len(cached_data[DEFAULT_ENCODER_KEY]) == batch_size * 2
assert isinstance(cached_data[DEFAULT_ENCODER_KEY][0], int)

cached_result = cache_trainable_model.model.forward(cached_ids)
cached_result = cache_trainable_model.model.forward(cached_batch)
print("cached_result: ", cached_result)

# Same, without cache
Expand Down
52 changes: 49 additions & 3 deletions tests/model_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Union
from typing import Any, Dict, List, Union

import torch
from quaterion_models.encoders import Encoder
from quaterion_models.heads import EncoderHead, GatedHead
from quaterion_models.encoders import Encoder, SwitchEncoder
from quaterion_models.heads import EmptyHead, EncoderHead, GatedHead, SwitchHead
from quaterion_models.types import TensorInterchange
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -30,6 +30,15 @@ def __init__(self):
"mandarin ".strip(): torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 0.0]),
}

@classmethod
def extract_meta(cls, batch: List[Any]) -> List[dict]:
return [
{
"first_letter": name[0],
}
for name in batch
]

@property
def trainable(self) -> bool:
return False
Expand Down Expand Up @@ -94,3 +103,40 @@ def configure_head(self, input_embedding_size: int) -> EncoderHead:

def configure_optimizers(self):
return torch.optim.Adam(params=self.model.parameters(), lr=0.001)


class FakeEncoderWithNegativeOutput(FakeEncoder):
def forward(self, batch: TensorInterchange) -> Tensor:
return -torch.stack([self.tensors[word] for word in batch])

@classmethod
def load(cls, input_path: str) -> "Encoder":
return FakeEncoderWithNegativeOutput()


class FakeSwitchEncoder(SwitchEncoder):
@classmethod
def encoder_selection(cls, record: Any) -> str:
if record.startswith("m"):
return "positive"
else:
return "negative"


class FakeTrainableModelWithSwitchEncoder(FakeTrainableModel):
def configure_encoders(self) -> Union[Encoder, Dict[str, Encoder]]:
return FakeSwitchEncoder(
options={
"positive": FakeEncoder(),
"negative": FakeEncoderWithNegativeOutput(),
}
)

def configure_head(self, input_embedding_size: int) -> EncoderHead:
return SwitchHead(
{
"positive": GatedHead(input_embedding_size),
"negative": GatedHead(input_embedding_size),
},
input_embedding_size,
)
22 changes: 22 additions & 0 deletions tests/test_switch_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytorch_lightning as pl

from quaterion import Quaterion
from quaterion.dataset import PairsSimilarityDataLoader

from .model_fixtures import FakePairDataset, FakeTrainableModelWithSwitchEncoder


class TestSwitchEncoder:
def test_switch_encoder_and_head(self):
model = FakeTrainableModelWithSwitchEncoder()
dataset = FakePairDataset()
data_loader = PairsSimilarityDataLoader(dataset, batch_size=3)
trainer_args = Quaterion.trainer_defaults(model, data_loader)
trainer_args["callbacks"].pop(1) # remove EarlyStopping callback
trainer_args["accelerator"] = "cpu"
trainer_args["max_epochs"] = 1
Quaterion.fit(
trainable_model=model,
trainer=pl.Trainer(**trainer_args),
train_dataloader=data_loader,
)

0 comments on commit f2be2a4

Please sign in to comment.