Skip to content

Commit

Permalink
support torch.sub/sign/abs in eager mode
Browse files Browse the repository at this point in the history
Summary:
As title
Achieve by `vizard.quantization.functional. FloatFunctional` and `vizard.quantization.prepare.prepare_eager`

Differential Revision: D48377683
  • Loading branch information
jiaxuzhu92 authored and facebook-github-bot committed Sep 18, 2023
1 parent d49077d commit 67d280a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
9 changes: 6 additions & 3 deletions d2go/quantization/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@

TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize import convert, prepare, prepare_qat
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize import convert
from torch.quantization.quantize import convert, prepare, prepare_qat
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx


@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
return prepare_qat_fx if is_qat else prepare_fx
if cfg.QUANTIZATION.EAGER_MODE:
return prepare_qat if is_qat else prepare
else:
return prepare_qat_fx if is_qat else prepare_fx


@fb_overwritable()
Expand Down
7 changes: 2 additions & 5 deletions d2go/quantization/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,8 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
)
model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if is_qat:
torch.ao.quantization.prepare_qat(model, inplace=True)
else:
torch.ao.quantization.prepare(model, inplace=True)

prepare_fn = get_prepare_fx_fn(cfg, is_qat)
prepare_fn(model, inplace=True)
else:
# FX graph mode requires the model to be symbolically traceable, swap common
# modules like SyncBN to FX-friendly version.
Expand Down

0 comments on commit 67d280a

Please sign in to comment.