Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Updated Initialization for Style Transfer (#2988)
Browse files Browse the repository at this point in the history
* Updated Initialization for Style Transfer

* Xavier to Uniform

* Updated initializer to be controlled by a flag, fixed a bug identified in testing on bolt

* Black reformatting
  • Loading branch information
abhishekpratapa authored Feb 10, 2020
1 parent 43ee1d1 commit 38f12b2
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 23 deletions.
13 changes: 13 additions & 0 deletions src/ml/neural_net/weight_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ void xavier_weight_initializer::operator()(float* first_weight,
}
}

uniform_weight_initializer::uniform_weight_initializer(
float lower_bound, float upper_bound, std::mt19937* random_engine)
: dist_(std::uniform_real_distribution<float>(lower_bound, upper_bound)),
random_engine_(*random_engine)
{}

void uniform_weight_initializer::operator()(float* first_weight,
float* last_weight) {
for (float* w = first_weight; w != last_weight; ++w) {
*w = dist_(random_engine_);
}
}

scalar_weight_initializer::scalar_weight_initializer(float scalar)
: scalar_(scalar) {}

Expand Down
26 changes: 26 additions & 0 deletions src/ml/neural_net/weight_init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ class xavier_weight_initializer {
std::mt19937& random_engine_;
};


class uniform_weight_initializer {
public:

/**
* Creates a weight initializer that performs Uniform initialization
*
* \param lower_bound The lower bound of the uniform distribution to be sampled
* \param upper_bound The upper bound of the uniform distribution to be sampled
* \param random_engine The random number generator to use, which must remain
* valid for the lifetime of this instance.
*/
uniform_weight_initializer(float lower_bound, float upper_bound,
std::mt19937* random_engine);

/**
* Initializes each value in uniformly at random in the range [-lower_bound, upper_bound]
*/
void operator()(float* first_weight, float* last_weight);

private:

std::uniform_real_distribution<float> dist_;
std::mt19937& random_engine_;
};

struct scalar_weight_initializer {
/**
* Creates a weight initializer that initializes all of the weights to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def create(
options["num_styles"] = len(style_dataset)
options["resnet_mlmodel_path"] = pretrained_resnet_model.get_model_path("coreml")
options["vgg_mlmodel_path"] = pretrained_vgg16_model.get_model_path("coreml")
options["pretrained_weights"] = params["pretrained_weights"]

model.train(style_dataset[style_feature], content_dataset[content_feature], options)
return StyleTransfer(model_proxy=model, name=name)
Expand Down
19 changes: 17 additions & 2 deletions src/toolkits/style_transfer/style_transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
}
size_t num_styles = num_styles_iter->second;

auto pretrained_weights_iter = opts.find("pretrained_weights");
bool pretrained_weights = false;
if (pretrained_weights_iter != opts.end()) {
pretrained_weights = pretrained_weights_iter->second;
}
opts.erase(pretrained_weights_iter);

init_options(opts);

if (read_state<flexible_type>("random_seed") == FLEX_UNDEFINED) {
Expand All @@ -694,9 +701,11 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
add_or_update_state({{"random_seed", random_seed}});
}

int random_seed = read_state<int>("random_seed");

m_training_data_iterator =
create_iterator(content, style, /* repeat */ true,
/* training */ true, static_cast<int>(num_styles));
/* training */ true, random_seed);

m_training_compute_context = create_compute_context();
if (m_training_compute_context == nullptr) {
Expand All @@ -709,7 +718,13 @@ void style_transfer::init_train(gl_sarray style, gl_sarray content,
{"styles", style_sframe_with_index(style)},
{"num_content_images", content.size()}});

m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
// TODO: change to include random seed.
if (pretrained_weights) {
m_resnet_spec = init_resnet(resnet_mlmodel_path, num_styles);
} else {
m_resnet_spec = init_resnet(num_styles, random_seed);
}

m_vgg_spec = init_vgg_16(vgg_mlmodel_path);

float_array_map weight_params = m_resnet_spec->export_params_view();
Expand Down
62 changes: 42 additions & 20 deletions src/toolkits/style_transfer/style_transfer_model_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <toolkits/style_transfer/style_transfer_model_definition.hpp>

#include <random>

#include <ml/neural_net/weight_init.hpp>
#include <toolkits/coreml_export/mlmodel_include.hpp>

Expand All @@ -20,14 +22,33 @@ using CoreML::Specification::NeuralNetworkLayer;
using turi::neural_net::float_array_map;
using turi::neural_net::model_spec;
using turi::neural_net::scalar_weight_initializer;
using turi::neural_net::uniform_weight_initializer;
using turi::neural_net::weight_initializer;
using turi::neural_net::zero_weight_initializer;


using padding_type = model_spec::padding_type;

namespace {

constexpr float LOWER_BOUND = -0.07;
constexpr float UPPER_BOUND = 0.07;

// TODO: refactor code to be more readable with loops
void define_resnet(model_spec& nn_spec, size_t num_styles) {
void define_resnet(model_spec& nn_spec, size_t num_styles, bool initialize=false, int random_seed=0) {
std::mt19937 random_engine;
std::seed_seq seed_seq{random_seed};
random_engine = std::mt19937(seed_seq);

weight_initializer initializer;

// This is to make sure that when the uniform initialization is not needed extra work is avoided
if (initialize) {
initializer = uniform_weight_initializer(LOWER_BOUND, UPPER_BOUND, &random_engine);
} else {
initializer = zero_weight_initializer();
}

nn_spec.add_padding(
/* name */ "transformer_pad0",
/* input */ "image",
Expand All @@ -46,7 +67,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_1_inst_gamma",
Expand Down Expand Up @@ -102,7 +123,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_2_inst_gamma",
Expand Down Expand Up @@ -158,7 +179,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 2,
/* stride_width */ 2,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_encode_3_inst_gamma",
Expand Down Expand Up @@ -214,7 +235,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_1_gamma",
Expand Down Expand Up @@ -270,7 +291,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_1_inst_2_gamma",
Expand Down Expand Up @@ -327,7 +348,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_1_gamma",
Expand Down Expand Up @@ -383,7 +404,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_2_inst_2_gamma",
Expand Down Expand Up @@ -440,7 +461,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_1_gamma",
Expand Down Expand Up @@ -496,7 +517,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_3_inst_2_gamma",
Expand Down Expand Up @@ -553,7 +574,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_1_gamma",
Expand Down Expand Up @@ -609,7 +630,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_4_inst_2_gamma",
Expand Down Expand Up @@ -666,7 +687,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_1_gamma",
Expand Down Expand Up @@ -722,7 +743,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_residual_5_inst_2_gamma",
Expand Down Expand Up @@ -785,7 +806,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_1_inst_gamma",
Expand Down Expand Up @@ -847,7 +868,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_decoding_2_inst_gamma",
Expand Down Expand Up @@ -903,7 +924,7 @@ void define_resnet(model_spec& nn_spec, size_t num_styles) {
/* stride_height */ 1,
/* stride_width */ 1,
/* padding */ padding_type::VALID,
/* weight_init_fn */ zero_weight_initializer());
/* weight_init_fn */ initializer);

nn_spec.add_inner_product(
/* name */ "transformer_instancenorm5_gamma",
Expand Down Expand Up @@ -1164,9 +1185,10 @@ std::unique_ptr<model_spec> init_resnet(const std::string& path) {
return spec;
}

std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles) {
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed) {
std::unique_ptr<model_spec> nn_spec(new model_spec());
define_resnet(*nn_spec, num_styles);
define_resnet(*nn_spec, num_styles, /* initialize */ true, random_seed);
return nn_spec;
}

Expand All @@ -1190,4 +1212,4 @@ std::unique_ptr<model_spec> init_vgg_16(const std::string& path) {
}

} // namespace style_transfer
} // namespace turi
} // namespace turi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace turi {
namespace style_transfer {

std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_resnet(size_t num_styles,
int random_seed=0);
std::unique_ptr<neural_net::model_spec> init_resnet(const std::string& path,
size_t num_styles);
std::unique_ptr<neural_net::model_spec> init_vgg_16();
Expand Down

0 comments on commit 38f12b2

Please sign in to comment.