Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Commit

Permalink
Style transfer toolkit (#624)
Browse files Browse the repository at this point in the history
New toolkit to create a style transfer model trained on style images and
content images.
  • Loading branch information
znation authored Jun 4, 2018
1 parent 67fc610 commit 85ace03
Show file tree
Hide file tree
Showing 28 changed files with 1,982 additions and 98 deletions.
16 changes: 12 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -740,14 +740,22 @@ endfunction()

# Core ML is only present on macOS 10.13 or higher
if(APPLE)
EXEC_PROGRAM(sw_vers ARGS -productVersion OUTPUT_VARIABLE mac_version)
if(mac_version GREATER_EQUAL 10.13)
EXEC_PROGRAM(xcrun ARGS --show-sdk-version OUTPUT_VARIABLE mac_version RETURN_VALUE _xcrun_ret)

if(NOT ${_xcrun_ret} EQUAL 0)
message(ERROR, "xcrun command failed with return code ${_xcrun_ret}.")
endif()

# Core ML is only present on macOS 10.13 or higher.
# Logic reversed to get around what seems to be a CMake bug.
if(NOT mac_version VERSION_LESS 10.13)
add_definitions(-DHAS_CORE_ML)
set(HAS_CORE_ML TRUE)
endif()

if(mac_version GREATER_EQUAL 10.14)
# Core ML only supports batch inference on macOS 10.14 or higher
# Core ML only supports batch inference on macOS 10.14 or higher
# Logic reversed to get around what seems to be a CMake bug.
if(NOT mac_version VERSION_LESS 10.14)
add_definitions(-DHAS_CORE_ML_BATCH_INFERENCE)

# GPU-accelerated training with MPS backend requires macOS 10.14 or higher
Expand Down
39 changes: 27 additions & 12 deletions src/image/image_util_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace image_util_detail {
using namespace boost::gil;

template<typename current_pixel_type, typename new_pixel_type>
void resize_image_detail(const char* data, size_t width, size_t height, size_t channels, size_t resized_width, size_t resized_height, size_t resized_channels, char** resized_data){
void resize_image_detail(const char* data, size_t width, size_t height, size_t channels, size_t resized_width, size_t resized_height, size_t resized_channels, char** resized_data, int resample_method){
if (data == NULL){
log_and_throw("Trying to resize image with NULL data pointer");
}
Expand All @@ -37,7 +37,13 @@ void resize_image_detail(const char* data, size_t width, size_t height, size_t c
auto view = interleaved_view(width, height, (current_pixel_type*)data, width * channels * sizeof(char));
auto resized_view = interleaved_view(resized_width, resized_height, (new_pixel_type*)buf,
resized_width * resized_channels * sizeof(char));
resize_view(color_converted_view<new_pixel_type>(view), (resized_view), nearest_neighbor_sampler());
if (resample_method == 0) {
resize_view(color_converted_view<new_pixel_type>(view), (resized_view), nearest_neighbor_sampler());
} else if (resample_method == 1) {
resize_view(color_converted_view<new_pixel_type>(view), (resized_view), bilinear_sampler());
} else {
log_and_throw("Unknown resampling method");
}
}
*resized_data = buf;
}
Expand All @@ -46,53 +52,62 @@ void resize_image_detail(const char* data, size_t width, size_t height, size_t c
/**
* Resize the image, and set resized_data to resized image data.
*/
void resize_image_impl(const char* data, size_t width, size_t height, size_t channels, size_t resized_width, size_t resized_height, size_t resized_channels, char** resized_data) {
void resize_image_impl(const char* data, size_t width, size_t height, size_t channels, size_t resized_width, size_t resized_height, size_t resized_channels, char** resized_data, int resample_method) {
// This code should be simplified
if (channels == 1) {
if (resized_channels == 1){
resize_image_detail<gray8_pixel_t, gray8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 3){
resize_image_detail<gray8_pixel_t, rgb8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 4){
resize_image_detail<gray8_pixel_t, rgba8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else {
log_and_throw (std::string("Unsupported channel size ") + std::to_string(channels));
}
} else if (channels == 3) {
if (resized_channels == 1){
resize_image_detail<rgb8_pixel_t, gray8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 3){
resize_image_detail<rgb8_pixel_t, rgb8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 4){
resize_image_detail<rgb8_pixel_t, rgba8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else {
log_and_throw (std::string("Unsupported channel size ") + std::to_string(channels));
}
} else if (channels == 4) {
if (resized_channels == 1){
resize_image_detail<rgba8_pixel_t, gray8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 3){
resize_image_detail<rgba8_pixel_t, rgb8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else if (resized_channels == 4){
resize_image_detail<rgba8_pixel_t, rgba8_pixel_t>(data, width, height, channels,
resized_width, resized_height,
resized_channels, resized_data);
resized_channels, resized_data,
resample_method);
} else {
log_and_throw (std::string("Unsupported channel size ") + std::to_string(channels));
}
Expand Down
2 changes: 1 addition & 1 deletion src/image/image_util_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace image_util_detail {

void resize_image_impl(const char* data, size_t width, size_t height,
size_t channels, size_t resized_width, size_t resized_height,
size_t resized_channels, char** resized_data);
size_t resized_channels, char** resized_data, int resample_method);

void decode_image_impl(image_type& image);

Expand Down
36 changes: 25 additions & 11 deletions src/unity/lib/image_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,10 @@ std::shared_ptr<unity_sarray> decode_image_sarray(std::shared_ptr<unity_sarray>
/**
* Reisze an sarray of flex_images with the new size.
*/
flexible_type resize_image(const flexible_type& image, size_t resized_width, size_t resized_height, size_t resized_channels, bool decode) {
if (image.get_type() != flex_type_enum::IMAGE){
flexible_type resize_image(const flexible_type& input, size_t resized_width,
size_t resized_height, size_t resized_channels,
bool decode, int resample_method) {
if (input.get_type() != flex_type_enum::IMAGE){
std::string error = "Cannot resize non-image type";
log_and_throw(error);
}
Expand All @@ -358,13 +360,24 @@ flexible_type resize_image(const flexible_type& image, size_t resized_width, siz
decoded_image.m_width, decoded_image.m_height, decoded_image.m_channels, resized_width,
resized_height, resized_channels, &resized_data);
}
flex_image dst_img;
dst_img.m_width = resized_width;
dst_img.m_height = resized_height;
dst_img.m_channels = resized_channels;
dst_img.m_format = Format::RAW_ARRAY;
dst_img.m_image_data_size = resized_height * resized_width * resized_channels;
dst_img.m_image_data.reset(resized_data);

// Resize if necessary.
if (!has_desired_size()) {
char* resized_data;
image_util_detail::resize_image_impl(
reinterpret_cast<const char*>(image.get_image_data()),
image.m_width, image.m_height, image.m_channels,
resized_width, resized_height, resized_channels,
&resized_data, resample_method);
image.m_width = resized_width;
image.m_height = resized_height;
image.m_channels = resized_channels;
image.m_format = Format::RAW_ARRAY;
image.m_image_data_size = resized_height * resized_width * resized_channels;
image.m_image_data.reset(resized_data);
}

// Encode if necessary.
if (!decode) {
image_util_detail::encode_image_impl(dst_img);
}
Expand All @@ -380,10 +393,11 @@ std::shared_ptr<unity_sarray> resize_image_sarray(
size_t resized_width,
size_t resized_height,
size_t resized_channels,
bool decode) {
bool decode,
int resample_method) {
log_func_entry();
auto fn = [=](const flexible_type& f)->flexible_type {
return flexible_type(resize_image(f, resized_width, resized_height, resized_channels, decode));
return flexible_type(resize_image(f, resized_width, resized_height, resized_channels, decode, resample_method));
};
auto ret = image_sarray->transform_lambda(fn, flex_type_enum::IMAGE, true, 0);
return std::static_pointer_cast<unity_sarray>(ret);
Expand Down
10 changes: 7 additions & 3 deletions src/unity/lib/image_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,20 @@ flexible_type encode_image(const flexible_type& data);
/* */
/**************************************************************************/

/** Reisze an sarray of flex_images with the new size.
/** Resize an sarray of flex_images with the new size. The sampling method
* is specified as the polynomial order of the resampling kernel, with 0
* (nearest neighbor) and 1 (bilinear) supported.
*/
flexible_type resize_image(const flexible_type& image, size_t resized_width,
size_t resized_height, size_t resized_channel, bool decode = false);
size_t resized_height, size_t resized_channel, bool decode = false,
int resample_method = 0);

/** Resize an sarray of flex_image with the new size.
*/
std::shared_ptr<unity_sarray> resize_image_sarray(
std::shared_ptr<unity_sarray> image_sarray, size_t resized_width,
size_t resized_height, size_t resized_channels, bool decode = false);
size_t resized_height, size_t resized_channels, bool decode = false,
int resample_method = 0);



Expand Down
1 change: 1 addition & 0 deletions src/unity/python/turicreate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import turicreate.toolkits.image_classifier as image_classifier
import turicreate.toolkits.image_similarity as image_similarity
import turicreate.toolkits.object_detector as object_detector
import turicreate.toolkits.style_transfer as style_transfer
import turicreate.toolkits.activity_classifier as activity_classifier

from turicreate.toolkits import evaluation
Expand Down
2 changes: 2 additions & 0 deletions src/unity/python/turicreate/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def set_num_gpus(num_gpus):
>> turicreate.config.set_num_gpus(1)
>> turicreate.image_classifier.create(data, target='label')
"""
if(num_gpus < -1):
raise ValueError("'num_gpus' must be greater than or equal to -1")
set_runtime_config('TURI_NUM_GPUS', num_gpus)


Expand Down
7 changes: 0 additions & 7 deletions src/unity/python/turicreate/test/test_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,6 @@ def test_create_with_empty_dataset(self):
with self.assertRaises(_ToolkitError):
tc.image_classifier.create(self.sf[:0], target = self.target)

def test_invalid_num_gpus(self):
num_gpus = tc.config.get_num_gpus()
tc.config.set_num_gpus(-2)
with self.assertRaises(_ToolkitError):
tc.image_classifier.create(self.sf, target=self.target)
tc.config.set_num_gpus(num_gpus)

def test_predict(self):
model = self.model
for output_type in ['class', 'probability_vector']:
Expand Down
7 changes: 0 additions & 7 deletions src/unity/python/turicreate/test/test_image_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,6 @@ def test_create_with_empty_dataset(self):
with self.assertRaises(_ToolkitError):
tc.image_similarity.create(self.sf[:0])

def test_invalid_num_gpus(self):
num_gpus = tc.config.get_num_gpus()
tc.config.set_num_gpus(-2)
with self.assertRaises(_ToolkitError):
tc.image_similarity.create(self.sf)
tc.config.set_num_gpus(num_gpus)

def test_query(self):
model = self.model
preds = model.query(self.sf)
Expand Down
7 changes: 0 additions & 7 deletions src/unity/python/turicreate/test/test_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,6 @@ def test_dict_annotations(self):
annotated_img = tc.object_detector.util.draw_bounding_boxes(sf_copy[self.feature],
sf_copy[self.annotations])

def test_invalid_num_gpus(self):
num_gpus = tc.config.get_num_gpus()
tc.config.set_num_gpus(-2)
with self.assertRaises(_ToolkitError):
tc.object_detector.create(self.sf)
tc.config.set_num_gpus(num_gpus)

def test_extra_classes(self):
# Create while the data has extra classes
model = tc.object_detector.create(self.sf, classes=_CLASSES[:2], max_iterations=1)
Expand Down
Loading

0 comments on commit 85ace03

Please sign in to comment.