Skip to content

Commit

Permalink
example: update example of lenet-stn
Browse files Browse the repository at this point in the history
make its optimizer configured same as Python's

fix #369
  • Loading branch information
iblislin committed Dec 11, 2017
1 parent db09528 commit ebd7404
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions examples/mnist/lenet-stn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ data = mx.Variable(:data)
# the localisation network in lenet-stn
# it will increase acc about more than 1%, when num-epoch >=15
# The localization net just takes the data as input and must output a vector in R^n
loc_net = @mx.chain mx.Convolution(data, num_filter=10, kernel=(5, 5), stride=(2,2)) =>
loc_net = @mx.chain mx.Convolution(data, num_filter=30, kernel=(5, 5), stride=(2, 2)) =>
mx.Activation(act_type=:relu) =>
mx.Pooling( kernel=(2, 2), stride=(2, 2), pool_type=:max) =>
mx.Convolution( num_filter=10, kernel=(3, 3), stride=(2,2), pad=(1, 1)) =>
mx.Pooling(kernel=(2, 2), stride=(2, 2), pool_type=:max) =>
mx.Convolution(num_filter=60, kernel=(3, 3), stride=(2, 2), pad=(1, 1)) =>
mx.Activation(act_type=:relu) =>
mx.Pooling( global_pool=true, kernel=(2, 2), pool_type=:avg) =>
mx.Pooling(global_pool=true, kernel=(2, 2), pool_type=:avg) =>
mx.Flatten() =>
mx.FullyConnected(num_hidden=6, name=:stn_loc)

data=mx.SpatialTransformer(data,loc_net, target_shape = (28,28), transform_type="affine", sampler_type="bilinear")
data = mx.SpatialTransformer(data, loc_net, target_shape=(28, 28), transform_type="affine", sampler_type="bilinear")

# first conv
conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) =>
Expand Down Expand Up @@ -57,8 +57,13 @@ train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
model = mx.FeedForward(lenet, context=mx.cpu())

# optimizer
optimizer = mx.ADAM(lr=0.01, weight_decay=0.00001)
optimizer = mx.SGD(lr=0.1, momentum=.9)

# fit parameters
initializer=mx.XavierInitializer(distribution = mx.xv_uniform, regularization = mx.xv_avg, magnitude = 1)
mx.fit(model, optimizer, train_provider, n_epoch=20, eval_data=eval_provider,initializer=initializer)
initializer = mx.XavierInitializer(distribution=mx.xv_normal,
regularization=mx.xv_in,
magnitude=2)
mx.fit(model, optimizer, train_provider,
n_epoch=20,
eval_data=eval_provider,
initializer=initializer)

0 comments on commit ebd7404

Please sign in to comment.