Skip to content

Commit

Permalink
to_tensor only for requested fileds. Use add_to_tensor_inputs to …
Browse files Browse the repository at this point in the history
…add custom fields.
  • Loading branch information
Riccorl committed Mar 16, 2021
1 parent a2b05a1 commit 1e99572
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_embedder", # Replace with your own username
version="1.6.2",
version="1.6.3",
author="Riccardo Orlando",
author_email="[email protected]",
description="Word level transformer based embeddings",
Expand Down
39 changes: 19 additions & 20 deletions transformer_embedder/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,27 @@ def add_to_tensor_inputs(self, names: Union[str, set]):
"""
if isinstance(names, str):
keys = {names}
names = {names}
self.to_tensor_inputs |= names

def to_tensor(self, batch: Union[List[dict], dict]) -> Dict[str, torch.Tensor]:
"""
Return a the batch in input as Pytorch tensors.
Args:
batch (List[dict] or dict): batch in input
Returns:
Dict: the batch as tensor
"""
# convert to tensor
batch = {
k: torch.as_tensor(v) if k in self.to_tensor_inputs else v
for k, v in batch.items()
}
return batch

def _load_spacy(self) -> spacy.tokenizer.Tokenizer:
"""
Download and load spacy model.
Expand Down Expand Up @@ -459,25 +477,6 @@ def _clean_output(output: Union[List, Dict]) -> Dict:
output = {k: [d[k] for d in output] for k in output[0]}
return output

@staticmethod
def to_tensor(batch: Union[List[dict], dict]) -> Dict[str, torch.Tensor]:
"""
Return a the batch in input as Pytorch tensors.
Args:
batch (List[dict] or dict): batch in input
Returns:
Dict: the batch as tensor
"""
# convert to tensor
batch = {
k: torch.as_tensor(v) if isinstance(v[0], list) else v
for k, v in batch.items()
}
return batch

@staticmethod
def _get_token_type_id(config) -> int:
"""
Expand Down

0 comments on commit 1e99572

Please sign in to comment.