diff --git a/src/python/turicreate/toolkits/style_transfer/_tf_model_architecture.py b/src/python/turicreate/toolkits/style_transfer/_tf_model_architecture.py index 790442d849..8518ac1779 100644 --- a/src/python/turicreate/toolkits/style_transfer/_tf_model_architecture.py +++ b/src/python/turicreate/toolkits/style_transfer/_tf_model_architecture.py @@ -24,9 +24,11 @@ def define_tensorflow_variables(net_params, trainable=True): Parameters ---------- trainable: boolean - If `true` the network updates the convolutional layers as well as the - instance norm layers of the network. If `false` only the instance norm - layers of the network are updated. + If `True` the transformer network updates the convolutional layers as + well as the instance norm layers of the network. If `False` only the + instance norm layers of the network are updated. + + Note the VGG network's parameters aren't updated Returns ------- out: dict @@ -35,21 +37,24 @@ def define_tensorflow_variables(net_params, trainable=True): tensorflow_variables = dict() for key in net_params.keys(): if "weight" in key: + # only set the parameter to train if in the transformer network + train_param = trainable and "transformer_" in key if "conv" in key: tensorflow_variables[key] = _tf.Variable( initial_value=_utils.convert_conv2d_coreml_to_tf(net_params[key]), name=key, - trainable=trainable, + trainable=train_param, ) else: + # This is the path that the instance norm takes tensorflow_variables[key] = _tf.Variable( initial_value=_utils.convert_dense_coreml_to_tf(net_params[key]), name=key, - trainable=trainable, + trainable=True, ) else: tensorflow_variables[key] = _tf.Variable( - initial_value=net_params[key], name=key, trainable=trainable + initial_value=net_params[key], name=key, trainable=False ) return tensorflow_variables