From 65d681f76aff840375c42ea46a4423a62a34f875 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 19 Jun 2020 18:17:10 +0200 Subject: [PATCH] Add an example in a script and README file. --- README.md | 47 ++++++++++++++++++++++++++++++++++++++ examples/example.py | 55 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 examples/example.py diff --git a/README.md b/README.md index 045e8d1..7860f99 100755 --- a/README.md +++ b/README.md @@ -23,6 +23,53 @@ pip install . to install deepCABAC extension. +## Examples + +### Encoding Pytorch model weights +``` +import deepCABAC +import torchvision.models as models + +model = models.resnet50(pretrained=True) +encoder = deepCABAC.Encoder() + +interv = 0.1 +stepsize = 15 +_lambda = 0. + +for name, param in model.state_dict().items(): + if '.weight' in name: + encoder.encodeWeightsRD( weights, interv, stepsize, _lambda ) + else: + encoder.encodeWeightsRD( weights, interv, stepsize + 4, _lambda ) + +stream = encoder.finish() +with open('weights.bin', 'wb') as f: + f.write(stream) +``` + +### Decoding Pytorch model weights +``` +import deepCABAC +import torchvision.models as models + +model = models.resnet50(pretrained=False) +decoder = deepCABAC.Decoder() + +with open('weights.bin', 'rb') as f: + stream = f.read() + +decoder.getStream(stream) +state_dict = model.state_dict() +for name in state_dict.keys(): + state_dict[name] = torch.tensor(decA.decodeWeights()) +decoder.finish() + +model.load_state_dict(state_dict) + +# evaluate(model) +``` + ### Debugging If you want to debug the module, on Ubuntu with gdb you can use: diff --git a/examples/example.py b/examples/example.py new file mode 100644 index 0000000..79bfe43 --- /dev/null +++ b/examples/example.py @@ -0,0 +1,55 @@ +import torch +import deepCABAC +import torchvision.models as models +import numpy as np +from tqdm import tqdm + + +def main(): + # encoding + model = models.resnet18(pretrained=True) + encoder = deepCABAC.Encoder() + + interv = 0.1 + stepsize = 2**(-0.5*15) + stepsize_other = 2**(-0.5*19) + _lambda = 0. + + for name, param in tqdm(model.state_dict().items()): + if '.num_batches_tracked' in name: + continue + param = param.cpu().numpy() + if '.weight' in name: + encoder.encodeWeightsRD(param, interv, stepsize, _lambda) + else: + encoder.encodeWeightsRD(param, interv, stepsize_other, _lambda) + + stream = encoder.finish().tobytes() + print("Compressed size: {:2f} MB".format(1e-6 * len(stream))) + with open('weights.bin', 'wb') as f: + f.write(stream) + + # decoding + model = models.resnet18(pretrained=False) + decoder = deepCABAC.Decoder() + + with open('weights.bin', 'rb') as f: + stream = f.read() + + decoder.getStream(np.frombuffer(stream, dtype=np.uint8)) + state_dict = model.state_dict() + + for name in tqdm(state_dict.keys()): + if '.num_batches_tracked' in name: + continue + param = decoder.decodeWeights() + state_dict[name] = torch.tensor(param) + decoder.finish() + + model.load_state_dict(state_dict) + + # evaluate(model) + + +if __name__ == '__main__': + main()