-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,325 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
CMAKE_MINIMUM_REQUIRED(VERSION 2.8) | ||
PROJECT(l_softmax) | ||
|
||
find_package(CUDA REQUIRED) | ||
|
||
EXECUTE_PROCESS(COMMAND python3.5 -c "import os; print(os.getcwd(), end='', flush=True)" OUTPUT_VARIABLE CWD) | ||
MESSAGE(STATUS "Found CWD: " ${CWD}) | ||
|
||
EXECUTE_PROCESS(COMMAND python3.5 -c "import subprocess; process = subprocess.Popen('nvidia-smi -i 0 --query-gpu=name --format=csv'.split(), stdout=subprocess.PIPE); output, _ = process.communicate(); output = str(output); device_capability_map = { | ||
'Tesla K80' : '37', | ||
'Tesla K40' : '35', | ||
'Tesla K20' : '35', | ||
'Tesla C2075' : '20', | ||
'Tesla C2050' : '20', | ||
'Tesla C2070' : '20', | ||
'Tesla V100' : '70', | ||
'Tesla P100' : '60', | ||
'Tesla P40' : '61', | ||
'Tesla P4' : '61', | ||
'Tesla M60' : '52', | ||
'Tesla M40' : '52', | ||
'Tesla K80' : '37', | ||
'Tesla K40' : '35', | ||
'Tesla K20' : '35', | ||
'Tesla K10' : '30', | ||
'GeForce GTX 1080 Ti' : '61' | ||
}; cap = '61'; | ||
for k, v in device_capability_map.items(): | ||
if k in output: | ||
cap = v | ||
break | ||
print('gencode arch=compute_' + cap + ',code=sm_' + cap)" OUTPUT_VARIABLE GPU_CAPABILITY) | ||
MESSAGE(STATUS "Found GPU_CAPABILITY: " ${GPU_CAPABILITY}) | ||
|
||
# Pass options to NVCC | ||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --keep --keep-dir ${CWD} -${GPU_CAPABILITY} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr") | ||
|
||
#set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --keep --keep-dir ${CWD} -gencode arch=compute_61,code=sm_61 -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr") | ||
|
||
# compiler flags | ||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2 ${OpenMP_CXX_FLAGS} -Wall -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 -DGOOGLE_CUDA=1") | ||
|
||
# TensorFlow dependencies | ||
EXECUTE_PROCESS(COMMAND python3.5 -c "import os; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; print(tf.sysconfig.get_include(), end='', flush=True)" OUTPUT_VARIABLE TF_INC) | ||
|
||
EXECUTE_PROCESS(COMMAND python3.5 -c "import os; os.environ['TF_CPP_MIN_LOG_LEVEL']='3'; import tensorflow as tf; print(tf.sysconfig.get_lib(), end='', flush=True)" OUTPUT_VARIABLE TF_LIB) | ||
|
||
|
||
MESSAGE(STATUS "Found TF_INC: " ${TF_INC}) | ||
MESSAGE(STATUS "Found TF_INC_EXTERNAL: " ${TF_INC}/external/nsync/public) | ||
MESSAGE(STATUS "Found TF_LIB: " ${TF_LIB}) | ||
|
||
|
||
INCLUDE_DIRECTORIES(${TF_INC}) | ||
INCLUDE_DIRECTORIES(${TF_INC}/external/nsync/public) | ||
LINK_DIRECTORIES(${TF_LIB}) | ||
|
||
# approach 1 | ||
# CUDA_ADD_LIBRARY(l_softmax_gpu SHARED l_softmax_op.cu OPTIONS -I$TF_INC/tensorflow/stream_executor/cuda -I/usr/local) | ||
|
||
# ADD_LIBRARY(l_softmax SHARED | ||
# l_softmax_op.h | ||
# l_softmax_op.cc | ||
# ) | ||
|
||
# TARGET_LINK_LIBRARIES(l_softmax tensorflow_framework ${CUDA_LIBRARIES} l_softmax_gpu) | ||
|
||
|
||
# approach 2 | ||
CUDA_COMPILE(L_SOFTMAX_CU_O l_softmax_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) | ||
CUDA_COMPILE(L_SOFTMAX_GRAD_CU_O l_softmax_grad_op.cu MODULE OPTIONS -I$TF_INC -I/usr/local) | ||
|
||
ADD_LIBRARY(l_softmax SHARED | ||
${L_SOFTMAX_CU_O} | ||
${L_SOFTMAX_GRAD_CU_O} | ||
l_softmax_op.h | ||
l_softmax_op.cc | ||
l_softmax_grad_op.cc | ||
common.h | ||
common.cc | ||
) | ||
|
||
TARGET_LINK_LIBRARIES(l_softmax tensorflow_framework ${CUDA_LIBRARIES}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,48 @@ | ||
# tf.extra_losses | ||
# Large-Margin Softmax Loss In Tensorflow C++ API | ||
|
||
This repository contains codes of the reimplementation of [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/abs/1612.02295) in TensorFlow. If your goal is to reproduce the results in the paper published in ICML 2016, please use the [official codes](https://github.com/wy1iu/LargeMargin_Softmax_Loss). | ||
|
||
## ## | ||
For using this op in your own machine: | ||
|
||
- copy the header file "cuda\_config.h" from "your\_python\_path/site-packages/external/local\_config\_cuda/cuda/cuda/cuda\_config.h" to "your\_python\_path/site-packages/tensorflow/include/tensorflow/stream\_executor/cuda/cuda\_config.h". | ||
|
||
- run the following script: | ||
|
||
```sh | ||
mkdir build | ||
cd build && cmake .. | ||
make | ||
``` | ||
|
||
- run "test\_op.py" and check the numeric errors to test your install | ||
- follow the below codes snippet to integrate this Op into your own code: | ||
|
||
```python | ||
op_module = tf.load_op_library(so_lib_path) | ||
large_margin_softmax = op_module.large_margin_softmax | ||
|
||
@ops.RegisterGradient("LargeMarginSoftmax") | ||
def _large_margin_softmax_grad(op, grad, _): | ||
'''The gradients for `LargeMarginSoftmax`. | ||
''' | ||
inputs_features = op.inputs[0] | ||
inputs_weights = op.inputs[1] | ||
inputs_labels = op.inputs[2] | ||
cur_lambda = op.outputs[1] | ||
margin_order = op.get_attr('margin_order') | ||
|
||
grads = op_module.large_margin_softmax_grad(inputs_features, inputs_weights, inputs_labels, grad, cur_lambda[0], margin_order) | ||
return [grads[0], grads[1], None, None] | ||
|
||
var_weights = tf.Variable(initial_value, trainable=True, name='lsoftmax_weights') | ||
result = large_margin_softmax(features, var_weights, labels, 1, 4, 1000., 0.000025, 35., 0.) | ||
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=result[0])) | ||
``` | ||
|
||
All the codes was tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04 with CUDA 8.0. The outputs of this Op had been compared with the original caffe codes' outputs, and the bias could be ignored. The gradients of this Op had been checked using [tf.test.compute\_gradient\_error](https://www.tensorflow.org/api_docs/python/tf/test/compute_gradient_error) and [tf.test.compute\_gradient](https://www.tensorflow.org/api_docs/python/tf/test/compute_gradient). | ||
|
||
Any contributions to this repo is welcomed. | ||
|
||
## ## | ||
MIT License |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// MIT License | ||
|
||
// Copyright (c) 2018 Changan Wang | ||
|
||
// Permission is hereby granted, free of charge, to any person obtaining a copy | ||
// of this software and associated documentation files (the "Software"), to deal | ||
// in the Software without restriction, including without limitation the rights | ||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
// copies of the Software, and to permit persons to whom the Software is | ||
// furnished to do so, subject to the following conditions: | ||
|
||
// The above copyright notice and this permission notice shall be included in all | ||
// copies or substantial portions of the Software. | ||
|
||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
// SOFTWARE. | ||
#include "common.h" | ||
|
||
//__attribute__((always_inline)) | ||
// template<typename T, typename std::is_same<float, typename std::remove_cv<T>::type>::value> | ||
void atomic_float_add(volatile float* ptr, const float operand) | ||
{ | ||
assert(is_aligned(ptr, 4)); | ||
|
||
volatile int32_t* iptr = reinterpret_cast<volatile int32_t*>(ptr); | ||
int32_t expected = *iptr; | ||
|
||
while (true) | ||
{ | ||
const float value = binary_cast<float>(expected); | ||
const int32_t new_value = binary_cast<int32_t>(value + operand); | ||
const int32_t actual = __sync_val_compare_and_swap(iptr, expected, new_value); | ||
if (actual == expected) | ||
return; | ||
expected = actual; | ||
} | ||
} | ||
|
||
int32_t _factorial(int32_t n) | ||
{ | ||
return n > 1 ? n * _factorial(n - 1) : 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// MIT License | ||
|
||
// Copyright (c) 2018 Changan Wang | ||
|
||
// Permission is hereby granted, free of charge, to any person obtaining a copy | ||
// of this software and associated documentation files (the "Software"), to deal | ||
// in the Software without restriction, including without limitation the rights | ||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
// copies of the Software, and to permit persons to whom the Software is | ||
// furnished to do so, subject to the following conditions: | ||
|
||
// The above copyright notice and this permission notice shall be included in all | ||
// copies or substantial portions of the Software. | ||
|
||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
// SOFTWARE. | ||
#ifndef COMMON_H_ | ||
#define COMMON_H_ | ||
|
||
#include <cstdlib> | ||
#include <cassert> | ||
#include <cstdint> | ||
|
||
// atomic addition for float using gcc built-in functions for atomic memory access | ||
// this code snippet borrowed from https://codereview.stackexchange.com/questions/135852/atomic-floating-point-addition | ||
template <typename Target, typename Source> | ||
__attribute__((always_inline)) Target binary_cast(Source s) | ||
{ | ||
static_assert(sizeof(Target) == sizeof(Source), "binary_cast: 'Target' must has the same size as 'Source'"); | ||
union | ||
{ | ||
Source m_source; | ||
Target m_target; | ||
} u; | ||
|
||
u.m_source = s; | ||
return u.m_target; | ||
} | ||
|
||
template <typename T> | ||
__attribute__((always_inline)) bool is_pow2(const T x) | ||
{ | ||
return (x & (x - 1)) == 0; | ||
} | ||
|
||
template <typename T> | ||
__attribute__((always_inline)) bool is_aligned(const T ptr, const size_t alignment) | ||
{ | ||
assert(alignment > 0); | ||
assert(is_pow2(alignment)); | ||
|
||
const uintptr_t p = (uintptr_t)ptr; | ||
return (p & (alignment - 1)) == 0; | ||
} | ||
|
||
extern void atomic_float_add(volatile float* ptr, const float operand); | ||
extern int32_t _factorial(int32_t); | ||
|
||
#endif // COMMON_H_ |
Oops, something went wrong.