diff --git a/examples/mnist/lenet-stn.jl b/examples/mnist/lenet-stn.jl index 23ca9de3f..60f2def68 100644 --- a/examples/mnist/lenet-stn.jl +++ b/examples/mnist/lenet-stn.jl @@ -12,16 +12,16 @@ data = mx.Variable(:data) # the localisation network in lenet-stn # it will increase acc about more than 1%, when num-epoch >=15 # The localization net just takes the data as input and must output a vector in R^n -loc_net = @mx.chain mx.Convolution(data, num_filter=10, kernel=(5, 5), stride=(2,2)) => +loc_net = @mx.chain mx.Convolution(data, num_filter=30, kernel=(5, 5), stride=(2, 2)) => mx.Activation(act_type=:relu) => - mx.Pooling( kernel=(2, 2), stride=(2, 2), pool_type=:max) => - mx.Convolution( num_filter=10, kernel=(3, 3), stride=(2,2), pad=(1, 1)) => + mx.Pooling(kernel=(2, 2), stride=(2, 2), pool_type=:max) => + mx.Convolution(num_filter=60, kernel=(3, 3), stride=(2, 2), pad=(1, 1)) => mx.Activation(act_type=:relu) => - mx.Pooling( global_pool=true, kernel=(2, 2), pool_type=:avg) => + mx.Pooling(global_pool=true, kernel=(2, 2), pool_type=:avg) => mx.Flatten() => mx.FullyConnected(num_hidden=6, name=:stn_loc) -data=mx.SpatialTransformer(data,loc_net, target_shape = (28,28), transform_type="affine", sampler_type="bilinear") +data = mx.SpatialTransformer(data, loc_net, target_shape=(28, 28), transform_type="affine", sampler_type="bilinear") # first conv conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) => @@ -57,8 +57,13 @@ train_provider, eval_provider = get_mnist_providers(batch_size; flat=false) model = mx.FeedForward(lenet, context=mx.cpu()) # optimizer -optimizer = mx.ADAM(lr=0.01, weight_decay=0.00001) +optimizer = mx.SGD(lr=0.1, momentum=.9) # fit parameters -initializer=mx.XavierInitializer(distribution = mx.xv_uniform, regularization = mx.xv_avg, magnitude = 1) -mx.fit(model, optimizer, train_provider, n_epoch=20, eval_data=eval_provider,initializer=initializer) +initializer = mx.XavierInitializer(distribution=mx.xv_normal, + regularization=mx.xv_in, + magnitude=2) +mx.fit(model, optimizer, train_provider, + n_epoch=20, + eval_data=eval_provider, + initializer=initializer)