Skip to content

Commit

Permalink
Remove workarounds no longer relevant in 4.10.2 release of Transforme…
Browse files Browse the repository at this point in the history
…rs, closes #110
  • Loading branch information
davidmezzetti committed Sep 10, 2021
1 parent b8aef46 commit 652a494
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 41 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
DESCRIPTION = f.read()

# Required dependencies
install = ["faiss-cpu>=1.7.1.post2", "numpy>=1.18.4", "torch>=1.6.0", "transformers>=4.8.2"]
install = ["faiss-cpu>=1.7.1.post2", "numpy>=1.18.4", "torch>=1.6.0", "transformers>=4.10.2"]

# Optional dependencies
extras = {}
Expand Down
17 changes: 7 additions & 10 deletions src/python/txtai/models/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,13 @@ def autoadd(self, mapping, key, value):
"""

# pylint: disable=W0212
if hasattr(mapping, "_config_mapping"):
Params = namedtuple("Params", ["config", "model"])
params = Params(key, value)

mapping._modules[key] = params
mapping._config_mapping[key] = "config"
mapping._reverse_config_mapping[value] = key
mapping._model_mapping[key] = "model"
else:
mapping[key] = value
Params = namedtuple("Params", ["config", "model"])
params = Params(key, value)

mapping._modules[key] = params
mapping._config_mapping[key] = "config"
mapping._reverse_config_mapping[value] = key
mapping._model_mapping[key] = "model"

def forward(self, **inputs):
"""
Expand Down
24 changes: 0 additions & 24 deletions src/python/txtai/pipeline/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def __init__(self, path=None, quantize=False, gpu=True, model=None, dynamic=True
# Set if labels are dynamic (zero shot) or fixed (standard text classification)
self.dynamic = dynamic

# Save handle to pipeline tokenizer
self.tokenizer = self.pipeline.tokenizer

def __call__(self, text, labels=None, multilabel=False):
"""
Applies a text classifier to text. Returns a list of (id, score) sorted by highest score,
Expand All @@ -41,14 +38,8 @@ def __call__(self, text, labels=None, multilabel=False):
"""

if self.dynamic:
# Override tokenizer to set truncation parameter
self.pipeline.tokenizer = self.tokenize

# Run zero shot classification pipeline
results = self.pipeline(text, labels, multi_label=multilabel, truncation=True)

# Reset tokenizer
self.pipeline.tokenizer = self.tokenizer
else:
# Run text classification pipeline
results = self.textclassify(text, multilabel)
Expand Down Expand Up @@ -103,18 +94,3 @@ def labels(self):
"""

return list(self.pipeline.model.config.id2label.values())

def tokenize(self, *args, **kwargs):
"""
Tokenization method that forces truncation=True.
Args:
args: arguments
kwargs: named arguments
Returns:
tokenized output
"""

kwargs["truncation"] = True
return self.tokenizer(*args, **kwargs)
6 changes: 0 additions & 6 deletions test/python/testonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@ def __init__(self):
self.assertTrue("key" in mapping._model_mapping)
self.assertTrue("key" in mapping._modules)

# Test mapping backed by dict
mapping = {}
model.autoadd(mapping, "key", "value")

self.assertTrue("key" in mapping)

def testDefault(self):
"""
Test exporting an ONNX model with default parameters
Expand Down

0 comments on commit 652a494

Please sign in to comment.