-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from InfiniTensor/dev
Dev
- Loading branch information
Showing
81 changed files
with
2,226 additions
and
442 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
src/04kernel/include/kernel/attributes/mat_mul_integer_info.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#ifndef KERNEL_MAT_MUL_INTEGER_INFO_H | ||
#define KERNEL_MAT_MUL_INTEGER_INFO_H | ||
|
||
#include "kernel/attributes/broadcaster.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct MatMulIntegerInfo { | ||
struct Input { | ||
bool | ||
withZeroPoint, | ||
signed_, | ||
scalar; | ||
|
||
Input(TensorRefs const &, size_t i) noexcept; | ||
}; | ||
|
||
Input a, b; | ||
dim_t m, k, n; | ||
Broadcaster broadcaster; | ||
|
||
explicit MatMulIntegerInfo(TensorRefs const &inputs) noexcept; | ||
dim_t batch() const noexcept; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_MAT_MUL_INTEGER_INFO_H |
18 changes: 18 additions & 0 deletions
18
src/04kernel/include/kernel/collectors/dequantize_linear.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef KERNEL_DEQUANTIZE_LINEAR_H | ||
#define KERNEL_DEQUANTIZE_LINEAR_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct DequantizeLinearCollector final : public InfoCollector { | ||
|
||
explicit DequantizeLinearCollector(decltype(_target)) noexcept; | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_DEQUANTIZE_LINEAR_H |
18 changes: 18 additions & 0 deletions
18
src/04kernel/include/kernel/collectors/dynamic_quantize_linear.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef KERNEL_DYNAMIC_QUANTIZE_LINEAR_H | ||
#define KERNEL_DYNAMIC_QUANTIZE_LINEAR_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct DynamicQuantizeLinearCollector final : public InfoCollector { | ||
|
||
explicit DynamicQuantizeLinearCollector(decltype(_target)) noexcept; | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_DYNAMIC_QUANTIZE_LINEAR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef KERNEL_MAT_MUL_INTEGER_H | ||
#define KERNEL_MAT_MUL_INTEGER_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct MatMulIntegerCollector final : public InfoCollector { | ||
|
||
constexpr MatMulIntegerCollector(decltype(_target) target) noexcept | ||
: InfoCollector(target) {} | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_MAT_MUL_INTEGER_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include "kernel/attributes/mat_mul_integer_info.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
MatMulIntegerInfo::Input::Input(TensorRefs const &inputs, size_t i) noexcept | ||
: withZeroPoint(false), | ||
signed_(true), | ||
scalar(true) { | ||
if (inputs.size() > i + 2) { | ||
auto const &t = inputs[i + 2].get(); | ||
auto size = t.elementsSize(); | ||
if (t.data) { | ||
auto data = slice(t.data->get<uint8_t>(), size); | ||
if (std::all_of(data.begin(), data.end(), [](auto x) { return x == 0; })) { | ||
return; | ||
} | ||
} | ||
withZeroPoint = true; | ||
signed_ = t.dataType == DataType::I8; | ||
scalar = size == 1; | ||
} | ||
} | ||
|
||
MatMulIntegerInfo::MatMulIntegerInfo(TensorRefs const &inputs) noexcept | ||
: a(inputs, 0), | ||
b(inputs, 1), | ||
#define A (inputs[0].get().shape) | ||
#define B (inputs[1].get().shape) | ||
m(A.rbegin()[1]), | ||
k(A.rbegin()[0]), | ||
n(B.rbegin()[0]), | ||
broadcaster({slice(A.data(), A.size() - 2), | ||
slice(B.data(), B.size() - 2)}) { | ||
} | ||
#undef A | ||
#undef B | ||
|
||
dim_t MatMulIntegerInfo::batch() const noexcept { | ||
return broadcaster.outputsCount; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#include "kernel/collectors/dequantize_linear.h" | ||
#include "../kernels/dequantize_linear/cpu_kernel.hh" | ||
#include "../kernels/dequantize_linear/cuda_kernel.hh" | ||
|
||
namespace refactor::kernel { | ||
|
||
DequantizeLinearCollector:: | ||
DequantizeLinearCollector(decltype(_target) target) noexcept | ||
: InfoCollector(target) {} | ||
|
||
std::vector<KernelBox> | ||
DequantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
auto const &output = outputs[0]; | ||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = DequantizeLinearCpu::build(inputs, output); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = DequantizeLinearCuda::build(inputs, output); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#include "kernel/collectors/dynamic_quantize_linear.h" | ||
#include "../kernels/dynamic_quantize_linear/cpu_kernel.hh" | ||
#include "../kernels/dynamic_quantize_linear/cuda_kernel.hh" | ||
|
||
namespace refactor::kernel { | ||
|
||
DynamicQuantizeLinearCollector:: | ||
DynamicQuantizeLinearCollector(decltype(_target) target) noexcept | ||
: InfoCollector(target) {} | ||
|
||
std::vector<KernelBox> | ||
DynamicQuantizeLinearCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
auto size = inputs[0].get().elementsSize(); | ||
|
||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = DynamicQuantizeLinearCpu::build(size); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = DynamicQuantizeLinearCuda::build(size); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "kernel/collectors/mat_mul_integer.h" | ||
#include "../../src/kernels/mat_mul_integer/cpu_kernel.hh" | ||
#include "../../src/kernels/mat_mul_integer/cublas_kernel.hh" | ||
#include "kernel/attributes/mat_mul_integer_info.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
std::vector<KernelBox> | ||
MatMulIntegerCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
MatMulIntegerInfo info(inputs); | ||
|
||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
if (auto ptr = MatMulIntegerCpu::build(info); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
case decltype(_target)::Nvidia: | ||
if (auto ptr = MatMulIntegerCublas::build(info); ptr) { | ||
ans.emplace_back(std::move(ptr)); | ||
} | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
Oops, something went wrong.