Skip to content

Commit

Permalink
Add an example in a script and README file.
Browse files Browse the repository at this point in the history
  • Loading branch information
Talmaj Marinc committed Jun 19, 2020
1 parent 4df9b9a commit 65d681f
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,53 @@ pip install .

to install deepCABAC extension.

## Examples

### Encoding Pytorch model weights
```
import deepCABAC
import torchvision.models as models
model = models.resnet50(pretrained=True)
encoder = deepCABAC.Encoder()
interv = 0.1
stepsize = 15
_lambda = 0.
for name, param in model.state_dict().items():
if '.weight' in name:
encoder.encodeWeightsRD( weights, interv, stepsize, _lambda )
else:
encoder.encodeWeightsRD( weights, interv, stepsize + 4, _lambda )
stream = encoder.finish()
with open('weights.bin', 'wb') as f:
f.write(stream)
```

### Decoding Pytorch model weights
```
import deepCABAC
import torchvision.models as models
model = models.resnet50(pretrained=False)
decoder = deepCABAC.Decoder()
with open('weights.bin', 'rb') as f:
stream = f.read()
decoder.getStream(stream)
state_dict = model.state_dict()
for name in state_dict.keys():
state_dict[name] = torch.tensor(decA.decodeWeights())
decoder.finish()
model.load_state_dict(state_dict)
# evaluate(model)
```

### Debugging

If you want to debug the module, on Ubuntu with gdb you can use:
Expand Down
55 changes: 55 additions & 0 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import deepCABAC
import torchvision.models as models
import numpy as np
from tqdm import tqdm


def main():
# encoding
model = models.resnet18(pretrained=True)
encoder = deepCABAC.Encoder()

interv = 0.1
stepsize = 2**(-0.5*15)
stepsize_other = 2**(-0.5*19)
_lambda = 0.

for name, param in tqdm(model.state_dict().items()):
if '.num_batches_tracked' in name:
continue
param = param.cpu().numpy()
if '.weight' in name:
encoder.encodeWeightsRD(param, interv, stepsize, _lambda)
else:
encoder.encodeWeightsRD(param, interv, stepsize_other, _lambda)

stream = encoder.finish().tobytes()
print("Compressed size: {:2f} MB".format(1e-6 * len(stream)))
with open('weights.bin', 'wb') as f:
f.write(stream)

# decoding
model = models.resnet18(pretrained=False)
decoder = deepCABAC.Decoder()

with open('weights.bin', 'rb') as f:
stream = f.read()

decoder.getStream(np.frombuffer(stream, dtype=np.uint8))
state_dict = model.state_dict()

for name in tqdm(state_dict.keys()):
if '.num_batches_tracked' in name:
continue
param = decoder.decodeWeights()
state_dict[name] = torch.tensor(param)
decoder.finish()

model.load_state_dict(state_dict)

# evaluate(model)


if __name__ == '__main__':
main()

0 comments on commit 65d681f

Please sign in to comment.