Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accuracy issues for both Keras (all models) and PyTorch (inception_v3) #5

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ Benchmarks for **every** pre-trained model in PyTorch and Keras-Tensorflow. Benc

## Why this is helpful

Combining Keras and PyTorch benchmarks into a single framework lets researchers decide which platform is best for a given model. For example `resnet` architectures perform better in PyTorch and `inception` architectures perform better in Keras (see below). These benchmarks serve as a standard from which to start new projects or debug current implementations.

For researchers exploring Keras and PyTorch models, these benchmarks serve as a standard from which to start new projects or debug current implementations.
Combining Keras and PyTorch benchmarks into a single framework lets researchers decide which platform is best for a given model. For example `resnet` architectures perform better in PyTorch and `inception` architectures perform better in Keras (see below). These benchmarks serve as a standard from which to start new projects or debug current implementations.

Many researchers struggle with reproducible accuracy benchmarks of pre-trained Keras (Tensorflow) models on ImageNet. Examples of issues are [here1](https://github.com/keras-team/keras/issues/10040), [here2](https://github.com/keras-team/keras/issues/10979), [here3](http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/), [here4](https://github.com/keras-team/keras/issues/8672), and [here5](https://github.com/keras-team/keras/issues/7848).

Expand Down
47 changes: 40 additions & 7 deletions imagenet_keras_get_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import os
import sys
from PIL import Image

# Use PyTorch/torchvision for dataloading (more reliable/faster)
from torchvision import datasets
Expand Down Expand Up @@ -149,11 +150,11 @@ def main(args = parser.parse_args()):
# Create output directory if it does not exist
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

# Grab imagenet data
val_dataset = datasets.ImageFolder(args.val_dir)
img_paths, labels = (list(t) for t in zip(*val_dataset.imgs))

# Run forward pass inference on all models for all examples in val set.
models = keras_models if args.model is None else [args.model]
for model in models:
Expand All @@ -176,6 +177,23 @@ def main(args = parser.parse_args()):

# In[9]:

def crop_center(img, target_size):
# target_size is assumed to be in network's order (H, W)
w, h = img.size
cx = w // 2
cy = h // 2
left = cx - target_size[1] // 2
top = cy - target_size[0] // 2
return img.crop((left, top, left + target_size[1], top + target_size[0]))


def shortest_edge_scale(img, target_size, scale):
# target_size is assumed to be in network's order (H, W)
w, h = img.size
nw = int(w * target_size[1] / scale) // min((w, h))
nh = int(h * target_size[0] / scale) // min((w, h))
return img.resize((nw, nh), resample=Image.BILINEAR)


def process_model(
model_name,
Expand All @@ -201,16 +219,31 @@ def process_model(

# Create Keras model
model = Model(weights='imagenet')

# Preprocessing and Forward pass through validation set.
probs = []
inputs = []
batch_size = 64
for i, img_path in enumerate(img_paths):
if i % 32 == 0:
img = image.load_img(img_path, target_size=None)
img = shortest_edge_scale(img, img_size, 0.875)
img = crop_center(img, img_size)
img = np.expand_dims(image.img_to_array(img), axis=0)
inputs.append(img)

current_batch = 0
if i % (batch_size + 1) == 0:
current_batch = batch_size
elif i == len(img_paths) - 1:
current_batch = len(img_paths) % batch_size

if current_batch:
inputs = np.concatenate(inputs, axis=0)
probs.append(model.predict_on_batch(preprocess_model(inputs)))
print("\r{} completed: {:.2%}".format(model_name, i / len(img_paths)), end="")
sys.stdout.flush()
img = image.load_img(img_path, target_size=img_size)
img = np.expand_dims(image.img_to_array(img), axis=0)
probs.append(model.predict(preprocess_model(img)))
inputs = []

probs = np.vstack(probs)
if save_all_probs:
np.save(wfn_base + "probs.npy", probs.astype(np.float16))
Expand Down
4 changes: 2 additions & 2 deletions imagenet_pytorch_get_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def main(args = parser.parse_args()):
dataloaders = {}
for img_size in [224, 299]:
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.Resize(img_size // 0.875),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
Expand Down Expand Up @@ -132,7 +132,7 @@ def process_model(
model = model.to(device)
wfn_base = os.path.join(out_dir, model_name + "_pytorch_imagenet_")
probs, labels = [], []
loader = dataloaders[299] if model_name is "inception_v3" else dataloaders[224]
loader = dataloaders[299] if model_name == "inception_v3" else dataloaders[224]

# Inference, with no gradient changing
model.eval() # set model to inference mode (not train mode)
Expand Down