From 9b2c23ddeb206637867ae4eda6bd1b71badd4203 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Fri, 24 May 2024 15:23:12 -0400 Subject: [PATCH 1/9] some looks fine modifications --- .proj.toml | 18 ++++++++++-------- CMakeLists.txt | 11 +++++++---- lib/kernels/include/kernels/device.h | 18 +++++++++--------- lib/kernels/src/device.h | 4 ++-- lib/kernels/src/hip/concat_kernels.cpp | 8 +++++--- lib/utils/CMakeLists.txt | 3 ++- 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/.proj.toml b/.proj.toml index a4592dcccc..c048347913 100644 --- a/.proj.toml +++ b/.proj.toml @@ -2,20 +2,22 @@ project_name = "flexflow" testsuite_macro = "FF_TEST_SUITE" namespace_name = "FlexFlow" header_extension = ".h" +fix_compile_commands = false build_targets = [ - "utils", - "op-attrs", "kernels", - "substitutions", - "compiler", ] test_targets = [ - "utils-tests", - "substitutions-tests", - "compiler-tests", + # "utils-tests", + # "substitutions-tests", + # "compiler-tests", ] [cmake_flags_extra] -FF_CUDA_ARCH = "60" +FF_USE_HIP_ROCM = "ON" +FF_GPU_BACKEND = "hip_rocm" CMAKE_CUDA_ARCHITECTURES = "60" +CMAKE_HIP_ARCHITECTURES = "gfx900" +CMAKE_CXX_COMPILER = "hipcc" +CMAKE_C_COMPILER = "hipcc" +# FF_CUDA_ARCH = "60" \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a6cf80100..e3322e06b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,9 @@ cmake_minimum_required(VERSION 3.10) project(FlexFlow) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} @@ -83,9 +86,9 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") set(LIBEXT ".so") endif() -include(cuda) -include(cudnn) -include(nccl) +# include(cuda) +# include(cudnn) +# include(nccl) include(CodeCoverage) append_coverage_compiler_flags() # set_property(CACHE FF_GPU_BACKEND PROPERTY STRINGS ${FF_GPU_BACKENDS}) @@ -97,7 +100,7 @@ include(doctestlib) # named doctestlib to avoid a name collision with doctest.cm include(visit_struct) include(CTest) include(fmt) -include(legion) +# include(legion) include(rapidcheck) #include(gtest) diff --git a/lib/kernels/include/kernels/device.h b/lib/kernels/include/kernels/device.h index 439937177a..ac44438367 100644 --- a/lib/kernels/include/kernels/device.h +++ b/lib/kernels/include/kernels/device.h @@ -7,7 +7,7 @@ #include #elif defined(FF_USE_HIP_ROCM) #include -#include +#include #include #else #error "Unknown device" @@ -57,21 +57,21 @@ typedef miopenTensorDescriptor_t ffTensorDescriptor_t; typedef miopenActivationDescriptor_t ffActivationDescriptor_t; typedef miopenPoolingDescriptor_t ffPoolingDescriptor_t; typedef miopenBatchNormMode_t ffBatchNormMode_t; -typedef miopenFilterDescriptor_t ffFilterDescriptor_t; +typedef miopenTensorDescriptor_t ffFilterDescriptor_t; typedef miopenConvolutionDescriptor_t ffConvolutionDescriptor_t; -typedef miopenConvolutionFwdAlgo_t ffConvolutionFwdAlgo_t; -typedef miopenConvolutionBwdFilterAlgo_t ffConvolutionBwdFilterAlgo_t; -typedef miopenConvolutionBwdDataAlgo_t ffConvolutionBwdDataAlgo_t; +// typedef miopenConvolutionFwdAlgo_t ffConvolutionFwdAlgo_t; //we don't have this one in miopen +// typedef miopenConvolutionBwdFilterAlgo_t ffConvolutionBwdFilterAlgo_t; // don't have this either +// typedef miopenConvolutionBwdDataAlgo_t ffConvolutionBwdDataAlgo_t; typedef miopenDropoutDescriptor_t ffDropoutDescriptor_t; -typedef miopenOpTensorDescriptor_t ffOpTensorDescriptor_t; +typedef miopenTensorDescriptor_t ffOpTensorDescriptor_t; //don't have this either but will use miopenTensorDescriptor_t as a placeholder typedef miopenReduceTensorDescriptor_t ffReduceTensorDescriptor_t; -typedef miopenAttnDescriptor_t ffAttnDescriptor_t; -typedef miopenSeqDataDescriptor_t ffSeqDataDescriptor_t; +// typedef miopenAttnDescriptor_t ffAttnDescriptor_t; +// typedef miopenSeqDataDescriptor_t ffSeqDataDescriptor_t; typedef miopenHandle_t ffHandle_t; typedef hipEvent_t ffEvent_t; typedef hipblasHandle_t ffblasHandle_t; typedef miopenStatus_t ffStatus_t; -typedef hipblasDataType_t ffDataType_t; +typedef hipblasDatatype_t ffDataType_t; typedef miopenDataType_t ffCudnnDataType_t; typedef hipError_t ffError_t; #else diff --git a/lib/kernels/src/device.h b/lib/kernels/src/device.h index 00f2888f45..0295908125 100644 --- a/lib/kernels/src/device.h +++ b/lib/kernels/src/device.h @@ -84,13 +84,13 @@ __host__ void relu_backward_kernel(DataType data_type, void *output_grad_ptr, void const *output_ptr, size_t output_size, - cudaStream_t stream); + hipStream_t stream); __host__ void sigmoid_backward_kernel(DataType data_type, void *output_grad_ptr, void const *output_ptr, size_t output_size, - cudaStream_t stream); + hipStream_t stream); template __global__ void apply_add_with_scale(DT *data_ptr, diff --git a/lib/kernels/src/hip/concat_kernels.cpp b/lib/kernels/src/hip/concat_kernels.cpp index 6eac034a4b..0cc0c2ae48 100644 --- a/lib/kernels/src/hip/concat_kernels.cpp +++ b/lib/kernels/src/hip/concat_kernels.cpp @@ -13,9 +13,10 @@ * limitations under the License. */ +#include "device.h" #include "kernels/concat_kernels.h" -#include "kernels/hip_helper.h" #include +#include namespace FlexFlow { namespace Kernels { @@ -72,9 +73,10 @@ void backward_kernel(hipStream_t stream, coord_t num_blocks = 1, output_blk_size = 1, input_blk_sizes[MAX_NUM_INPUTS]; int num_inputs = input_grads.size(); assert(num_inputs <= MAX_NUM_INPUTS); + calc_blk_size(num_blocks, output_blk_size, output_grad.shape, axis); for (int i = 0; i < num_inputs; i++) { - shape = input_grads[i].shape; + ArrayShape shape = input_grads[i].shape; size_t input_num_blocks = 1; calc_blk_size(input_num_blocks, input_blk_sizes[i], shape, axis); assert(input_num_blocks == num_blocks); @@ -98,4 +100,4 @@ void backward_kernel(hipStream_t stream, } // namespace Concat } // namespace Kernels -} // namespace FlexFlow +} // namespace FlexFlow \ No newline at end of file diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index a0d77b9f76..1ad1f012bb 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -12,7 +12,8 @@ ff_add_library( visit_struct fmt json - cuda + miopen + # cuda ) add_subdirectory(ffi) From 9dd2682700370357b740674c9d8878686e47f00f Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Wed, 29 May 2024 13:38:13 -0400 Subject: [PATCH 2/9] looks good init version --- CMakeLists.txt | 3 --- flake.nix | 53 ++++++++++++++++++++++++++++++-------- lib/kernels/CMakeLists.txt | 53 ++++++++++++++++++++++++++++++++------ lib/utils/CMakeLists.txt | 1 - 4 files changed, 87 insertions(+), 23 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e3322e06b9..eccd077929 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,9 +1,6 @@ cmake_minimum_required(VERSION 3.10) project(FlexFlow) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} diff --git a/flake.nix b/flake.nix index 653b0e7f59..a0464e883b 100644 --- a/flake.nix +++ b/flake.nix @@ -6,10 +6,14 @@ extra-substituters = [ "https://ff.cachix.org" "https://cuda-maintainers.cachix.org/" + "https://llama-cpp.cachix.org" + "https://nixos-rocm.cachix.org/" ]; extra-trusted-public-keys = [ "cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E=" "ff.cachix.org-1:/kyZ0w35ToSJBjpiNfPLrL3zTjuPkUiqf2WH0GIShXM=" + "nixos-rocm.cachix.org-1:VEpsf7pRIijjd8csKjFNBGzkBqOmw8H9PRmgAq14LnE=" + "llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc=" ]; }; @@ -29,11 +33,30 @@ pkgs = import nixpkgs { inherit system; config.allowUnfree = true; + config.rocmSupport = true; }; lib = pkgs.lib; + rocm = pkgs.symlinkJoin { + name = "rocm"; + paths = with pkgs.rocmPackages; [ + rocm-thunk + rocm-runtime + rocm-device-libs + clr + hipcc + rccl + miopen + miopengemm + miopen-hip + hipblas + rocm-cmake + clr + ]; + }; + mkShell = pkgs.mkShell.override { - stdenv = pkgs.cudaPackages.backendStdenv; + stdenv = pkgs.rocmPackages.llvm.rocmClangStdenv; }; in { @@ -61,6 +84,7 @@ ci = mkShell { shellHook = '' export PATH="$HOME/ff/.scripts/:$PATH" + echo "ROCm path set to: $ROCM_PATH" ''; CMAKE_FLAGS = lib.strings.concatStringsSep " " [ @@ -89,21 +113,29 @@ ccache pkg-config python3 - cudatoolkit - cudaPackages.cuda_nvcc - cudaPackages.cudnn - cudaPackages.nccl - cudaPackages.libcublas - cudaPackages.cuda_cudart + # cudatoolkit + # cudaPackages.cuda_nvcc + # cudaPackages.cudnn + # cudaPackages.nccl + # cudaPackages.libcublas + # cudaPackages.cuda_cudart tl-expected - lcov # for code coverage - xdg_utils # for xdg-open to open html files ]) (with self.packages.${system}; [ legion rapidcheckFull doctest ]) + [ rocm ] + # (with pkgs.rocmPackages; [ + # hipcc + # rccl + # miopen + # miopen-hip + # hipblas + # rocm-cmake + # clr + # ]) ]; }; @@ -122,7 +154,6 @@ compdb jq gh - lcov # for code coverage ]) (with proj-repo.packages.${system}; [ proj @@ -145,4 +176,4 @@ }; } ); -} +} \ No newline at end of file diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index b2b81c85bd..67ab5c09a0 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -1,13 +1,50 @@ set(project_target kernels) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +message("rocm path: $ENV{ROCM_PATH}") + project(${project_target} - LANGUAGES CXX CUDA) + LANGUAGES CXX HIP) + +message("rocm path after: $ENV{ROCM_PATH}") + + +# if (DEFINED ENV{ROCM_PATH}) +# set(ROCM_PATH $ENV{ROCM_PATH}) +# else() + # message(FATAL_ERROR "ROCM_PATH is not set") +# endif() +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) +if(CXX_IS_HIPCC) + if(LINUX) + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + message(WARNING "Setting hipcc as the C++ compiler is legacy behavior." + " Prefer setting the HIP compiler directly. See README for details.") + endif() +else() + # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_ARGETS}) + endif() + cmake_minimum_required(VERSION 3.21) + enable_language(HIP) +endif() + + +find_package(hip REQUIRED) +find_package(miopen REQUIRED) +find_package(rccl REQUIRED) file(GLOB_RECURSE SRC CONFIGURE_DEPENDS LIST_DIRECTORIES False - src/*.cc - src/cuda/ops/*.cu + # src/*.cc + src/hip/concat_kernels.cpp ) add_library( @@ -25,9 +62,9 @@ target_include_directories( target_link_libraries( ${project_target} op-attrs - cuda - cudnn - nccl + MIOpen + hip::host + rccl ) define_ff_vars(${project_target}) @@ -35,5 +72,5 @@ define_ff_vars(${project_target}) set_target_properties( ${project_target} PROPERTIES - CUDA_STANDARD 17 -) + HIP_STANDARD 17 +) \ No newline at end of file diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index 1ad1f012bb..7215684f7e 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -12,7 +12,6 @@ ff_add_library( visit_struct fmt json - miopen # cuda ) From 6c74c892647158c385c85196dda90b9ececb5b96 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Wed, 29 May 2024 14:10:57 -0400 Subject: [PATCH 3/9] fix hip macro modes --- lib/kernels/include/kernels/ff_handle.h | 8 +++++--- lib/kernels/include/kernels/nccl.h | 9 +++++++-- lib/kernels/src/device.cc | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lib/kernels/include/kernels/ff_handle.h b/lib/kernels/include/kernels/ff_handle.h index 89df04e3c1..116f04f35a 100644 --- a/lib/kernels/include/kernels/ff_handle.h +++ b/lib/kernels/include/kernels/ff_handle.h @@ -1,7 +1,9 @@ #ifndef _FLEXFLOW_KERNELS_FF_HANDLE_H #define _FLEXFLOW_KERNELS_FF_HANDLE_H -#ifdef FF_USE_NCCL +#ifdef FF_USE_HIP_ROCM +#include +#elif FF_USE_NCCL #include #endif @@ -18,12 +20,12 @@ struct PerDeviceFFHandle { size_t workSpaceSize; bool allowTensorOpMathConversion; -#ifdef FF_USE_NCCL +#if defined(FF_USE_HIP_ROCM) || defined(FF_USE_NCCL) ncclComm_t ncclComm; #endif }; -#ifdef FF_USE_NCCL +#if defined(FF_USE_HIP_ROCM) || defined(FF_USE_NCCL) FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(PerDeviceFFHandle, dnn, blas, diff --git a/lib/kernels/include/kernels/nccl.h b/lib/kernels/include/kernels/nccl.h index b8a6784676..5a074bbfad 100644 --- a/lib/kernels/include/kernels/nccl.h +++ b/lib/kernels/include/kernels/nccl.h @@ -1,10 +1,15 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_NCCL_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_NCCL_H -#ifdef FF_USE_NCCL +#ifdef FF_USE_HIP_ROCM +#include +#elif FF_USE_NCCL +#include +#endif + +#if defined(FF_USE_HIP_ROCM) || defined(FF_USE_NCCL) #include #include -#include #define checkNCCL(cmd) \ do { \ diff --git a/lib/kernels/src/device.cc b/lib/kernels/src/device.cc index 0df5e84ee9..7156957df1 100644 --- a/lib/kernels/src/device.cc +++ b/lib/kernels/src/device.cc @@ -39,7 +39,7 @@ ffError_t #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) return cudaEventElapsedTime(elapsed, start, stop); #elif defined(FF_USE_HIP_ROCM) - return cudaEventElapsedTime(elapsed, start, stop); + return hipEventElapsedTime(elapsed, start, stop); #endif } From 625e13f876854966a267c7452a716b487bd9ffa8 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Tue, 4 Jun 2024 23:41:46 -0400 Subject: [PATCH 4/9] merge --- .editorconfig | 22 + .flake/pkgs/hpp2plantuml.nix | 11 + .gitattributes | 2 + .github/workflows/helpers/cmake_cuda.sh | 1 + .github/workflows/per-lib-check.yml | 46 +- .proj.toml | 7 +- CMakeLists.txt | 7 +- codecov.yml | 22 + flake.lock | 6 +- flake.nix | 8 +- lib/compiler/test/src/test_optimal_cost.cc | 7 +- lib/kernels/include/kernels/array_shape.h | 7 +- lib/kernels/include/kernels/cast_kernels.h | 2 +- lib/kernels/include/kernels/conv_2d_kernels.h | 2 +- .../include/kernels/element_binary_kernels.h | 2 +- .../include/kernels/element_unary_kernels.h | 25 +- lib/kernels/include/kernels/gather_kernels.h | 24 +- lib/kernels/include/kernels/legion_dim.h | 8 +- .../include/kernels/legion_dim_t.dtg.h | 54 ++ .../include/kernels/legion_dim_t.struct.toml | 14 + lib/kernels/include/kernels/pool_2d_kernels.h | 2 +- lib/kernels/include/kernels/reduce_kernels.h | 2 +- .../include/kernels/transpose_kernels.h | 2 +- lib/kernels/src/cuda/ops/concat_kernels.cu | 11 +- .../src/cuda/ops/element_binary_kernels.cu | 53 +- .../src/cuda/ops/element_unary_kernels.cu | 159 ++-- lib/kernels/src/cuda/ops/gather_kernels.cu | 162 ++-- lib/kernels/src/cuda/ops/linear_kernels.cu | 5 +- lib/kernels/src/cuda/ops/reduce_kernels.cu | 4 +- lib/kernels/src/cuda/ops/transpose_kernels.cu | 8 +- lib/kernels/src/device.h | 2 +- .../src/hip/element_binary_kernels.cpp | 405 +++++---- lib/kernels/src/hip/element_unary_kernels.cpp | 299 +++--- lib/kernels/src/hip/embedding_kernels.cpp | 448 +++++---- lib/kernels/src/hip/partition_kernels.cpp | 24 +- lib/kernels/src/hip/pool_2d_kernels.cpp | 130 +-- lib/kernels/src/hip/softmax_kernels.cpp | 34 +- lib/kernels/src/hip/split_kernels.cpp | 4 +- lib/kernels/src/hip/topk_kernels.cpp | 16 +- lib/kernels/src/hip/transpose_kernels.cpp | 119 ++- lib/kernels/src/kernels/legion_dim_t.dtg.cc | 69 ++ lib/local-execution/CMakeLists.txt | 1 + .../include/local-execution}/arg_ref.h | 32 +- .../include/local-execution/concrete_arg.h | 55 ++ .../include/local-execution}/config.h | 24 +- .../include/local-execution}/cost_metrics.h | 4 +- .../local-execution}/device_specific.h | 17 +- .../local-execution}/legion_tensor_shape.h | 4 +- .../{ => local-execution}/local_allocator.h | 4 +- .../include/local-execution}/op_arg_ref.h | 16 +- .../local-execution/op_task_invocation.h | 97 ++ .../local-execution}/op_task_signature.h | 60 +- .../include/local-execution/op_tensor_spec.h | 21 + .../include/local-execution}/permissions.h | 9 +- .../include/local-execution}/profiling.h | 13 +- .../local-execution}/runtime_arg_ref.h | 8 +- .../include/local-execution}/serialization.h | 130 +-- .../local-execution}/sim_environment.h | 11 +- .../include/local-execution}/slot_id.h | 4 +- .../include/local-execution}/slot_type.h | 4 +- .../local-execution/task_argument_accessor.h | 155 ++++ .../include/local-execution}/tasks.h | 11 +- .../{ => local-execution}/tracked_allocator.h | 2 +- .../local-execution/variadic_tensor_ref.h | 18 + lib/local-execution/src/local_allocator.cc | 2 +- lib/local-execution/src/op_arg_ref.cc | 14 + lib/local-execution/src/op_task_invocation.cc | 100 ++ lib/local-execution/src/op_task_signature.cc | 81 ++ .../src/ops/attention.cc | 108 ++- .../src/ops/attention.h | 4 +- .../src/ops/batch_matmul.cc | 37 +- .../src/ops/batch_matmul.h | 6 +- .../src/ops/batch_norm.cc | 68 +- .../src/ops/batch_norm.h | 4 +- .../src/ops/cast.cc | 38 +- .../src/ops/cast.h | 4 +- .../src/ops/combine.cc | 35 +- .../src/ops/combine.h | 4 +- .../src/ops/concat.cc | 43 +- .../src/ops/concat.h | 4 +- .../src/ops/conv_2d.cc | 74 +- .../src/ops/conv_2d.h | 4 +- .../src/ops/dropout.cc | 54 +- .../src/ops/dropout.h | 6 +- .../src/ops/element_binary.cc | 62 +- .../src/ops/element_binary.h | 2 +- .../src/ops/element_unary.cc | 60 +- .../src/ops/element_unary.h | 7 +- .../src/ops/embedding.cc | 39 +- .../src/ops/embedding.h | 4 +- .../src/ops/flat.cc | 33 +- .../src/ops/flat.h | 2 +- lib/local-execution/src/ops/gather.cc | 215 +++++ lib/local-execution/src/ops/gather.h | 30 + .../src/ops/layer_norm.cc | 134 ++- .../src/ops/layer_norm.h | 4 +- .../src/ops/linear.cc | 206 ++--- .../src/ops/linear.h | 4 +- .../src/ops/noop.cc | 2 +- .../src/ops/noop.h | 2 +- .../src/ops/parallel_op.h | 2 +- .../src/ops/partition.cc | 113 +-- .../src/ops/pool_2d.cc | 140 ++- .../src/ops/pool_2d.h | 6 +- .../src/ops/reduce.cc | 112 +-- .../src/ops/reduce.h | 4 +- .../src/ops/reduction.cc | 92 +- .../src/ops/reduction.h | 6 +- .../src/ops/repartition.h | 4 +- .../src/ops/replicate.cc | 72 +- .../src/ops/replicate.h | 4 +- .../src/ops/reshape.cc | 78 +- .../src/ops/reshape.h | 4 +- .../src/ops/reverse.cc | 81 +- .../src/ops/reverse.h | 4 +- .../src/ops/softmax.cc | 92 +- .../src/ops/softmax.h | 4 +- .../src/ops/split.cc | 136 ++- .../src/ops/split.h | 6 +- .../src/ops/topk.cc | 130 +-- .../src/ops/topk.h | 4 +- .../src/ops/transpose.cc | 114 +-- .../src/ops/transpose.h | 4 +- .../src/permissions.cc | 32 +- .../src}/runtime_arg_ref.cc | 8 +- lib/local-execution/src/tracked_allocator.cc | 2 +- .../src/variadic_tensor_ref.cc | 9 + lib/op-attrs/CMakeLists.txt | 1 + .../include/op-attrs/activation.dtg.h | 40 + .../include/op-attrs/activation.enum.toml | 20 + lib/op-attrs/include/op-attrs/activation.h | 42 - .../include/op-attrs/aggregate_op.dtg.h | 40 + .../include/op-attrs/aggregate_op.enum.toml | 14 + lib/op-attrs/include/op-attrs/as_dot.h | 15 + .../op-attrs/computation_graph_op_attrs.dtg.h | 471 ++++++++++ .../op-attrs/computation_graph_op_attrs.h | 12 + .../computation_graph_op_attrs.variant.toml | 148 +++ lib/op-attrs/include/op-attrs/datatype.dtg.h | 40 + .../include/op-attrs/datatype.enum.toml | 26 + lib/op-attrs/include/op-attrs/datatype.h | 40 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 69 +- .../include/op-attrs/dim_ordered/slice.h | 54 ++ .../include/op-attrs/dim_ordered/transform.h | 20 + lib/op-attrs/include/op-attrs/ff_dim.dtg.h | 54 ++ lib/op-attrs/include/op-attrs/ff_dim.h | 26 +- .../include/op-attrs/ff_dim.struct.toml | 14 + lib/op-attrs/include/op-attrs/get_op_type.h | 47 +- .../include/op-attrs/get_output_shapes.h | 20 - .../op-attrs/l1_regularizer_attrs.dtg.h | 62 ++ .../op-attrs/l1_regularizer_attrs.struct.toml | 14 + .../op-attrs/l2_regularizer_attrs.dtg.h | 62 ++ .../op-attrs/l2_regularizer_attrs.struct.toml | 14 + lib/op-attrs/include/op-attrs/op.h | 369 -------- .../include/op-attrs/operator_attrs.h | 89 +- .../include/op-attrs/operator_type.dtg.h | 124 +++ .../include/op-attrs/operator_type.enum.toml | 95 ++ lib/op-attrs/include/op-attrs/operator_type.h | 13 + lib/op-attrs/include/op-attrs/ops/attention.h | 93 +- .../multihead_attention_inputs.dtg.h | 74 ++ .../attention/multihead_attention_inputs.h | 17 + .../multihead_attention_inputs.struct.toml | 39 + .../multihead_attention_parallel_inputs.dtg.h | 82 ++ .../multihead_attention_parallel_inputs.h | 17 + ...head_attention_parallel_inputs.struct.toml | 46 + .../op-attrs/ops/attention_attrs.dtg.h | 76 ++ .../op-attrs/ops/attention_attrs.struct.toml | 43 + .../include/op-attrs/ops/batch_matmul.dtg.h | 63 ++ .../include/op-attrs/ops/batch_matmul.h | 28 +- .../op-attrs/ops/batch_matmul.struct.toml | 19 + .../include/op-attrs/ops/batch_norm.h | 10 +- .../op-attrs/ops/batch_norm_attrs.dtg.h | 62 ++ .../op-attrs/ops/batch_norm_attrs.struct.toml | 15 + .../include/op-attrs/ops/broadcast.dtg.h | 64 ++ lib/op-attrs/include/op-attrs/ops/broadcast.h | 17 +- .../op-attrs/ops/broadcast.struct.toml | 18 + lib/op-attrs/include/op-attrs/ops/cast.h | 9 +- .../include/op-attrs/ops/cast_attrs.dtg.h | 63 ++ .../op-attrs/ops/cast_attrs.struct.toml | 18 + lib/op-attrs/include/op-attrs/ops/combine.h | 20 +- .../include/op-attrs/ops/combine_attrs.dtg.h | 66 ++ .../op-attrs/ops/combine_attrs.struct.toml | 23 + lib/op-attrs/include/op-attrs/ops/concat.h | 9 +- .../include/op-attrs/ops/concat_attrs.dtg.h | 65 ++ .../op-attrs/ops/concat_attrs.struct.toml | 23 + lib/op-attrs/include/op-attrs/ops/conv_2d.h | 35 +- .../ops/conv_2d/conv_2d_input_shape.dtg.h | 72 ++ .../ops/conv_2d/conv_2d_input_shape.h | 13 + .../conv_2d/conv_2d_input_shape.struct.toml | 35 + .../conv_2d_parallel_input_shape.dtg.h | 76 ++ .../conv_2d/conv_2d_parallel_input_shape.h | 14 + .../conv_2d_parallel_input_shape.struct.toml | 43 + .../include/op-attrs/ops/conv_2d_attrs.dtg.h | 83 ++ .../op-attrs/ops/conv_2d_attrs.struct.toml | 29 + lib/op-attrs/include/op-attrs/ops/dropout.h | 12 +- .../include/op-attrs/ops/dropout_attrs.dtg.h | 63 ++ .../op-attrs/ops/dropout_attrs.struct.toml | 19 + .../include/op-attrs/ops/element_binary.h | 25 +- .../op-attrs/ops/element_binary_attrs.dtg.h | 70 ++ .../ops/element_binary_attrs.struct.toml | 32 + .../ops/element_scalar_unary_attrs.dtg.h | 65 ++ .../element_scalar_unary_attrs.struct.toml | 23 + .../include/op-attrs/ops/element_unary.h | 34 +- .../op-attrs/ops/element_unary_attrs.dtg.h | 63 ++ .../ops/element_unary_attrs.struct.toml | 19 + lib/op-attrs/include/op-attrs/ops/embedding.h | 49 +- .../op-attrs/ops/embedding_attrs.dtg.h | 71 ++ .../op-attrs/ops/embedding_attrs.struct.toml | 32 + lib/op-attrs/include/op-attrs/ops/flat.h | 4 +- .../include/op-attrs/ops/flat_attrs.dtg.h | 58 ++ .../op-attrs/ops/flat_attrs.struct.toml | 11 + lib/op-attrs/include/op-attrs/ops/gather.h | 7 +- .../include/op-attrs/ops/gather_attrs.dtg.h | 64 ++ .../op-attrs/ops/gather_attrs.struct.toml | 19 + lib/op-attrs/include/op-attrs/ops/input.h | 7 +- .../include/op-attrs/ops/input_attrs.dtg.h | 58 ++ .../op-attrs/ops/input_attrs.struct.toml | 11 + .../include/op-attrs/ops/layer_norm.h | 12 +- .../op-attrs/ops/layer_norm_attrs.dtg.h | 70 ++ .../op-attrs/ops/layer_norm_attrs.struct.toml | 28 + lib/op-attrs/include/op-attrs/ops/linear.h | 49 +- .../include/op-attrs/ops/linear_attrs.dtg.h | 74 ++ .../op-attrs/ops/linear_attrs.struct.toml | 37 + lib/op-attrs/include/op-attrs/ops/noop.h | 8 +- .../include/op-attrs/ops/noop_attrs.dtg.h | 58 ++ .../op-attrs/ops/noop_attrs.struct.toml | 11 + .../ops/parallel_attention_inputs.dtg.h | 66 ++ .../ops/parallel_attention_inputs.struct.toml | 26 + lib/op-attrs/include/op-attrs/ops/pool_2d.h | 51 +- .../include/op-attrs/ops/pool_2d_attrs.dtg.h | 78 ++ .../op-attrs/ops/pool_2d_attrs.struct.toml | 47 + lib/op-attrs/include/op-attrs/ops/reduce.h | 18 +- .../include/op-attrs/ops/reduce_attrs.dtg.h | 71 ++ .../op-attrs/ops/reduce_attrs.struct.toml | 29 + lib/op-attrs/include/op-attrs/ops/reduction.h | 17 +- .../op-attrs/ops/reduction_attrs.dtg.h | 62 ++ .../op-attrs/ops/reduction_attrs.struct.toml | 14 + .../include/op-attrs/ops/repartition.h | 17 +- .../op-attrs/ops/repartition_attrs.dtg.h | 66 ++ .../ops/repartition_attrs.struct.toml | 23 + lib/op-attrs/include/op-attrs/ops/replicate.h | 13 +- .../op-attrs/ops/replicate_attrs.dtg.h | 62 ++ .../op-attrs/ops/replicate_attrs.struct.toml | 16 + lib/op-attrs/include/op-attrs/ops/reshape.h | 11 +- .../include/op-attrs/ops/reshape_attrs.dtg.h | 63 ++ .../op-attrs/ops/reshape_attrs.struct.toml | 18 + lib/op-attrs/include/op-attrs/ops/reverse.h | 11 +- .../include/op-attrs/ops/reverse_attrs.dtg.h | 64 ++ .../op-attrs/ops/reverse_attrs.struct.toml | 19 + lib/op-attrs/include/op-attrs/ops/softmax.h | 12 +- .../include/op-attrs/ops/softmax_attrs.dtg.h | 64 ++ .../op-attrs/ops/softmax_attrs.struct.toml | 19 + lib/op-attrs/include/op-attrs/ops/split.h | 15 +- .../include/op-attrs/ops/split_attrs.dtg.h | 67 ++ .../op-attrs/ops/split_attrs.struct.toml | 24 + lib/op-attrs/include/op-attrs/ops/topk.h | 12 +- .../include/op-attrs/ops/topk_attrs.dtg.h | 63 ++ .../op-attrs/ops/topk_attrs.struct.toml | 18 + lib/op-attrs/include/op-attrs/ops/transpose.h | 12 +- .../op-attrs/ops/transpose_attrs.dtg.h | 65 ++ .../op-attrs/ops/transpose_attrs.struct.toml | 20 + .../include/op-attrs/ops/weight_attrs.dtg.h | 58 ++ .../op-attrs/ops/weight_attrs.struct.toml | 11 + .../include/op-attrs/parallel_dim.dtg.h | 128 +++ lib/op-attrs/include/op-attrs/parallel_dim.h | 17 +- .../op-attrs/parallel_dim.variant.toml | 23 + .../op-attrs/parallel_tensor_dims.dtg.h | 71 ++ .../include/op-attrs/parallel_tensor_dims.h | 57 +- .../op-attrs/parallel_tensor_dims.struct.toml | 27 + .../op-attrs/parallel_tensor_shape.dtg.h | 66 ++ .../include/op-attrs/parallel_tensor_shape.h | 52 +- .../parallel_tensor_shape.struct.toml | 23 + .../discard_copy_degree.dtg.h | 62 ++ .../discard_copy_degree.struct.toml | 14 + .../parallel_tensor_shape/sum_degree.dtg.h | 62 ++ .../sum_degree.struct.toml | 14 + .../include/op-attrs/param_sync.dtg.h | 40 + .../include/op-attrs/param_sync.enum.toml | 14 + lib/op-attrs/include/op-attrs/param_sync.h | 32 +- .../include/op-attrs/pcg_operator_attrs.dtg.h | 495 ++++++++++ .../include/op-attrs/pcg_operator_attrs.h | 13 + .../op-attrs/pcg_operator_attrs.variant.toml | 158 ++++ lib/op-attrs/include/op-attrs/pool_op.dtg.h | 40 + .../include/op-attrs/pool_op.enum.toml | 14 + .../include/op-attrs/regularizer_attrs.dtg.h | 128 +++ .../op-attrs/regularizer_attrs.variant.toml | 23 + .../op-attrs/replica_parallel_dim.dtg.h | 65 ++ .../include/op-attrs/replica_parallel_dim.h | 12 + .../op-attrs/replica_parallel_dim.struct.toml | 22 + .../op-attrs/replica_parallel_dim_set.dtg.h | 67 ++ .../op-attrs/replica_parallel_dim_set.h | 18 + .../replica_parallel_dim_set.struct.toml | 23 + .../include/op-attrs/replica_type.dtg.h | 40 + .../include/op-attrs/replica_type.enum.toml | 14 + .../include/op-attrs/shard_parallel_dim.dtg.h | 63 ++ .../include/op-attrs/shard_parallel_dim.h | 12 + .../op-attrs/shard_parallel_dim.struct.toml | 18 + .../include/op-attrs/tensor_dims.dtg.h | 63 ++ lib/op-attrs/include/op-attrs/tensor_dims.h | 24 + .../include/op-attrs/tensor_dims.struct.toml | 17 + .../include/op-attrs/tensor_shape.dtg.h | 66 ++ lib/op-attrs/include/op-attrs/tensor_shape.h | 28 +- .../include/op-attrs/tensor_shape.struct.toml | 23 + lib/op-attrs/src/batch_matmul.cc | 26 - lib/op-attrs/src/batch_norm.cc | 3 - lib/op-attrs/src/broadcast.cc | 3 - lib/op-attrs/src/combine.cc | 18 - lib/op-attrs/src/conv_2d.cc | 115 --- lib/op-attrs/src/element_binary.cc | 3 - lib/op-attrs/src/element_unary.cc | 3 - lib/op-attrs/src/embedding.cc | 9 - lib/op-attrs/src/linear.cc | 3 - lib/op-attrs/src/noop.cc | 1 - lib/op-attrs/src/op-attrs/activation.dtg.cc | 86 ++ lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc | 62 ++ lib/op-attrs/src/op-attrs/as_dot.cc | 13 + .../op-attrs/computation_graph_op_attrs.cc | 11 + .../computation_graph_op_attrs.dtg.cc | 597 ++++++++++++ lib/op-attrs/src/op-attrs/datatype.dtg.cc | 102 +++ lib/op-attrs/src/op-attrs/ff_dim.dtg.cc | 68 ++ .../src/{ => op-attrs}/get_op_type.cc | 61 +- .../src/op-attrs/l1_regularizer_attrs.dtg.cc | 76 ++ .../src/op-attrs/l2_regularizer_attrs.dtg.cc | 76 ++ lib/op-attrs/src/op-attrs/operator_type.cc | 24 + .../src/op-attrs/operator_type.dtg.cc | 720 +++++++++++++++ .../src/{ => op-attrs/ops}/attention.cc | 180 +++- .../attention/multihead_attention_inputs.cc | 80 ++ .../multihead_attention_inputs.dtg.cc | 185 ++++ .../multihead_attention_parallel_inputs.cc | 132 +++ ...multihead_attention_parallel_inputs.dtg.cc | 209 +++++ .../src/op-attrs/ops/attention_attrs.dtg.cc | 220 +++++ lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 176 ++++ .../src/op-attrs/ops/batch_matmul.dtg.cc | 90 ++ lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 10 + .../src/op-attrs/ops/batch_norm_attrs.dtg.cc | 75 ++ .../src/op-attrs/ops/broadcast.dtg.cc | 81 ++ lib/op-attrs/src/{ => op-attrs/ops}/cast.cc | 0 .../src/op-attrs/ops/cast_attrs.dtg.cc | 76 ++ lib/op-attrs/src/op-attrs/ops/combine.cc | 37 + .../src/op-attrs/ops/combine_attrs.dtg.cc | 91 ++ lib/op-attrs/src/{ => op-attrs/ops}/concat.cc | 0 .../src/op-attrs/ops/concat_attrs.dtg.cc | 91 ++ lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 176 ++++ .../ops/conv_2d/conv_2d_input_shape.cc | 23 + .../ops/conv_2d/conv_2d_input_shape.dtg.cc | 157 ++++ .../conv_2d/conv_2d_parallel_input_shape.cc | 26 + .../conv_2d_parallel_input_shape.dtg.cc | 211 +++++ .../src/op-attrs/ops/conv_2d_attrs.dtg.cc | 256 ++++++ lib/op-attrs/src/op-attrs/ops/dropout.cc | 10 + .../src/op-attrs/ops/dropout_attrs.dtg.cc | 82 ++ .../src/op-attrs/ops/element_binary.cc | 75 ++ .../op-attrs/ops/element_binary_attrs.dtg.cc | 145 +++ .../ops/element_scalar_unary_attrs.dtg.cc | 98 ++ .../src/op-attrs/ops/element_unary.cc | 54 ++ .../op-attrs/ops/element_unary_attrs.dtg.cc | 79 ++ lib/op-attrs/src/op-attrs/ops/embedding.cc | 112 +++ .../src/op-attrs/ops/embedding_attrs.dtg.cc | 139 +++ lib/op-attrs/src/{ => op-attrs/ops}/flat.cc | 2 - .../src/op-attrs/ops/flat_attrs.dtg.cc | 70 ++ lib/op-attrs/src/{ => op-attrs/ops}/gather.cc | 0 .../src/op-attrs/ops/gather_attrs.dtg.cc | 78 ++ lib/op-attrs/src/op-attrs/ops/input.cc | 9 + .../src/op-attrs/ops/input_attrs.dtg.cc | 70 ++ lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 10 + .../src/op-attrs/ops/layer_norm_attrs.dtg.cc | 108 +++ lib/op-attrs/src/op-attrs/ops/linear.cc | 110 +++ .../src/op-attrs/ops/linear_attrs.dtg.cc | 162 ++++ lib/op-attrs/src/op-attrs/ops/noop.cc | 10 + .../src/op-attrs/ops/noop_attrs.dtg.cc | 70 ++ .../ops/parallel_attention_inputs.dtg.cc | 88 ++ .../src/{ => op-attrs/ops}/pool_2d.cc | 29 +- .../src/op-attrs/ops/pool_2d_attrs.dtg.cc | 214 +++++ lib/op-attrs/src/op-attrs/ops/reduce.cc | 10 + .../src/op-attrs/ops/reduce_attrs.dtg.cc | 109 +++ lib/op-attrs/src/op-attrs/ops/reduction.cc | 22 + .../src/op-attrs/ops/reduction_attrs.dtg.cc | 76 ++ lib/op-attrs/src/op-attrs/ops/repartition.cc | 14 + .../src/op-attrs/ops/repartition_attrs.dtg.cc | 93 ++ lib/op-attrs/src/op-attrs/ops/replicate.cc | 13 + .../src/op-attrs/ops/replicate_attrs.dtg.cc | 76 ++ lib/op-attrs/src/op-attrs/ops/reshape.cc | 10 + .../src/op-attrs/ops/reshape_attrs.dtg.cc | 78 ++ lib/op-attrs/src/op-attrs/ops/reverse.cc | 10 + .../src/op-attrs/ops/reverse_attrs.dtg.cc | 78 ++ lib/op-attrs/src/op-attrs/ops/softmax.cc | 10 + .../src/op-attrs/ops/softmax_attrs.dtg.cc | 78 ++ lib/op-attrs/src/op-attrs/ops/split.cc | 11 + .../src/op-attrs/ops/split_attrs.dtg.cc | 96 ++ lib/op-attrs/src/op-attrs/ops/topk.cc | 10 + .../src/op-attrs/ops/topk_attrs.dtg.cc | 79 ++ lib/op-attrs/src/op-attrs/ops/transpose.cc | 10 + .../src/op-attrs/ops/transpose_attrs.dtg.cc | 82 ++ .../src/op-attrs/ops/weight_attrs.dtg.cc | 70 ++ lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc | 116 +++ .../src/op-attrs/parallel_tensor_dims.cc | 72 ++ .../src/op-attrs/parallel_tensor_dims.dtg.cc | 101 +++ .../src/op-attrs/parallel_tensor_shape.cc | 87 ++ .../src/op-attrs/parallel_tensor_shape.dtg.cc | 94 ++ .../discard_copy_degree.dtg.cc | 76 ++ .../parallel_tensor_shape/sum_degree.dtg.cc | 75 ++ lib/op-attrs/src/op-attrs/param_sync.dtg.cc | 70 ++ .../src/op-attrs/pcg_operator_attrs.cc | 16 + .../src/op-attrs/pcg_operator_attrs.dtg.cc | 599 ++++++++++++ lib/op-attrs/src/op-attrs/pool_op.dtg.cc | 70 ++ .../src/op-attrs/regularizer_attrs.dtg.cc | 118 +++ .../src/op-attrs/replica_parallel_dim.cc | 9 + .../src/op-attrs/replica_parallel_dim.dtg.cc | 91 ++ .../src/op-attrs/replica_parallel_dim_set.cc | 36 + .../op-attrs/replica_parallel_dim_set.dtg.cc | 101 +++ lib/op-attrs/src/op-attrs/replica_type.dtg.cc | 70 ++ .../src/op-attrs/shard_parallel_dim.cc | 9 + .../src/op-attrs/shard_parallel_dim.dtg.cc | 89 ++ lib/op-attrs/src/op-attrs/tensor_dims.cc | 52 ++ lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc | 78 ++ lib/op-attrs/src/op-attrs/tensor_shape.cc | 18 + lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc | 92 ++ lib/op-attrs/src/op.cc | 24 - lib/op-attrs/src/operator_attrs.cc | 2 +- lib/op-attrs/src/parallel_dim.cc | 12 - .../src/parallel_dim_mapping_record_solver.cc | 362 -------- .../src/parallel_dim_mapping_record_solver.h | 106 --- lib/op-attrs/src/parallel_tensor_shape.cc | 97 -- lib/op-attrs/src/reduce.cc | 3 - lib/op-attrs/src/reduction.cc | 13 - lib/op-attrs/src/repartition.cc | 11 - lib/op-attrs/src/replicate.cc | 3 - lib/op-attrs/src/reshape.cc | 3 - lib/op-attrs/src/softmax.cc | 3 - lib/op-attrs/src/split.cc | 3 - lib/op-attrs/src/tensor_shape.cc | 13 - lib/op-attrs/src/topk.cc | 3 - lib/op-attrs/src/transpose.cc | 3 - lib/op-attrs/test/CMakeLists.txt | 13 + lib/op-attrs/test/src/dim_ordered/slice.cc | 23 + lib/op-attrs/test/src/ops/combine.cc | 59 ++ lib/op-attrs/test/src/ops/linear.cc | 231 +++++ lib/op-attrs/test/src/ops/reduction.cc | 55 ++ lib/op-attrs/test/src/ops/repartition.cc | 40 + lib/op-attrs/test/src/ops/replicate.cc | 33 + lib/op-attrs/test/src/test_attention.cc | 272 ++++++ lib/op-attrs/test/src/test_batch_matmul.cc | 268 ++++++ lib/op-attrs/test/src/test_conv_2d.cc | 62 ++ lib/op-attrs/test/src/test_dim_ordered.cc | 13 + lib/op-attrs/test/src/test_element_binary.cc | 162 ++++ lib/op-attrs/test/src/test_element_unary.cc | 73 ++ lib/op-attrs/test/src/test_embedding.cc | 160 ++++ lib/op-attrs/test/src/test_operator_attrs.cc | 35 + .../test/src/test_regularizer_attrs.cc | 14 + lib/pcg/CMakeLists.txt | 1 + lib/pcg/include/pcg/computation_graph.dtg.h | 29 + lib/pcg/include/pcg/computation_graph.h | 60 +- .../include/pcg/computation_graph.struct.toml | 13 + .../layer_added_result.dtg.h | 37 + .../layer_added_result.struct.toml | 19 + .../include/pcg/computation_graph_builder.h | 115 ++- lib/pcg/include/pcg/cpu_id_t.dtg.h | 62 ++ lib/pcg/include/pcg/cpu_id_t.struct.toml | 14 + lib/pcg/include/pcg/create_grad.dtg.h | 40 + lib/pcg/include/pcg/create_grad.enum.toml | 14 + lib/pcg/include/pcg/create_grad.h | 32 +- lib/pcg/include/pcg/dataflow_graph.h | 77 ++ lib/pcg/include/pcg/dataflow_input.dtg.h | 101 +++ .../include/pcg/dataflow_input.variant.toml | 21 + lib/pcg/include/pcg/device_id.h | 24 +- lib/pcg/include/pcg/device_id_t.dtg.h | 117 +++ lib/pcg/include/pcg/device_id_t.variant.toml | 22 + lib/pcg/include/pcg/device_type.dtg.h | 40 + lib/pcg/include/pcg/device_type.enum.toml | 14 + lib/pcg/include/pcg/device_type.h | 36 - .../v1/{data_type.h => data_type_value.h} | 17 - lib/pcg/include/pcg/file_format/v1/graphs.h | 68 +- .../file_format/v1/graphs/v1_graph_edge.dtg.h | 60 ++ .../v1/graphs/v1_graph_edge.struct.toml | 26 + .../v1/graphs/v1_graph_output.dtg.h | 55 ++ .../v1/graphs/v1_graph_output.struct.toml | 18 + .../v1/graphs/v1_jsonable_graph.dtg.h | 109 +++ .../v1/graphs/v1_jsonable_graph.struct.toml | 38 + .../v1/graphs/v1_multidigraph.dtg.h | 47 + .../file_format/v1/graphs/v1_multidigraph.h | 16 + .../v1/graphs/v1_multidigraph.struct.toml | 29 + .../v1/graphs/v1_operator_graph.dtg.h | 45 + .../v1/graphs/v1_operator_graph.struct.toml | 25 + .../include/pcg/file_format/v1/initializer.h | 57 -- .../pcg/file_format/v1/operator_attrs.h | 20 - .../pcg/file_format/v1/parallel_tensor.h | 37 - .../include/pcg/file_format/v1/param_sync.h | 16 - lib/pcg/include/pcg/file_format/v1/tensor.h | 36 - lib/pcg/include/pcg/gpu_id_t.dtg.h | 62 ++ lib/pcg/include/pcg/gpu_id_t.struct.toml | 14 + lib/pcg/include/pcg/initializer.h | 50 - lib/pcg/include/pcg/initializer_attrs.dtg.h | 169 ++++ .../pcg/initializer_attrs.variant.toml | 37 + .../constant_initializer_attrs.dtg.h | 56 ++ .../constant_initializer_attrs.struct.toml | 19 + .../initializers/glorot_uniform_attrs.dtg.h | 62 ++ .../glorot_uniform_attrs.struct.toml | 14 + .../initializers/norm_initializer_attrs.dtg.h | 64 ++ .../norm_initializer_attrs.struct.toml | 22 + .../uniform_initializer_attrs.dtg.h | 58 ++ .../uniform_initializer_attrs.struct.toml | 22 + .../initializers/zero_initializer_attrs.dtg.h | 58 ++ .../zero_initializer_attrs.struct.toml | 11 + lib/pcg/include/pcg/layer.h | 33 - lib/pcg/include/pcg/layer_attrs.dtg.h | 60 ++ lib/pcg/include/pcg/layer_attrs.struct.toml | 26 + lib/pcg/include/pcg/layer_guid_t.dtg.h | 46 + lib/pcg/include/pcg/layer_guid_t.struct.toml | 16 + .../include/pcg/machine_specification.dtg.h | 62 ++ lib/pcg/include/pcg/machine_specification.h | 27 +- .../pcg/machine_specification.struct.toml | 30 + lib/pcg/include/pcg/machine_view.dtg.h | 58 ++ lib/pcg/include/pcg/machine_view.h | 27 +- lib/pcg/include/pcg/machine_view.struct.toml | 23 + lib/pcg/include/pcg/num_points_t.dtg.h | 62 ++ lib/pcg/include/pcg/num_points_t.struct.toml | 14 + lib/pcg/include/pcg/open_dataflow_graph.h | 81 ++ lib/pcg/include/pcg/operator.h | 27 - .../pcg/operator_graph/operator_graph.h | 80 ++ .../operator_graph/operator_graph_input.dtg.h | 47 + .../pcg/operator_graph/operator_graph_input.h | 13 + .../operator_graph_input.struct.toml | 20 + .../operator_graph_output.dtg.h | 47 + .../operator_graph/operator_graph_output.h | 13 + .../operator_graph_output.struct.toml | 20 + lib/pcg/include/pcg/operator_guid_t.dtg.h | 46 + lib/pcg/include/pcg/operator_guid_t.h | 18 - .../include/pcg/operator_guid_t.struct.toml | 18 + lib/pcg/include/pcg/optimizer.h | 41 - lib/pcg/include/pcg/optimizer_attrs.h | 14 + .../pcg/optimizers/adam_optimizer_attrs.dtg.h | 74 ++ .../adam_optimizer_attrs.struct.toml | 38 + .../pcg/optimizers/sgd_optimizer_attrs.dtg.h | 68 ++ .../sgd_optimizer_attrs.struct.toml | 26 + .../pcg/parallel_computation_graph.dtg.h | 31 + .../include/pcg/parallel_computation_graph.h | 27 +- .../parallel_computation_graph.struct.toml | 13 + .../include/pcg/parallel_layer_attrs.dtg.h | 60 ++ .../pcg/parallel_layer_attrs.struct.toml | 24 + lib/pcg/include/pcg/parallel_tensor.h | 50 +- .../include/pcg/parallel_tensor_attrs.dtg.h | 66 ++ .../pcg/parallel_tensor_attrs.struct.toml | 34 + lib/pcg/include/pcg/serialization.h | 61 -- lib/pcg/include/pcg/side_size_t.dtg.h | 62 ++ lib/pcg/include/pcg/side_size_t.struct.toml | 14 + lib/pcg/include/pcg/strided_rectangle.dtg.h | 65 ++ lib/pcg/include/pcg/strided_rectangle.h | 58 +- .../include/pcg/strided_rectangle.struct.toml | 19 + .../include/pcg/strided_rectangle_side.dtg.h | 65 ++ lib/pcg/include/pcg/strided_rectangle_side.h | 15 + .../pcg/strided_rectangle_side.struct.toml | 22 + lib/pcg/include/pcg/tensor.h | 40 - lib/pcg/include/pcg/tensor_attrs.dtg.h | 64 ++ lib/pcg/include/pcg/tensor_attrs.struct.toml | 33 + lib/pcg/include/pcg/tensor_guid_t.dtg.h | 46 + lib/pcg/include/pcg/tensor_guid_t.h | 17 - lib/pcg/include/pcg/tensor_guid_t.struct.toml | 18 + lib/pcg/src/computation_graph.cc | 76 -- lib/pcg/src/device_id.cc | 19 - lib/pcg/src/file_format/v1/graphs.cc | 129 ++- lib/pcg/src/file_format/v1/v1.cc | 13 - lib/pcg/src/layer.cc | 9 - lib/pcg/src/machine_view.cc | 29 - lib/pcg/src/operator.cc | 9 - lib/pcg/src/parallel_computation_graph.cc | 40 - lib/pcg/src/parallel_tensor.cc | 17 - lib/pcg/src/pcg/computation_graph.cc | 20 + lib/pcg/src/pcg/computation_graph.dtg.cc | 21 + .../layer_added_result.dtg.cc | 43 + .../{ => pcg}/computation_graph_builder.cc | 384 ++++---- lib/pcg/src/pcg/cpu_id_t.dtg.cc | 74 ++ lib/pcg/src/pcg/create_grad.dtg.cc | 70 ++ lib/pcg/src/pcg/dataflow_input.dtg.cc | 41 + lib/pcg/src/pcg/device_id.cc | 32 + lib/pcg/src/pcg/device_id_t.dtg.cc | 103 +++ lib/pcg/src/pcg/device_type.dtg.cc | 70 ++ .../v1/graphs/v1_graph_edge.dtg.cc | 94 ++ .../v1/graphs/v1_graph_output.dtg.cc | 81 ++ .../v1/graphs/v1_jsonable_graph.dtg.cc | 10 + .../v1/graphs/v1_multidigraph.dtg.cc | 56 ++ .../v1/graphs/v1_operator_graph.dtg.cc | 52 ++ lib/pcg/src/pcg/gpu_id_t.dtg.cc | 74 ++ lib/pcg/src/pcg/initializer_attrs.dtg.cc | 158 ++++ .../constant_initializer_attrs.dtg.cc | 80 ++ .../initializers/glorot_uniform_attrs.dtg.cc | 76 ++ .../norm_initializer_attrs.dtg.cc | 96 ++ .../uniform_initializer_attrs.dtg.cc | 95 ++ .../zero_initializer_attrs.dtg.cc | 71 ++ lib/pcg/src/pcg/layer_attrs.dtg.cc | 84 ++ lib/pcg/src/pcg/layer_guid_t.dtg.cc | 59 ++ lib/pcg/src/pcg/machine_specification.dtg.cc | 151 ++++ lib/pcg/src/pcg/machine_view.cc | 63 ++ lib/pcg/src/pcg/machine_view.dtg.cc | 78 ++ lib/pcg/src/pcg/num_points_t.dtg.cc | 75 ++ .../src/pcg/operator_graph/operator_graph.cc | 48 + .../operator_graph/operator_graph_input.cc | 13 + .../operator_graph_input.dtg.cc | 63 ++ .../operator_graph/operator_graph_output.cc | 13 + .../operator_graph_output.dtg.cc | 63 ++ lib/pcg/src/pcg/operator_guid_t.dtg.cc | 59 ++ .../optimizers/adam_optimizer_attrs.dtg.cc | 192 ++++ .../pcg/optimizers/sgd_optimizer_attrs.dtg.cc | 111 +++ .../src/pcg/parallel_computation_graph.dtg.cc | 21 + lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc | 83 ++ lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc | 134 +++ lib/pcg/src/pcg/side_size_t.dtg.cc | 75 ++ lib/pcg/src/pcg/strided_rectangle.dtg.cc | 86 ++ lib/pcg/src/pcg/strided_rectangle_side.cc | 15 + lib/pcg/src/pcg/strided_rectangle_side.dtg.cc | 91 ++ lib/pcg/src/pcg/tensor_attrs.dtg.cc | 133 +++ lib/pcg/src/pcg/tensor_guid_t.dtg.cc | 59 ++ lib/pcg/src/serialization.cc | 3 - lib/pcg/src/strided_rectangle.cc | 39 +- lib/pcg/src/tensor.cc | 17 - lib/pcg/test/CMakeLists.txt | 13 + .../src/test_computation_graph_builder.cc | 28 + .../include/runtime/task_spec/concrete_arg.h | 46 - lib/runtime/src/ops/gather.cc | 416 --------- lib/runtime/src/ops/gather.h | 78 -- lib/runtime/src/parallel_op_info.h | 2 +- .../src/task_spec/op_task_invocation.cc | 47 - .../src/task_spec/op_task_invocation.h | 135 --- lib/runtime/src/task_spec/op_tensor_spec.h | 20 - .../src/task_spec/task_argument_accessor.h | 193 ---- .../src/task_spec/variadic_tensor_ref.h | 20 - .../include/substitution-generator/json.h | 162 +--- .../legacy_operator_type.dtg.h | 124 +++ .../legacy_operator_type.enum.toml | 95 ++ .../legacy_pm_parameter.dtg.h | 73 ++ .../legacy_pm_parameter.enum.toml | 44 + .../src/substitution-generator/json.cc | 25 +- .../legacy_operator_type.dtg.cc | 721 +++++++++++++++ .../legacy_pm_parameter.dtg.cc | 313 +++++++ .../test/substitution-generator/json.cc | 2 +- .../include/substitutions/attribute_expr.h | 40 - .../substitutions/constraint_type.dtg.h | 40 + .../substitutions/constraint_type.enum.toml | 11 + .../include/substitutions/graph_pattern.h | 32 +- .../substitutions/graph_pattern_match.h | 42 - .../include/substitutions/operator_pattern.h | 107 --- .../operator_pattern/eval_list_access.h | 17 + .../operator_pattern/eval_list_size.h | 16 + .../{ => operator_pattern}/get_attribute.h | 16 +- .../operator_attribute_constraint.dtg.h | 62 ++ .../operator_attribute_constraint.struct.toml | 28 + .../operator_attribute_expr.dtg.h | 143 +++ .../operator_attribute_expr.h | 16 + .../operator_attribute_expr.variant.toml | 27 + .../operator_attribute_key.dtg.h | 97 ++ .../operator_attribute_key.enum.toml | 67 ++ .../operator_attribute_list_access.dtg.h | 67 ++ ...operator_attribute_list_access.struct.toml | 22 + .../operator_attribute_list_size.dtg.h | 64 ++ .../operator_attribute_list_size.struct.toml | 19 + .../operator_attribute_pattern.dtg.h | 56 ++ .../operator_attribute_pattern.struct.toml | 20 + .../operator_attribute_value.dtg.h | 264 ++++++ .../operator_attribute_value.variant.toml | 63 ++ .../operator_pattern/satisfies_constraint.h | 15 + .../operator_pattern/satisfies_pattern.h | 14 + .../include/substitutions/output_graph.h | 35 - .../output_graph/attr_constant.dtg.h | 46 + .../output_graph/attr_constant.struct.toml | 16 + .../output_graph/output_graph_expr.dtg.h | 28 + .../output_graph_expr.struct.toml | 12 + .../output_operator_attr_access.dtg.h | 49 + .../output_operator_attr_access.struct.toml | 23 + .../output_operator_attribute_expr.dtg.h | 119 +++ ...utput_operator_attribute_expr.variant.toml | 21 + .../output_operator_attrs_assignment.dtg.h | 49 + ...tput_operator_attrs_assignment.struct.toml | 21 + .../substitutions/parallel_tensor_pattern.h | 25 - .../include/substitutions/pcg_pattern.dtg.h | 31 + .../substitutions/pcg_pattern.struct.toml | 12 + .../sub_parallel_computation_graph.dtg.h | 31 + .../sub_parallel_computation_graph.h | 21 +- ...sub_parallel_computation_graph.struct.toml | 13 + .../include/substitutions/substitution.dtg.h | 38 + .../include/substitutions/substitution.h | 26 +- .../substitutions/substitution.struct.toml | 24 + .../tensor_pattern/eval_list_access.h | 15 + .../tensor_pattern/eval_list_size.h | 15 + .../tensor_pattern/get_attribute.h | 15 + .../tensor_pattern/satisfies_constraint.h | 15 + .../tensor_pattern/satisfies_pattern.h | 14 + .../tensor_attribute_constraint.dtg.h | 62 ++ .../tensor_attribute_constraint.struct.toml | 28 + .../tensor_attribute_expr.dtg.h | 141 +++ .../tensor_pattern/tensor_attribute_expr.h | 15 + .../tensor_attribute_expr.variant.toml | 27 + .../tensor_pattern/tensor_attribute_key.dtg.h | 40 + .../tensor_attribute_key.enum.toml | 14 + .../tensor_attribute_list_access.dtg.h | 66 ++ .../tensor_attribute_list_access.struct.toml | 22 + .../tensor_attribute_list_size.dtg.h | 63 ++ .../tensor_attribute_list_size.struct.toml | 19 + .../tensor_attribute_pattern.dtg.h | 56 ++ .../tensor_attribute_pattern.struct.toml | 20 + .../tensor_attribute_value.dtg.h | 118 +++ .../tensor_attribute_value.variant.toml | 21 + .../unlabelled/closed_pattern_edge.dtg.h | 39 + .../closed_pattern_edge.struct.toml | 15 + .../downward_open_pattern_edge.dtg.h | 39 + .../unlabelled/downward_open_pattern_edge.h | 12 + .../downward_open_pattern_edge.struct.toml | 15 + .../unlabelled/edge_splits.dtg.h | 36 + .../substitutions/unlabelled/edge_splits.h | 21 + .../unlabelled/edge_splits.struct.toml | 15 + .../unlabelled/find_pattern_matches.h | 18 + .../unlabelled/input_pattern_edge.dtg.h | 39 + .../unlabelled/input_pattern_edge.h | 13 + .../unlabelled/input_pattern_edge.struct.toml | 15 + .../match_additional_criterion.dtg.h | 36 + .../match_additional_criterion.struct.toml | 18 + .../unlabelled/match_split.dtg.h | 29 + .../substitutions/unlabelled/match_split.h | 18 + .../unlabelled/match_split.struct.toml | 18 + .../multidigraph_pattern_match.dtg.h | 36 + .../unlabelled/multidigraph_pattern_match.h | 17 + .../multidigraph_pattern_match.struct.toml | 24 + .../unlabelled/output_pattern_edge.dtg.h | 39 + .../unlabelled/output_pattern_edge.h | 13 + .../output_pattern_edge.struct.toml | 15 + .../unlabelled/pattern_edge.dtg.h | 39 + .../substitutions/unlabelled/pattern_edge.h | 27 + .../unlabelled/pattern_edge.struct.toml | 15 + .../unlabelled/pattern_matching.h | 25 + .../unlabelled/pattern_node.dtg.h | 39 + .../unlabelled/pattern_node.struct.toml | 15 + .../unlabelled/pattern_split.dtg.h | 47 + .../substitutions/unlabelled/pattern_split.h | 23 + .../unlabelled/pattern_split.struct.toml | 22 + .../unlabelled/unlabelled_graph_pattern.dtg.h | 24 + .../unlabelled/unlabelled_graph_pattern.h | 29 + .../unlabelled_graph_pattern.struct.toml | 10 + .../unlabelled/upward_open_pattern_edge.dtg.h | 39 + .../unlabelled/upward_open_pattern_edge.h | 12 + .../upward_open_pattern_edge.struct.toml | 15 + lib/substitutions/src/graph_pattern.cc | 257 ------ lib/substitutions/src/graph_pattern_match.cc | 305 ------- .../src/sub_parallel_computation_graph.cc | 3 - lib/substitutions/src/substitution.cc | 851 ++++++++---------- .../src/substitutions/constraint_type.dtg.cc | 64 ++ .../src/substitutions/graph_pattern.cc | 42 + .../operator_pattern/eval_list_access.cc | 41 + .../operator_pattern/eval_list_size.cc | 31 + .../operator_pattern/get_attribute.cc} | 128 ++- .../operator_attribute_constraint.dtg.cc | 121 +++ .../operator_attribute_expr.cc | 23 + .../operator_attribute_expr.dtg.cc | 137 +++ .../operator_attribute_key.dtg.cc | 505 +++++++++++ .../operator_attribute_list_access.dtg.cc | 101 +++ .../operator_attribute_list_size.dtg.cc | 88 ++ .../operator_attribute_pattern.dtg.cc | 73 ++ .../operator_attribute_value.dtg.cc | 292 ++++++ .../operator_pattern/satisfies_constraint.cc | 26 + .../operator_pattern/satisfies_pattern.cc | 14 + .../output_graph/attr_constant.dtg.cc | 59 ++ .../output_graph/output_graph_expr.dtg.cc | 20 + .../output_operator_attr_access.dtg.cc | 77 ++ .../output_operator_attribute_expr.dtg.cc | 79 ++ .../output_operator_attrs_assignment.dtg.cc | 58 ++ .../src/substitutions/pcg_pattern.dtg.cc | 21 + .../sub_parallel_computation_graph.cc | 22 + .../sub_parallel_computation_graph.dtg.cc | 22 + .../src/substitutions/substitution.cc | 154 ++++ .../src/substitutions/substitution.dtg.cc | 28 + .../tensor_pattern/eval_list_access.cc | 24 + .../tensor_pattern/eval_list_size.cc | 21 + .../tensor_pattern/get_attribute.cc | 28 + .../tensor_pattern/satisfies_constraint.cc | 22 + .../tensor_pattern/satisfies_pattern.cc | 13 + .../tensor_attribute_constraint.dtg.cc | 119 +++ .../tensor_pattern/tensor_attribute_expr.cc | 22 + .../tensor_attribute_expr.dtg.cc | 129 +++ .../tensor_attribute_key.dtg.cc | 73 ++ .../tensor_attribute_list_access.dtg.cc | 99 ++ .../tensor_attribute_list_size.dtg.cc | 87 ++ .../tensor_attribute_pattern.dtg.cc | 71 ++ .../tensor_attribute_value.dtg.cc | 105 +++ .../unlabelled/closed_pattern_edge.dtg.cc | 45 + .../unlabelled/downward_open_pattern_edge.cc | 9 + .../downward_open_pattern_edge.dtg.cc | 52 ++ .../substitutions/unlabelled/edge_splits.cc | 35 + .../unlabelled/edge_splits.dtg.cc | 31 + .../unlabelled/find_pattern_matches.cc | 161 ++++ .../unlabelled/input_pattern_edge.cc | 9 + .../unlabelled/input_pattern_edge.dtg.cc | 45 + .../match_additional_criterion.dtg.cc | 25 + .../substitutions/unlabelled/match_split.cc | 69 ++ .../unlabelled/match_split.dtg.cc | 26 + .../unlabelled/multidigraph_pattern_match.cc | 56 ++ .../multidigraph_pattern_match.dtg.cc | 34 + .../unlabelled/output_pattern_edge.cc | 9 + .../unlabelled/output_pattern_edge.dtg.cc | 46 + .../substitutions/unlabelled/pattern_edge.cc | 50 + .../unlabelled/pattern_edge.dtg.cc | 45 + .../unlabelled/pattern_matching.cc | 74 ++ .../unlabelled/pattern_node.dtg.cc | 45 + .../substitutions/unlabelled/pattern_split.cc | 42 + .../unlabelled/pattern_split.dtg.cc | 60 ++ .../unlabelled/unlabelled_graph_pattern.cc | 52 ++ .../unlabelled_graph_pattern.dtg.cc | 18 + .../unlabelled/upward_open_pattern_edge.cc | 9 + .../upward_open_pattern_edge.dtg.cc | 52 ++ .../test/src/test_substitution.cc | 20 +- lib/utils/include/utils/bidict.h | 30 + lib/utils/include/utils/check_fmtable.h | 15 + lib/utils/include/utils/containers.decl.h | 24 +- lib/utils/include/utils/containers.h | 67 +- .../include/utils/containers/concat_vectors.h | 18 + .../utils/containers/enumerate_vector.h | 20 + .../include/utils/containers/extend_vector.h | 16 + .../utils/containers/vector_transform.h | 21 + .../include/utils/containers/zip_vectors.h | 21 + lib/utils/include/utils/exception.decl.h | 16 +- lib/utils/include/utils/exception.h | 10 + lib/utils/include/utils/fmt.decl.h | 36 +- lib/utils/include/utils/fmt.h | 85 +- lib/utils/include/utils/fmt/expected.h | 32 + lib/utils/include/utils/fmt/pair.h | 20 + lib/utils/include/utils/fmt/unordered_map.h | 45 + lib/utils/include/utils/graph/README.md | 59 +- lib/utils/include/utils/graph/docs/edges.svg | 1 + .../utils/graph/docs/generate_diagram.py | 135 +++ .../include/utils/graph/docs/labelled.svg | 1 + lib/utils/include/utils/graph/docs/open.svg | 1 + .../include/utils/graph/docs/undirected.svg | 1 + lib/utils/include/utils/graph/multidiedge.h | 4 + lib/utils/include/utils/integer_conversions.h | 13 + lib/utils/include/utils/join_strings.h | 43 + lib/utils/include/utils/json.h | 49 +- lib/utils/include/utils/optional.h | 35 +- lib/utils/include/utils/overload.h | 15 + lib/utils/include/utils/stack_string.h | 14 + lib/utils/include/utils/stack_vector.h | 39 +- lib/utils/include/utils/type_index.h | 5 +- lib/utils/include/utils/variant.h | 12 + lib/utils/src/exception.cc | 9 +- lib/utils/src/utils/graph/multidiedge.cc | 17 + lib/utils/src/utils/integer_conversions.cc | 17 + lib/utils/src/utils/overload.cc | 1 + .../test/common/include/test/utils/doctest.h | 17 +- lib/utils/test/common/src/main.cc | 2 + lib/utils/test/src/test_optional.cc | 10 + lib/utils/test/src/test_stack_vector.cc | 8 + lib/utils/test/src/test_variant.cc | 7 + 845 files changed, 38294 insertions(+), 8609 deletions(-) create mode 100644 .editorconfig create mode 100644 .flake/pkgs/hpp2plantuml.nix create mode 100644 .gitattributes create mode 100644 codecov.yml create mode 100644 lib/kernels/include/kernels/legion_dim_t.dtg.h create mode 100644 lib/kernels/include/kernels/legion_dim_t.struct.toml create mode 100644 lib/kernels/src/kernels/legion_dim_t.dtg.cc rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/arg_ref.h (50%) create mode 100644 lib/local-execution/include/local-execution/concrete_arg.h rename lib/{runtime/include/runtime => local-execution/include/local-execution}/config.h (89%) rename lib/{runtime/src => local-execution/include/local-execution}/cost_metrics.h (95%) rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/device_specific.h (64%) rename lib/{runtime/src => local-execution/include/local-execution}/legion_tensor_shape.h (92%) rename lib/local-execution/include/{ => local-execution}/local_allocator.h (82%) rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/op_arg_ref.h (54%) create mode 100644 lib/local-execution/include/local-execution/op_task_invocation.h rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/op_task_signature.h (73%) create mode 100644 lib/local-execution/include/local-execution/op_tensor_spec.h rename lib/{runtime/src => local-execution/include/local-execution}/permissions.h (84%) rename lib/{runtime/include/runtime => local-execution/include/local-execution}/profiling.h (64%) rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/runtime_arg_ref.h (73%) rename lib/{runtime/src => local-execution/include/local-execution}/serialization.h (55%) rename lib/{runtime/src => local-execution/include/local-execution}/sim_environment.h (94%) rename lib/{runtime/include/runtime/task_spec => local-execution/include/local-execution}/slot_id.h (73%) rename lib/{runtime/src/task_spec => local-execution/include/local-execution}/slot_type.h (86%) create mode 100644 lib/local-execution/include/local-execution/task_argument_accessor.h rename lib/{runtime/src => local-execution/include/local-execution}/tasks.h (95%) rename lib/local-execution/include/{ => local-execution}/tracked_allocator.h (94%) create mode 100644 lib/local-execution/include/local-execution/variadic_tensor_ref.h create mode 100644 lib/local-execution/src/op_arg_ref.cc create mode 100644 lib/local-execution/src/op_task_invocation.cc create mode 100644 lib/local-execution/src/op_task_signature.cc rename lib/{runtime => local-execution}/src/ops/attention.cc (84%) rename lib/{runtime => local-execution}/src/ops/attention.h (91%) rename lib/{runtime => local-execution}/src/ops/batch_matmul.cc (87%) rename lib/{runtime => local-execution}/src/ops/batch_matmul.h (84%) rename lib/{runtime => local-execution}/src/ops/batch_norm.cc (81%) rename lib/{runtime => local-execution}/src/ops/batch_norm.h (89%) rename lib/{runtime => local-execution}/src/ops/cast.cc (79%) rename lib/{runtime => local-execution}/src/ops/cast.h (93%) rename lib/{runtime => local-execution}/src/ops/combine.cc (75%) rename lib/{runtime => local-execution}/src/ops/combine.h (87%) rename lib/{runtime => local-execution}/src/ops/concat.cc (78%) rename lib/{runtime => local-execution}/src/ops/concat.h (89%) rename lib/{runtime => local-execution}/src/ops/conv_2d.cc (78%) rename lib/{runtime => local-execution}/src/ops/conv_2d.h (89%) rename lib/{runtime => local-execution}/src/ops/dropout.cc (77%) rename lib/{runtime => local-execution}/src/ops/dropout.h (85%) rename lib/{runtime => local-execution}/src/ops/element_binary.cc (81%) rename lib/{runtime => local-execution}/src/ops/element_binary.h (95%) rename lib/{runtime => local-execution}/src/ops/element_unary.cc (76%) rename lib/{runtime => local-execution}/src/ops/element_unary.h (85%) rename lib/{runtime => local-execution}/src/ops/embedding.cc (81%) rename lib/{runtime => local-execution}/src/ops/embedding.h (88%) rename lib/{runtime => local-execution}/src/ops/flat.cc (76%) rename lib/{runtime => local-execution}/src/ops/flat.h (93%) create mode 100644 lib/local-execution/src/ops/gather.cc create mode 100644 lib/local-execution/src/ops/gather.h rename lib/{runtime => local-execution}/src/ops/layer_norm.cc (61%) rename lib/{runtime => local-execution}/src/ops/layer_norm.h (97%) rename lib/{runtime => local-execution}/src/ops/linear.cc (54%) rename lib/{runtime => local-execution}/src/ops/linear.h (98%) rename lib/{runtime => local-execution}/src/ops/noop.cc (95%) rename lib/{runtime => local-execution}/src/ops/noop.h (87%) rename lib/{runtime => local-execution}/src/ops/parallel_op.h (96%) rename lib/{runtime => local-execution}/src/ops/partition.cc (61%) rename lib/{runtime => local-execution}/src/ops/pool_2d.cc (57%) rename lib/{runtime => local-execution}/src/ops/pool_2d.h (95%) rename lib/{runtime => local-execution}/src/ops/reduce.cc (57%) rename lib/{runtime => local-execution}/src/ops/reduce.h (96%) rename lib/{runtime => local-execution}/src/ops/reduction.cc (59%) rename lib/{runtime => local-execution}/src/ops/reduction.h (94%) rename lib/{runtime => local-execution}/src/ops/repartition.h (97%) rename lib/{runtime => local-execution}/src/ops/replicate.cc (66%) rename lib/{runtime => local-execution}/src/ops/replicate.h (89%) rename lib/{runtime => local-execution}/src/ops/reshape.cc (68%) rename lib/{runtime => local-execution}/src/ops/reshape.h (97%) rename lib/{runtime => local-execution}/src/ops/reverse.cc (67%) rename lib/{runtime => local-execution}/src/ops/reverse.h (89%) rename lib/{runtime => local-execution}/src/ops/softmax.cc (68%) rename lib/{runtime => local-execution}/src/ops/softmax.h (97%) rename lib/{runtime => local-execution}/src/ops/split.cc (58%) rename lib/{runtime => local-execution}/src/ops/split.h (93%) rename lib/{runtime => local-execution}/src/ops/topk.cc (60%) rename lib/{runtime => local-execution}/src/ops/topk.h (97%) rename lib/{runtime => local-execution}/src/ops/transpose.cc (55%) rename lib/{runtime => local-execution}/src/ops/transpose.h (97%) rename lib/{runtime => local-execution}/src/permissions.cc (66%) rename lib/{runtime/src/task_spec => local-execution/src}/runtime_arg_ref.cc (55%) create mode 100644 lib/local-execution/src/variadic_tensor_ref.cc create mode 100644 lib/op-attrs/include/op-attrs/activation.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/activation.enum.toml delete mode 100644 lib/op-attrs/include/op-attrs/activation.h create mode 100644 lib/op-attrs/include/op-attrs/aggregate_op.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/aggregate_op.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/as_dot.h create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/datatype.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/datatype.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/slice.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/transform.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml delete mode 100644 lib/op-attrs/include/op-attrs/op.h create mode 100644 lib/op-attrs/include/op-attrs/operator_type.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/operator_type.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/operator_type.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_dim.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/param_sync.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/param_sync.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/pool_op.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/pool_op.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h create mode 100644 lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/replica_type.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/replica_type.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/shard_parallel_dim.h create mode 100644 lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_dims.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/tensor_shape.dtg.h create mode 100644 lib/op-attrs/include/op-attrs/tensor_shape.struct.toml delete mode 100644 lib/op-attrs/src/batch_matmul.cc delete mode 100644 lib/op-attrs/src/batch_norm.cc delete mode 100644 lib/op-attrs/src/broadcast.cc delete mode 100644 lib/op-attrs/src/combine.cc delete mode 100644 lib/op-attrs/src/conv_2d.cc delete mode 100644 lib/op-attrs/src/element_binary.cc delete mode 100644 lib/op-attrs/src/element_unary.cc delete mode 100644 lib/op-attrs/src/embedding.cc delete mode 100644 lib/op-attrs/src/linear.cc delete mode 100644 lib/op-attrs/src/noop.cc create mode 100644 lib/op-attrs/src/op-attrs/activation.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/as_dot.cc create mode 100644 lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/datatype.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ff_dim.dtg.cc rename lib/op-attrs/src/{ => op-attrs}/get_op_type.cc (60%) create mode 100644 lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/operator_type.cc create mode 100644 lib/op-attrs/src/op-attrs/operator_type.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/attention.cc (77%) create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_matmul.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_norm.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/cast.cc (100%) create mode 100644 lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/combine.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/concat.cc (100%) create mode 100644 lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/dropout.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_binary.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_unary.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/embedding.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/flat.cc (94%) create mode 100644 lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/gather.cc (100%) create mode 100644 lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/input.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/layer_norm.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/linear.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/noop.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc rename lib/op-attrs/src/{ => op-attrs/ops}/pool_2d.cc (65%) create mode 100644 lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduce.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduction.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/repartition.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/replicate.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reshape.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reverse.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/softmax.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/split.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/topk.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/transpose.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/param_sync.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc create mode 100644 lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/pool_op.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/replica_type.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/shard_parallel_dim.cc create mode 100644 lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_dims.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_shape.cc create mode 100644 lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc delete mode 100644 lib/op-attrs/src/op.cc delete mode 100644 lib/op-attrs/src/parallel_dim.cc delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record_solver.cc delete mode 100644 lib/op-attrs/src/parallel_dim_mapping_record_solver.h delete mode 100644 lib/op-attrs/src/parallel_tensor_shape.cc delete mode 100644 lib/op-attrs/src/reduce.cc delete mode 100644 lib/op-attrs/src/reduction.cc delete mode 100644 lib/op-attrs/src/repartition.cc delete mode 100644 lib/op-attrs/src/replicate.cc delete mode 100644 lib/op-attrs/src/reshape.cc delete mode 100644 lib/op-attrs/src/softmax.cc delete mode 100644 lib/op-attrs/src/split.cc delete mode 100644 lib/op-attrs/src/tensor_shape.cc delete mode 100644 lib/op-attrs/src/topk.cc delete mode 100644 lib/op-attrs/src/transpose.cc create mode 100644 lib/op-attrs/test/CMakeLists.txt create mode 100644 lib/op-attrs/test/src/dim_ordered/slice.cc create mode 100644 lib/op-attrs/test/src/ops/combine.cc create mode 100644 lib/op-attrs/test/src/ops/linear.cc create mode 100644 lib/op-attrs/test/src/ops/reduction.cc create mode 100644 lib/op-attrs/test/src/ops/repartition.cc create mode 100644 lib/op-attrs/test/src/ops/replicate.cc create mode 100644 lib/op-attrs/test/src/test_attention.cc create mode 100644 lib/op-attrs/test/src/test_batch_matmul.cc create mode 100644 lib/op-attrs/test/src/test_conv_2d.cc create mode 100644 lib/op-attrs/test/src/test_dim_ordered.cc create mode 100644 lib/op-attrs/test/src/test_element_binary.cc create mode 100644 lib/op-attrs/test/src/test_element_unary.cc create mode 100644 lib/op-attrs/test/src/test_embedding.cc create mode 100644 lib/op-attrs/test/src/test_operator_attrs.cc create mode 100644 lib/op-attrs/test/src/test_regularizer_attrs.cc create mode 100644 lib/pcg/include/pcg/computation_graph.dtg.h create mode 100644 lib/pcg/include/pcg/computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h create mode 100644 lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml create mode 100644 lib/pcg/include/pcg/cpu_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/cpu_id_t.struct.toml create mode 100644 lib/pcg/include/pcg/create_grad.dtg.h create mode 100644 lib/pcg/include/pcg/create_grad.enum.toml create mode 100644 lib/pcg/include/pcg/dataflow_graph.h create mode 100644 lib/pcg/include/pcg/dataflow_input.dtg.h create mode 100644 lib/pcg/include/pcg/dataflow_input.variant.toml create mode 100644 lib/pcg/include/pcg/device_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/device_id_t.variant.toml create mode 100644 lib/pcg/include/pcg/device_type.dtg.h create mode 100644 lib/pcg/include/pcg/device_type.enum.toml delete mode 100644 lib/pcg/include/pcg/device_type.h rename lib/pcg/include/pcg/file_format/v1/{data_type.h => data_type_value.h} (64%) create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h create mode 100644 lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml delete mode 100644 lib/pcg/include/pcg/file_format/v1/initializer.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/operator_attrs.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/parallel_tensor.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/param_sync.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/tensor.h create mode 100644 lib/pcg/include/pcg/gpu_id_t.dtg.h create mode 100644 lib/pcg/include/pcg/gpu_id_t.struct.toml delete mode 100644 lib/pcg/include/pcg/initializer.h create mode 100644 lib/pcg/include/pcg/initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializer_attrs.variant.toml create mode 100644 lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml delete mode 100644 lib/pcg/include/pcg/layer.h create mode 100644 lib/pcg/include/pcg/layer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/layer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/layer_guid_t.dtg.h create mode 100644 lib/pcg/include/pcg/layer_guid_t.struct.toml create mode 100644 lib/pcg/include/pcg/machine_specification.dtg.h create mode 100644 lib/pcg/include/pcg/machine_specification.struct.toml create mode 100644 lib/pcg/include/pcg/machine_view.dtg.h create mode 100644 lib/pcg/include/pcg/machine_view.struct.toml create mode 100644 lib/pcg/include/pcg/num_points_t.dtg.h create mode 100644 lib/pcg/include/pcg/num_points_t.struct.toml create mode 100644 lib/pcg/include/pcg/open_dataflow_graph.h delete mode 100644 lib/pcg/include/pcg/operator.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.h create mode 100644 lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml create mode 100644 lib/pcg/include/pcg/operator_guid_t.dtg.h delete mode 100644 lib/pcg/include/pcg/operator_guid_t.h create mode 100644 lib/pcg/include/pcg/operator_guid_t.struct.toml delete mode 100644 lib/pcg/include/pcg/optimizer.h create mode 100644 lib/pcg/include/pcg/optimizer_attrs.h create mode 100644 lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_computation_graph.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_layer_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_layer_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml delete mode 100644 lib/pcg/include/pcg/serialization.h create mode 100644 lib/pcg/include/pcg/side_size_t.dtg.h create mode 100644 lib/pcg/include/pcg/side_size_t.struct.toml create mode 100644 lib/pcg/include/pcg/strided_rectangle.dtg.h create mode 100644 lib/pcg/include/pcg/strided_rectangle.struct.toml create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.dtg.h create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.h create mode 100644 lib/pcg/include/pcg/strided_rectangle_side.struct.toml delete mode 100644 lib/pcg/include/pcg/tensor.h create mode 100644 lib/pcg/include/pcg/tensor_attrs.dtg.h create mode 100644 lib/pcg/include/pcg/tensor_attrs.struct.toml create mode 100644 lib/pcg/include/pcg/tensor_guid_t.dtg.h delete mode 100644 lib/pcg/include/pcg/tensor_guid_t.h create mode 100644 lib/pcg/include/pcg/tensor_guid_t.struct.toml delete mode 100644 lib/pcg/src/computation_graph.cc delete mode 100644 lib/pcg/src/device_id.cc delete mode 100644 lib/pcg/src/file_format/v1/v1.cc delete mode 100644 lib/pcg/src/layer.cc delete mode 100644 lib/pcg/src/machine_view.cc delete mode 100644 lib/pcg/src/operator.cc delete mode 100644 lib/pcg/src/parallel_computation_graph.cc delete mode 100644 lib/pcg/src/parallel_tensor.cc create mode 100644 lib/pcg/src/pcg/computation_graph.cc create mode 100644 lib/pcg/src/pcg/computation_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc rename lib/pcg/src/{ => pcg}/computation_graph_builder.cc (52%) create mode 100644 lib/pcg/src/pcg/cpu_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/create_grad.dtg.cc create mode 100644 lib/pcg/src/pcg/dataflow_input.dtg.cc create mode 100644 lib/pcg/src/pcg/device_id.cc create mode 100644 lib/pcg/src/pcg/device_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/device_type.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/gpu_id_t.dtg.cc create mode 100644 lib/pcg/src/pcg/initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/layer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/layer_guid_t.dtg.cc create mode 100644 lib/pcg/src/pcg/machine_specification.dtg.cc create mode 100644 lib/pcg/src/pcg/machine_view.cc create mode 100644 lib/pcg/src/pcg/machine_view.dtg.cc create mode 100644 lib/pcg/src/pcg/num_points_t.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_input.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_output.cc create mode 100644 lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc create mode 100644 lib/pcg/src/pcg/operator_guid_t.dtg.cc create mode 100644 lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/side_size_t.dtg.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle.dtg.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle_side.cc create mode 100644 lib/pcg/src/pcg/strided_rectangle_side.dtg.cc create mode 100644 lib/pcg/src/pcg/tensor_attrs.dtg.cc create mode 100644 lib/pcg/src/pcg/tensor_guid_t.dtg.cc delete mode 100644 lib/pcg/src/serialization.cc delete mode 100644 lib/pcg/src/tensor.cc create mode 100644 lib/pcg/test/CMakeLists.txt create mode 100644 lib/pcg/test/src/test_computation_graph_builder.cc delete mode 100644 lib/runtime/include/runtime/task_spec/concrete_arg.h delete mode 100644 lib/runtime/src/ops/gather.cc delete mode 100644 lib/runtime/src/ops/gather.h delete mode 100644 lib/runtime/src/task_spec/op_task_invocation.cc delete mode 100644 lib/runtime/src/task_spec/op_task_invocation.h delete mode 100644 lib/runtime/src/task_spec/op_tensor_spec.h delete mode 100644 lib/runtime/src/task_spec/task_argument_accessor.h delete mode 100644 lib/runtime/src/task_spec/variadic_tensor_ref.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml create mode 100644 lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc create mode 100644 lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc delete mode 100644 lib/substitutions/include/substitutions/attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/constraint_type.dtg.h create mode 100644 lib/substitutions/include/substitutions/constraint_type.enum.toml delete mode 100644 lib/substitutions/include/substitutions/graph_pattern_match.h delete mode 100644 lib/substitutions/include/substitutions/operator_pattern.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h rename lib/substitutions/include/substitutions/{ => operator_pattern}/get_attribute.h (83%) create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h create mode 100644 lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h delete mode 100644 lib/substitutions/include/substitutions/output_graph.h create mode 100644 lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml delete mode 100644 lib/substitutions/include/substitutions/parallel_tensor_pattern.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml create mode 100644 lib/substitutions/include/substitutions/substitution.dtg.h create mode 100644 lib/substitutions/include/substitutions/substitution.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_matching.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h create mode 100644 lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml delete mode 100644 lib/substitutions/src/graph_pattern.cc delete mode 100644 lib/substitutions/src/graph_pattern_match.cc delete mode 100644 lib/substitutions/src/sub_parallel_computation_graph.cc create mode 100644 lib/substitutions/src/substitutions/constraint_type.dtg.cc create mode 100644 lib/substitutions/src/substitutions/graph_pattern.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc rename lib/substitutions/src/{operator_attributes.cc => substitutions/operator_pattern/get_attribute.cc} (72%) create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc create mode 100644 lib/substitutions/src/substitutions/pcg_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc create mode 100644 lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc create mode 100644 lib/substitutions/src/substitutions/substitution.cc create mode 100644 lib/substitutions/src/substitutions/substitution.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/edge_splits.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_split.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_split.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc create mode 100644 lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc create mode 100644 lib/utils/include/utils/check_fmtable.h create mode 100644 lib/utils/include/utils/containers/concat_vectors.h create mode 100644 lib/utils/include/utils/containers/enumerate_vector.h create mode 100644 lib/utils/include/utils/containers/extend_vector.h create mode 100644 lib/utils/include/utils/containers/vector_transform.h create mode 100644 lib/utils/include/utils/containers/zip_vectors.h create mode 100644 lib/utils/include/utils/fmt/expected.h create mode 100644 lib/utils/include/utils/fmt/pair.h create mode 100644 lib/utils/include/utils/fmt/unordered_map.h create mode 100644 lib/utils/include/utils/graph/docs/edges.svg create mode 100644 lib/utils/include/utils/graph/docs/generate_diagram.py create mode 100644 lib/utils/include/utils/graph/docs/labelled.svg create mode 100644 lib/utils/include/utils/graph/docs/open.svg create mode 100644 lib/utils/include/utils/graph/docs/undirected.svg create mode 100644 lib/utils/include/utils/integer_conversions.h create mode 100644 lib/utils/include/utils/join_strings.h create mode 100644 lib/utils/include/utils/overload.h create mode 100644 lib/utils/src/utils/graph/multidiedge.cc create mode 100644 lib/utils/src/utils/integer_conversions.cc create mode 100644 lib/utils/src/utils/overload.cc create mode 100644 lib/utils/test/common/src/main.cc create mode 100644 lib/utils/test/src/test_optional.cc diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000..7242dd283c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,22 @@ +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +[{CMakeLists.txt,*.cmake}] +indent_style = space +indent_size = 2 + +[*.{cc,h,cu,cpp}] +indent_style = space +indent_size = 2 + +[*.py] +indent_style = space +indent_size = 4 + +[*.toml] +indent_style = space +indent_size = 2 diff --git a/.flake/pkgs/hpp2plantuml.nix b/.flake/pkgs/hpp2plantuml.nix new file mode 100644 index 0000000000..d5aba814f1 --- /dev/null +++ b/.flake/pkgs/hpp2plantuml.nix @@ -0,0 +1,11 @@ +{buildPythonPackage, fetchPypi}: + +buildPythonPackage rec { + pname = "hpp2plantuml"; + version = "0.8.5"; + format = "wheel"; + src = fetchPypi { + inherit pname version format; + sha256 = "sha256-PfTJmBypI21AAK3sMojygQfrhnRqcMmVCW4dxGfDfQg="; + }; +} diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..efec9cf353 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.dtg.cc linguist-generated=true +*.dtg.h linguist-generated=true diff --git a/.github/workflows/helpers/cmake_cuda.sh b/.github/workflows/helpers/cmake_cuda.sh index e549859a5a..f062569efb 100755 --- a/.github/workflows/helpers/cmake_cuda.sh +++ b/.github/workflows/helpers/cmake_cuda.sh @@ -23,6 +23,7 @@ IFS=" " read -r -a FLAGS <<< "$CMAKE_FLAGS" -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ + -DFF_USE_CODE_COVERAGE=ON \ "${FLAGS[@]}" # vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index a53a6afc11..17de365e9e 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -72,13 +72,13 @@ jobs: run: | build_libs.sh kernels - - name: Build substitutions - run: | - build_libs.sh substitutions + # - name: Build substitutions + # run: | + # build_libs.sh substitutions - - name: Build compiler - run: | - build_libs.sh compiler + # - name: Build compiler + # run: | + # build_libs.sh compiler - name: Build substitution-generator run: | @@ -88,14 +88,40 @@ jobs: run: | test_libs.sh utils - - name: Test substitutions + - name: Test op-attrs run: | - test_libs.sh substitutions + test_libs.sh op-attrs - - name: Test compiler + - name: Test pcg run: | - test_libs.sh compiler + test_libs.sh pcg + + # - name: Test substitutions + # run: | + # test_libs.sh substitutions + + # - name: Test compiler + # run: | + # test_libs.sh compiler - name: Test substitution-generator run: | test_libs.sh substitution-generator + + - name: Generate code coverage + run: | + echo "gitwork: $GITHUB_WORKSPACE" + lcov --capture --directory . --output-file main_coverage.info + lcov --extract main_coverage.info "$GITHUB_WORKSPACE/lib/*" --output-file main_coverage.info + lcov --remove main_coverage.info "$GITHUB_WORKSPACE/lib/*.dtg.h" "$GITHUB_WORKSPACE/lib/*.dtg.cc" --output-file main_coverage.info + lcov --list main_coverage.info + + - name: Upload code coverage + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: main_coverage.info + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + verbose: true diff --git a/.proj.toml b/.proj.toml index c048347913..2ca91fabb7 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,11 +6,16 @@ fix_compile_commands = false build_targets = [ "kernels", +<<<<<<< HEAD ] test_targets = [ # "utils-tests", # "substitutions-tests", # "compiler-tests", + "pcg", + # "substitutions", + # "compiler", + "substitution-generator", ] [cmake_flags_extra] @@ -20,4 +25,4 @@ CMAKE_CUDA_ARCHITECTURES = "60" CMAKE_HIP_ARCHITECTURES = "gfx900" CMAKE_CXX_COMPILER = "hipcc" CMAKE_C_COMPILER = "hipcc" -# FF_CUDA_ARCH = "60" \ No newline at end of file +# FF_CUDA_ARCH = "60" diff --git a/CMakeLists.txt b/CMakeLists.txt index 27d8482c63..5222af555a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ option(FF_USE_PREBUILT_LEGION "Enable use of Legion pre-compiled library, if ava option(FF_USE_ALL_PREBUILT_LIBRARIES "Enable use of all pre-compiled libraries, if available" OFF) option(FF_USE_PYTHON "Enable Python" ON) option(FF_BUILD_FROM_PYPI "Build from pypi" OFF) +option(FF_USE_CODE_COVERAGE "Enable code coverage" OFF) set(FF_GASNET_CONDUITS aries udp mpi ibv ucx) set(FF_GASNET_CONDUIT "mpi" CACHE STRING "Select GASNet conduit ${FF_GASNET_CONDUITS}") @@ -86,8 +87,10 @@ endif() # include(cuda) # include(cudnn) # include(nccl) -include(CodeCoverage) -append_coverage_compiler_flags() +if (FF_USE_CODE_COVERAGE) + include(CodeCoverage) + append_coverage_compiler_flags() +endif() # set_property(CACHE FF_GPU_BACKEND PROPERTY STRINGS ${FF_GPU_BACKENDS}) include(json) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000..326788decf --- /dev/null +++ b/codecov.yml @@ -0,0 +1,22 @@ +codecov: + branch: repo-refactor + notify: + require_ci_to_pass: false + +ignore: + - "**/*.dtg.h" + - "**/*.dtg.cc" + + +coverage: + status: + project: + default: + target: auto # Automatically set target at 70% of the current project coverage + threshold: 0% # Allows the coverage to drop by no more than 0% from the target + base: auto # Picks the base of the pull request as a reference to compare against + +comment: + layout: "header, diff, flags, files" + behavior: default + require_changes: no \ No newline at end of file diff --git a/flake.lock b/flake.lock index ffd4a02962..f0fc292a5e 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1712222904, - "narHash": "sha256-FRI/RdOTtmo9o7iwZiACD0lSSlgvKqcpppjliXUHyRU=", + "lastModified": 1717449667, + "narHash": "sha256-xFGnB44WadxlCa2LnlH82g1c89+7UAomVgytIewSwO0=", "owner": "lockshaw", "repo": "proj", - "rev": "5b7a82dc01fa25076a8b3db96c1f2ea4752ae990", + "rev": "28b37a9bd993d3de3d80695eb3834a0436c805a4", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index a0464e883b..95b7a5d876 100644 --- a/flake.nix +++ b/flake.nix @@ -62,6 +62,7 @@ { packages = { legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + hpp2plantuml = pkgs.python3Packages.callPackage ./.flake/pkgs/hpp2plantuml.nix { }; rapidcheckFull = pkgs.symlinkJoin { name = "rapidcheckFull"; paths = (with pkgs; [ rapidcheck.out rapidcheck.dev ]); @@ -123,6 +124,7 @@ ]) (with self.packages.${system}; [ legion + hpp2plantuml rapidcheckFull doctest ]) @@ -143,6 +145,10 @@ inputsFrom = [ ci ]; inherit (ci) CMAKE_FLAGS; + VIMPLUGINS = lib.strings.concatStringsSep "," [ + "${proj-repo.packages.${system}.proj-nvim}" + ]; + buildInputs = builtins.concatLists [ (with pkgs; [ clang-tools @@ -176,4 +182,4 @@ }; } ); -} \ No newline at end of file +} diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc index 91c7a11888..8c176eb4d2 100644 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -41,10 +41,9 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; pcg.add_edge(e); - pcg.add_output(e, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); auto test_allowed_machine_views = [](Operator const &, MachineSpecification const &) { diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 36796bc504..1cb10e8ce7 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -6,6 +6,7 @@ #include "utils/stack_vector.h" #include "utils/visitable.h" #include +#include #include namespace FlexFlow { @@ -41,8 +42,10 @@ struct ArrayShape { std::optional at_maybe(std::size_t) const; ArrayShape reversed_dim_order() const; - ArrayShape sub_shape(std::optional start, - std::optional end); + + ArrayShape + sub_shape(std::optional> start, + std::optional> end) const; public: LegionTensorDims dims; diff --git a/lib/kernels/include/kernels/cast_kernels.h b/lib/kernels/include/kernels/cast_kernels.h index 4e6878e318..96f9aadd52 100644 --- a/lib/kernels/include/kernels/cast_kernels.h +++ b/lib/kernels/include/kernels/cast_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/conv_2d_kernels.h b/lib/kernels/include/kernels/conv_2d_kernels.h index b646c4b7cb..0a93125367 100644 --- a/lib/kernels/include/kernels/conv_2d_kernels.h +++ b/lib/kernels/include/kernels/conv_2d_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "kernels/accessor.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/element_binary_kernels.h b/lib/kernels/include/kernels/element_binary_kernels.h index a9cbba420e..41447e98e6 100644 --- a/lib/kernels/include/kernels/element_binary_kernels.h +++ b/lib/kernels/include/kernels/element_binary_kernels.h @@ -5,7 +5,7 @@ #include "ff_handle.h" #include "kernels/array_shape.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/element_unary_kernels.h b/lib/kernels/include/kernels/element_unary_kernels.h index 17e0048c65..5044b0cdb2 100644 --- a/lib/kernels/include/kernels/element_unary_kernels.h +++ b/lib/kernels/include/kernels/element_unary_kernels.h @@ -9,9 +9,6 @@ namespace FlexFlow { -using ElementUnaryUnifiedAttrs = - std::variant; - struct ElementUnaryPerDeviceState { ffTensorDescriptor_t inputTensor, outputTensor; req actiDesc; @@ -27,18 +24,34 @@ namespace ElementUnary { ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs); + ElementUnaryAttrs const &attrs); + +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad); + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, diff --git a/lib/kernels/include/kernels/gather_kernels.h b/lib/kernels/include/kernels/gather_kernels.h index c74f9c0bb6..13bf4b898a 100644 --- a/lib/kernels/include/kernels/gather_kernels.h +++ b/lib/kernels/include/kernels/gather_kernels.h @@ -2,36 +2,34 @@ #define _FLEXFLOW_OPS_KERNELS_GATHER_KERNELS_H #include "accessor.h" -#include "device.h" +#include "kernels/device.h" namespace FlexFlow { struct GatherPerDeviceState { - int legion_dim; - req index_data_type; + PerDeviceFFHandle handle; + legion_dim_t legion_dim; }; + FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(GatherPerDeviceState, - legion_dim, - index_data_type); + handle, + legion_dim); namespace Kernels { namespace Gather { + void forward_kernel(ffStream_t stream, GatherPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &index, - GenericTensorAccessorW const &output, - size_t stride, - size_t input_dim_size, - size_t output_dim_size); + GenericTensorAccessorW const &output); + void backward_kernel(ffStream_t stream, GatherPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &index, - GenericTensorAccessorW const &input_grad, - size_t stride, - size_t input_dim_size, - size_t output_dim_size); + GenericTensorAccessorW const &input_grad); + } // namespace Gather } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index f5c1d7ccc9..cf6ebfc2d4 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -1,14 +1,14 @@ #ifndef _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H +#include "kernels/legion_dim_t.dtg.h" #include "op-attrs/dim_ordered.h" -#include "utils/strong_typedef.h" namespace FlexFlow { -struct legion_dim_t : strong_typedef { - using strong_typedef::strong_typedef; -}; +legion_dim_t add_to_legion_dim(legion_dim_t, int); + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t, int num_dimensions); template using LegionOrdered = DimOrdered; diff --git a/lib/kernels/include/kernels/legion_dim_t.dtg.h b/lib/kernels/include/kernels/legion_dim_t.dtg.h new file mode 100644 index 0000000000..622f9c240a --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/kernels/include/kernels/legion_dim_t.struct.toml +/* proj-data +{ + "generated_from": "f67d6e50c53539a21d69e7162cf965f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H +#define _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct legion_dim_t { + legion_dim_t() = delete; + legion_dim_t(int const &value); + + bool operator==(legion_dim_t const &) const; + bool operator!=(legion_dim_t const &) const; + bool operator<(legion_dim_t const &) const; + bool operator>(legion_dim_t const &) const; + bool operator<=(legion_dim_t const &) const; + bool operator>=(legion_dim_t const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::legion_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::legion_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::legion_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &); +std::ostream &operator<<(std::ostream &, legion_dim_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_KERNELS_INCLUDE_KERNELS_LEGION_DIM_T_DTG_H diff --git a/lib/kernels/include/kernels/legion_dim_t.struct.toml b/lib/kernels/include/kernels/legion_dim_t.struct.toml new file mode 100644 index 0000000000..d2afb0d73f --- /dev/null +++ b/lib/kernels/include/kernels/legion_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "legion_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/kernels/include/kernels/pool_2d_kernels.h b/lib/kernels/include/kernels/pool_2d_kernels.h index 96bb6eccf9..798c0507f8 100644 --- a/lib/kernels/include/kernels/pool_2d_kernels.h +++ b/lib/kernels/include/kernels/pool_2d_kernels.h @@ -3,7 +3,7 @@ #include "device.h" #include "kernels/ff_handle.h" -#include "op-attrs/activation.h" +#include "op-attrs/activation.dtg.h" #include "op-attrs/ops/pool_2d.h" #include "utils/visitable.h" diff --git a/lib/kernels/include/kernels/reduce_kernels.h b/lib/kernels/include/kernels/reduce_kernels.h index 51730fb0cd..56241b73ce 100644 --- a/lib/kernels/include/kernels/reduce_kernels.h +++ b/lib/kernels/include/kernels/reduce_kernels.h @@ -4,7 +4,7 @@ #include "array_shape.h" #include "device.h" #include "ff_handle.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.dtg.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/transpose_kernels.h b/lib/kernels/include/kernels/transpose_kernels.h index cb34ff6736..fa087fada3 100644 --- a/lib/kernels/include/kernels/transpose_kernels.h +++ b/lib/kernels/include/kernels/transpose_kernels.h @@ -8,7 +8,7 @@ namespace FlexFlow { struct TransposePerDeviceState { int num_dim; - req> perm; + req> perm; }; FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(TransposePerDeviceState, diff --git a/lib/kernels/src/cuda/ops/concat_kernels.cu b/lib/kernels/src/cuda/ops/concat_kernels.cu index dcf7a41a2f..68004738d2 100644 --- a/lib/kernels/src/cuda/ops/concat_kernels.cu +++ b/lib/kernels/src/cuda/ops/concat_kernels.cu @@ -25,15 +25,8 @@ void calc_blk_size(size_t &num_blocks, size_t &blk_size, ArrayShape const &shape, ff_dim_t axis) { - num_blocks = 1; - blk_size = 1; - for (int d = 0; d < shape.num_dims(); d++) { - if (d <= axis) { - blk_size *= shape[legion_dim_t(d)]; - } else { - num_blocks *= shape[legion_dim_t(d)]; - } - } + blk_size = shape.sub_shape(legion_dim_t{0}, axis).num_elements(); + num_blocks = shape.sub_shape(axis, std::nullopt).num_elements(); } void forward_kernel(cudaStream_t stream, diff --git a/lib/kernels/src/cuda/ops/element_binary_kernels.cu b/lib/kernels/src/cuda/ops/element_binary_kernels.cu index be369ff064..45b4d43006 100644 --- a/lib/kernels/src/cuda/ops/element_binary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_binary_kernels.cu @@ -17,14 +17,12 @@ #include "kernels/element_binary_kernels.h" #include "kernels/ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" namespace FlexFlow { namespace Kernels { namespace ElementBinary { -using OperatorType = Op; - __global__ void elewise_binary_backward_kernel(size_t volume, float const alpha, float const beta, @@ -36,28 +34,28 @@ __global__ void elewise_binary_backward_kernel(size_t volume, float *rhs_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EW_ADD: { + case OperatorType::EW_ADD: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_SUB: { + case OperatorType::EW_SUB: { lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] + beta * rhs_grad[i]; break; } - case Op::EW_MUL: { + case OperatorType::EW_MUL: { lhs_grad[i] = alpha * out_grad[i] * rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = alpha * out_grad[i] * lhs[i] + beta * rhs_grad[i]; break; } - case Op::EW_DIV: { + case OperatorType::EW_DIV: { lhs_grad[i] = alpha * out_grad[i] / rhs[i] + beta * lhs_grad[i]; rhs_grad[i] = -alpha * out_grad[i] * lhs[i] / (rhs[i] * rhs[i]) + beta * rhs_grad[i]; break; } - case Op::EW_MAX: { + case OperatorType::EW_MAX: { lhs_grad[i] = (lhs[i] >= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -66,7 +64,7 @@ __global__ void elewise_binary_backward_kernel(size_t volume, : beta * rhs_grad[i]; break; } - case Op::EW_MIN: { + case OperatorType::EW_MIN: { lhs_grad[i] = (lhs[i] <= rhs[i]) ? alpha * out_grad[i] + beta * lhs_grad[i] : beta * lhs_grad[i]; @@ -102,17 +100,17 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(cudnnCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = CUDNN_OP_TENSOR_ADD; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = CUDNN_OP_TENSOR_MUL; break; - case Op::EW_MAX: + case OperatorType::EW_MAX: mode = CUDNN_OP_TENSOR_MAX; break; - case Op::EW_MIN: + case OperatorType::EW_MIN: mode = CUDNN_OP_TENSOR_MIN; break; default: @@ -152,13 +150,13 @@ void forward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(handle.dnn, stream)); float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: - case Op::EW_MAX: - case Op::EW_MIN: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: break; default: assert(false); @@ -167,9 +165,9 @@ void forward_kernel(cudaStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD || - op_type == Op::EW_MUL); - if (op_type == Op::EW_SUB || op_type == Op::EW_ADD) { + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD || + op_type == OperatorType::EW_MUL); + if (op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD) { // output = (beta*output + alpha1*input1) + beta*output = input1 checkCUDNN(cudnnOpTensor(handle.dnn, m.opDesc, @@ -195,7 +193,7 @@ void forward_kernel(cudaStream_t stream, &alpha1, m.outputTensor, out_ptr)); - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, @@ -258,7 +256,7 @@ void backward_kernel(cudaStream_t stream, checkCUDA(cublasSetStream(handle.blas, stream)); checkCUDNN(cudnnSetStream(handle.dnn, stream)); - if (op_type == Op::EW_ADD || op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_ADD || op_type == OperatorType::EW_SUB) { float alpha = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -284,7 +282,7 @@ void backward_kernel(cudaStream_t stream, lhs_grad_ptr)); } } - if (op_type == Op::EW_SUB) { + if (op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { @@ -311,7 +309,7 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MUL) { + } else if (op_type == OperatorType::EW_MUL) { float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f, zero = 0.0f; if (lhs_grad_ptr != nullptr) { if (broadcast_inputLHS) { @@ -393,7 +391,8 @@ void backward_kernel(cudaStream_t stream, rhs_grad_ptr)); } } - } else if (op_type == Op::EW_MIN || op_type == Op::EW_MAX) { + } else if (op_type == OperatorType::EW_MIN || + op_type == OperatorType::EW_MAX) { float alpha = 1.0f, beta = 1.0f; cudnnDataType_t dataType; int n; diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 305e778726..e37d32c325 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -25,29 +25,19 @@ namespace ElementUnary { static bool use_cudnn(OperatorType op_type) { switch (op_type) { - case Op::RELU: - case Op::SIGMOID: - case Op::TANH: - case Op::ELU: + case OperatorType::RELU: + case OperatorType::SIGMOID: + case OperatorType::TANH: + case OperatorType::ELU: return true; default: return false; } } -template -T get_scalar(ElementUnaryUnifiedAttrs const &attrs) { - if (std::holds_alternative(attrs)) { - return (T)std::get(attrs).scalar; - } else { - T dummy_scalar; - return dummy_scalar; - } -} - -ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, - ArrayShape const &output_shape, - ElementUnaryUnifiedAttrs const &attrs) { +static ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + OperatorType op_type) { ffTensorDescriptor_t inputTensor; ffTensorDescriptor_t outputTensor; @@ -57,21 +47,19 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, checkCUDNN(cudnnCreateTensorDescriptor(&outputTensor)); checkCUDNN(cudnnCreateActivationDescriptor(&actiDesc)); - Op op_type = get_op_type(attrs); - if (use_cudnn(op_type)) { cudnnActivationMode_t mode; switch (op_type) { - case Op::SIGMOID: + case OperatorType::SIGMOID: mode = CUDNN_ACTIVATION_SIGMOID; break; - case Op::RELU: + case OperatorType::RELU: mode = CUDNN_ACTIVATION_RELU; break; - case Op::TANH: + case OperatorType::TANH: mode = CUDNN_ACTIVATION_TANH; break; - case Op::ELU: + case OperatorType::ELU: mode = CUDNN_ACTIVATION_ELU; break; default: @@ -88,52 +76,64 @@ ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, return {inputTensor, outputTensor, actiDesc}; } +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + +ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, + ArrayShape const &output_shape, + ElementScalarUnaryAttrs const &attrs) { + return init_kernel(input_shape, output_shape, get_op_type(attrs)); +} + template __global__ void elewise_unary_forward_kernel( - coord_t volume, T const scalar, OperatorType type, T const *in, T *out) { + coord_t volume, T scalar, OperatorType type, T const *in, T *out) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { out[i] = (T)exp((float)in[i]); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { out[i] = in[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { out[i] = in[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { out[i] = in[i] + scalar; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { out[i] = in[i] - scalar; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { out[i] = in[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { out[i] = (T)(in[i] * 0.5 * erfc(-in[i] * M_SQRT1_2)); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { out[i] = (T)(1.0f / sqrt((float)in[i])); break; } - case Op::POW: { + case OperatorType::POW: { out[i] = (T)(powf(in[i], scalar)); break; } - case Op::SIN: { + case OperatorType::SIN: { out[i] = (T)sin((float)in[i]); break; } - case Op::COS: { + case OperatorType::COS: { out[i] = (T)cos((float)in[i]); break; } @@ -145,7 +145,7 @@ __global__ void elewise_unary_forward_kernel( template __global__ void elewise_unary_backward_kernel(coord_t volume, - T const scalar, + T scalar, OperatorType type, T const *output, T const *output_grad, @@ -153,53 +153,53 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, T *input_grad) { CUDA_KERNEL_LOOP(i, volume) { switch (type) { - case Op::EXP: { + case OperatorType::EXP: { // TODO: change to use output instead of recomputing input_grad[i] += (T)(output_grad[i] * exp((float)input[i])); break; } - case Op::IDENTITY: { + case OperatorType::IDENTITY: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_MULTIPLY: { + case OperatorType::SCALAR_MULTIPLY: { input_grad[i] += output_grad[i] * scalar; break; } - case Op::SCALAR_ADD: { + case OperatorType::SCALAR_ADD: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_SUB: { + case OperatorType::SCALAR_SUB: { input_grad[i] += output_grad[i]; break; } - case Op::SCALAR_TRUE_DIV: { + case OperatorType::SCALAR_TRUE_DIV: { input_grad[i] += output_grad[i] / scalar; break; } - case Op::GELU: { + case OperatorType::GELU: { input_grad[i] = (T)(output_grad[i] * (0.5 * erfc(-input[i] * M_SQRT1_2) - 0.5 * M_SQRT1_2 * input[i] * exp(-input[i] * input[i] * 0.5))); break; } - case Op::RSQRT: { + case OperatorType::RSQRT: { input_grad[i] = (T)(-0.5f * output_grad[i] * output[i] * output[i] * output[i]); break; } - case Op::POW: { + case OperatorType::POW: { input_grad[i] = (T)(output_grad[i] * scalar * powf(input[i], scalar - 1)); break; } - case Op::SIN: { + case OperatorType::SIN: { input_grad[i] += (T)(output_grad[i] * cos((float)input[i])); break; } - case Op::COS: { + case OperatorType::COS: { input_grad[i] += (T)(output_grad[i] * -sin((float)input[i])); break; } @@ -213,12 +213,12 @@ template struct ForwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) const { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f, beta = 0.0f; checkCUDNN(cudnnActivationForward(handle.dnn, @@ -234,7 +234,7 @@ struct ForwardKernel { elewise_unary_forward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, input.get(), output.get()); @@ -246,7 +246,8 @@ template struct BackwardKernel { void operator()(ffStream_t stream, ElementUnaryPerDeviceState const &m, - ElementUnaryUnifiedAttrs const &attrs, + OperatorType op_type, + std::optional scalar, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -254,7 +255,6 @@ struct BackwardKernel { GenericTensorAccessorR const &output_grad) { checkCUDNN(cudnnSetStream(handle.dnn, stream)); - Op op_type = get_op_type(attrs); if (use_cudnn(op_type)) { float alpha = 1.0f; checkCUDNN(cudnnActivationBackward(handle.dnn, @@ -274,7 +274,7 @@ struct BackwardKernel { elewise_unary_backward_kernel> <<>>( num_elements, - get_scalar>(attrs), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -286,17 +286,59 @@ struct BackwardKernel { void forward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}( - input.data_type, stream, device_state, attrs, handle, input, output); + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + output); +} + +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementScalarUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + attrs.scalar, + handle, + input, + output); +} + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + get_op_type(attrs), + std::nullopt, + handle, + input, + input_grad, + output, + output_grad); } void backward_kernel(ffStream_t stream, ElementUnaryPerDeviceState const &device_state, - ElementUnaryUnifiedAttrs const &attrs, + ElementScalarUnaryAttrs const &attrs, PerDeviceFFHandle const &handle, GenericTensorAccessorR const &input, GenericTensorAccessorW const &input_grad, @@ -305,7 +347,8 @@ void backward_kernel(ffStream_t stream, DataTypeDispatch1{}(input.data_type, stream, device_state, - attrs, + get_op_type(attrs), + attrs.scalar, handle, input, input_grad, diff --git a/lib/kernels/src/cuda/ops/gather_kernels.cu b/lib/kernels/src/cuda/ops/gather_kernels.cu index 37d0112eab..e002cf7e71 100644 --- a/lib/kernels/src/cuda/ops/gather_kernels.cu +++ b/lib/kernels/src/cuda/ops/gather_kernels.cu @@ -25,10 +25,10 @@ template __global__ void gather_forward(float const *input, IndexType const *index, float *output, - size_t output_size, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { + coord_t output_size, + coord_t stride, + coord_t input_dim_size, + coord_t output_dim_size) { CUDA_KERNEL_LOOP(o, output_size) { // output tensor shape: [*, output_dim_size, stride] // output tensor stride: [output_dim_size * stride, stride, 1] @@ -39,10 +39,10 @@ __global__ void gather_forward(float const *input, // [outer_index, index[0], left_over] // Therefore, input_index = outer_index * (stride * input_dim_size) // + index[0] * stride + left_over; - size_t outer_index = o / (stride * output_dim_size); + coord_t outer_index = o / (stride * output_dim_size); // coord_t index_2 = (o / stride) % dim_size - size_t left_over = o % stride; - size_t input_idx = + coord_t left_over = o % stride; + coord_t input_idx = outer_index * (stride * input_dim_size) + index[o] * stride + left_over; output[o] = input[input_idx]; } @@ -52,10 +52,10 @@ template __global__ void gather_backward(float const *output_grad, IndexType const *index, float *input_grad, - size_t output_size, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { + coord_t output_size, + coord_t stride, + coord_t input_dim_size, + coord_t output_dim_size) { CUDA_KERNEL_LOOP(o, output_size) { // output tensor shape: [*, output_dim_size, stride] // output tensor stride: [output_dim_size * stride, stride, 1] @@ -66,10 +66,10 @@ __global__ void gather_backward(float const *output_grad, // [outer_index, index[0], left_over] // Therefore, input_index = outer_index * (stride * input_dim_size) // + index[0] * stride + left_over; - size_t outer_index = o / (stride * output_dim_size); + coord_t outer_index = o / (stride * output_dim_size); // coord_t index_2 = (o / stride) % dim_size - size_t left_over = o % stride; - size_t input_idx = + coord_t left_over = o % stride; + coord_t input_idx = outer_index * (stride * input_dim_size) + index[o] * stride + left_over; atomicAdd(&input_grad[input_idx], output_grad[o]); @@ -78,100 +78,96 @@ __global__ void gather_backward(float const *output_grad, template struct ForwardKernel { - void operator()(cudaStream_t stream, - GatherPerDeviceState const &m, + void operator()(ffStream_t stream, GenericTensorAccessorR const &input, GenericTensorAccessorR const &index, GenericTensorAccessorW const &output, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { - /*size_t stride = 1; - for (int i = 0; i < m->legion_dim; i++) { - stride *= (output.domain.hi()[i] - output.domain.lo()[i] + 1); - } - size_t dim_size = - output.domain.hi()[m->legion_dim] - output.domain.lo()[m->legion_dim] + - 1; -*/ - gather_forward> - <<>>(input.get(), - index.get(), - output.get(), - output.shape.get_volume(), - stride, - input_dim_size, - output_dim_size); + coord_t output_size, + coord_t stride, + coord_t input_dim_size, + coord_t output_dim_size) { + gather_forward<<>>( + input.get_float_ptr(), + index.get(), + output.get_float_ptr(), + output_size, + stride, + input_dim_size, + output_dim_size); } }; -void forward_kernel(cudaStream_t stream, +template +struct BackwardKernel { + void operator()(ffStream_t stream, + GenericTensorAccessorR const &output_grad, + GenericTensorAccessorR const &index, + GenericTensorAccessorW const &input_grad, + coord_t output_size, + coord_t stride, + coord_t input_dim_size, + coord_t output_dim_size) { + gather_backward<<>>( + output_grad.get_float_ptr(), + index.get(), + input_grad.get_float_ptr(), + output_size, + stride, + input_dim_size, + output_dim_size); + } +}; + +void forward_kernel(ffStream_t stream, GatherPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorR const &index, - GenericTensorAccessorW const &output, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { - DataTypeDispatch1{}(m.index_data_type, + GenericTensorAccessorW const &output) { + checkCUDA(get_legion_stream(&stream)); + + coord_t stride = + output.shape.sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) + .num_elements(); + coord_t output_dim_size = output.shape[m.legion_dim]; + coord_t input_dim_size = input.shape[m.legion_dim]; + + assert(index.data_type == DataType::INT32 || + index.data_type == DataType::INT64); + + DataTypeDispatch1{}(index.data_type, stream, - m, input, index, output, + output.shape.get_volume(), stride, input_dim_size, output_dim_size); } -template -struct BackwardKernel { - void operator()(cudaStream_t stream, - GatherPerDeviceState const &m, - GenericTensorAccessorR const &output_grad, - GenericTensorAccessorR const &index, - GenericTensorAccessorW const &input_grad, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { - /*size_t stride = 1; - for (int i = 0; i < m->legion_dim; i++) { - stride *= (output_grad.domain.hi()[i] - output_grad.domain.lo()[i] + 1); - } - size_t dim_size = output_grad.domain.hi()[m->legion_dim] - - output_grad.domain.lo()[m->legion_dim] + 1; - */ - gather_backward> - <<>>(output_grad.get(), - index.get(), - input_grad.get(), - output_grad.shape.get_volume(), - stride, - input_dim_size, - output_dim_size); - } -}; - -void backward_kernel(cudaStream_t stream, +void backward_kernel(ffStream_t stream, GatherPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorR const &index, - GenericTensorAccessorW const &input_grad, - size_t stride, - size_t input_dim_size, - size_t output_dim_size) { - DataTypeDispatch1{}(m.index_data_type, + GenericTensorAccessorW const &input_grad) { + checkCUDA(get_legion_stream(&stream)); + + coord_t stride = + output_grad.shape + .sub_shape(std::nullopt, add_to_legion_dim(m.legion_dim, 1)) + .get_volume(); + coord_t output_dim_size = output_grad.shape[m.legion_dim]; + coord_t input_dim_size = input_grad.shape[m.legion_dim]; + + assert(index.data_type == DataType::INT32 || + index.data_type == DataType::INT64); + + DataTypeDispatch1{}(index.data_type, stream, - m, output_grad, index, input_grad, + output_grad.shape.get_volume(), stride, input_dim_size, output_dim_size); diff --git a/lib/kernels/src/cuda/ops/linear_kernels.cu b/lib/kernels/src/cuda/ops/linear_kernels.cu index 81ab34380e..9a36534a1b 100644 --- a/lib/kernels/src/cuda/ops/linear_kernels.cu +++ b/lib/kernels/src/cuda/ops/linear_kernels.cu @@ -253,9 +253,8 @@ void backward_kernel(cudaStream_t stream, // do nothing } else { RegularizerAttrs regularizer_attrs = m.regularizer.value(); - if (std::holds_alternative(regularizer_attrs)) { - L2RegularizerAttrs l2_attrs = - std::get(regularizer_attrs); + if (regularizer_attrs.has()) { + L2RegularizerAttrs l2_attrs = regularizer_attrs.get(); float lambda = l2_attrs.lambda; checkCUDA(cublasSgeam(m.handle.blas, CUBLAS_OP_N, diff --git a/lib/kernels/src/cuda/ops/reduce_kernels.cu b/lib/kernels/src/cuda/ops/reduce_kernels.cu index 8571219648..02a89da807 100644 --- a/lib/kernels/src/cuda/ops/reduce_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduce_kernels.cu @@ -71,10 +71,10 @@ void backward_kernel(cudaStream_t stream, checkCUDNN(cudnnSetStream(m.handle.dnn, stream)); float alpha = 1.0, beta = 1.0f; switch (m.op_type) { - case Op::REDUCE_SUM: + case OperatorType::REDUCE_SUM: alpha = 1.0f; break; - case Op::REDUCE_MEAN: + case OperatorType::REDUCE_MEAN: // When the output is the average of multiple input elements // we need to scale the gradients by 1.0 / reduction_size alpha = 1.0f / m.reduction_size; diff --git a/lib/kernels/src/cuda/ops/transpose_kernels.cu b/lib/kernels/src/cuda/ops/transpose_kernels.cu index 7dac25d0c9..3b3f80944d 100644 --- a/lib/kernels/src/cuda/ops/transpose_kernels.cu +++ b/lib/kernels/src/cuda/ops/transpose_kernels.cu @@ -33,10 +33,10 @@ TransposePerDeviceState init_kernel(int num_dim, std::vector const &perm) { int const length = perm.size(); - std::vector perm_vector; + std::vector perm_vector; assert(length <= MAX_TENSOR_DIM); for (int i = 0; i < length; ++i) { - perm_vector.push_back(perm[i].value()); + perm_vector.push_back(legion_dim_from_ff_dim(perm[i], num_dim)); } return {num_dim, perm_vector}; @@ -77,7 +77,7 @@ void forward_kernel(cudaStream_t stream, info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; } - info.perm[i] = m.perm[i]; + info.perm[i] = m.perm[i].value; } transpose_simple_kernel<< #if defined(FF_USE_CUDA) diff --git a/lib/kernels/src/hip/element_binary_kernels.cpp b/lib/kernels/src/hip/element_binary_kernels.cpp index 5d29c27837..bc66bbff2f 100644 --- a/lib/kernels/src/hip/element_binary_kernels.cpp +++ b/lib/kernels/src/hip/element_binary_kernels.cpp @@ -14,13 +14,72 @@ */ #include "kernels/element_binary_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include "kernels/ff_handle.h" +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.dtg.h" #include namespace FlexFlow { namespace Kernels { namespace ElementBinary { +__global__ void elewise_binary_backward_kernel(coord_t volume, + float const alpha, + float const beta, + OperatorType type, + float const *out_grad, + float const *lhs, + float const *rhs, + float *lhs_grad, + float *rhs_grad) { + CUDA_KERNEL_LOOP(i, volume) { + switch (type) { + case OperatorType::EW_ADD: { + lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; + rhs_grad[i] = alpha * out_grad[i] + beta * rhs_grad[i]; + break; + } + case OperatorType::EW_SUB: { + lhs_grad[i] = alpha * out_grad[i] + beta * lhs_grad[i]; + rhs_grad[i] = -alpha * out_grad[i] + beta * rhs_grad[i]; + break; + } + case OperatorType::EW_MUL: { + lhs_grad[i] = alpha * out_grad[i] * rhs[i] + beta * lhs_grad[i]; + rhs_grad[i] = alpha * out_grad[i] * lhs[i] + beta * rhs_grad[i]; + break; + } + case OperatorType::EW_DIV: { + lhs_grad[i] = alpha * out_grad[i] / rhs[i] + beta * lhs_grad[i]; + rhs_grad[i] = -alpha * out_grad[i] * lhs[i] / (rhs[i] * rhs[i]) + + beta * rhs_grad[i]; + break; + } + case OperatorType::EW_MAX: { + lhs_grad[i] = (lhs[i] >= rhs[i]) + ? alpha * out_grad[i] + beta * lhs_grad[i] + : beta * lhs_grad[i]; + rhs_grad[i] = (rhs[i] >= lhs[i]) + ? alpha * out_grad[i] + beta * rhs_grad[i] + : beta * rhs_grad[i]; + break; + } + case OperatorType::EW_MIN: { + lhs_grad[i] = (lhs[i] <= rhs[i]) + ? alpha * out_grad[i] + beta * lhs_grad[i] + : beta * lhs_grad[i]; + rhs_grad[i] = (rhs[i] <= lhs[i]) + ? alpha * out_grad[i] + beta * rhs_grad[i] + : beta * rhs_grad[i]; + break; + } + default: + assert(false); + } + } +} + ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, OperatorType op_type, bool should_broadcast_lhs, @@ -42,18 +101,18 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, checkCUDNN(miopenCreateReduceTensorDescriptor(&reduceAddDesc)); switch (op_type) { - case Op::EW_ADD: - case Op::EW_SUB: + case OperatorType::EW_ADD: + case OperatorType::EW_SUB: mode = miopenTensorOpAdd; break; - case Op::EW_MUL: + case OperatorType::EW_MUL: mode = miopenTensorOpMul; break; - case Op::EW_MAX: - mode = miopenTensorOpMax; + case OperatorType::EW_MAX: + mode = miopenOpTensorMax; break; - case Op::EW_MIN: - mode = miopenTensorOpMin; + case OperatorType::EW_MIN: + mode = miopenOpTensorMin; break; default: assert(false); @@ -82,83 +141,8 @@ ElementBinaryPerDeviceState init_kernel(PerDeviceFFHandle handle, return per_device_state; } -__global__ void elewise_binary_forward_kernel(coord_t volume, - float const alpha, - float const beta, - OperatorType type, - float const *in1, - float const *in2, - float *out) { - switch (type) { - case Op::EW_ADD: { - CUDA_KERNEL_LOOP(i, volume) { - out[i] = alpha * (in1[i] + in2[i]) + beta * out[i]; - } - break; - } - case Op::EW_SUB: { - CUDA_KERNEL_LOOP(i, volume) { - out[i] = alpha * (in1[i] - in2[i]) + beta * out[i]; - } - break; - } - case Op::EW_MUL: { - CUDA_KERNEL_LOOP(i, volume) { - out[i] = alpha * in1[i] * in2[i] + beta * out[i]; - } - break; - } - case Op::EW_DIV: { - CUDA_KERNEL_LOOP(i, volume) { - out[i] = alpha * (in1[i] / in2[i]) + beta * out[i]; - } - break; - } - default: - assert(false); - } -} - -__global__ void elewise_binary_backward_kernel(coord_t volume, - float const alpha, - float const beta, - OperatorType type, - float const *out_grad, - float const *in1, - float const *in2, - float *in1_grad, - float *in2_grad) { - CUDA_KERNEL_LOOP(i, volume) { - switch (type) { - case Op::EW_ADD: { - in1_grad[i] = alpha * out_grad[i] + beta * in1_grad[i]; - in2_grad[i] = alpha * out_grad[i] + beta * in2_grad[i]; - break; - } - case Op::EW_SUB: { - in1_grad[i] = alpha * out_grad[i] + beta * in1_grad[i]; - in2_grad[i] = -alpha * out_grad[i] + beta * in2_grad[i]; - break; - } - case Op::EW_MUL: { - in1_grad[i] = alpha * out_grad[i] * in2[i] + beta * in1_grad[i]; - in2_grad[i] = alpha * out_grad[i] * in1[i] + beta * in2_grad[i]; - break; - } - case Op::EW_DIV: { - in1_grad[i] = alpha * out_grad[i] / in2[i] + beta * in1_grad[i]; - in2_grad[i] = -alpha * out_grad[i] * in1[i] / (in2[i] * in2[i]) + - beta * in2_grad[i]; - break; - } - default: - assert(false); - } - } -} - void forward_kernel(hipStream_t stream, - ElementBinaryPerDeviceState const *m, + ElementBinaryPerDeviceState const &m, float const *lhs_ptr, float const *rhs_ptr, float *out_ptr, @@ -170,11 +154,13 @@ void forward_kernel(hipStream_t stream, float alpha1 = 1.0f, alpha2 = 1.0f, beta = 0.0f; switch (op_type) { - case Op::EW_SUB: + case OperatorType::EW_SUB: alpha2 = -1.0f; break; - case Op::EW_ADD: - case Op::EW_MUL: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: break; default: assert(false); @@ -183,41 +169,76 @@ void forward_kernel(hipStream_t stream, // cudnnOpTensor if (broadcast_inputLHS) { // currently only handle add and sub - assert(op_type == Op::EW_SUB || op_type == Op::EW_ADD); - checkCUDNN(miopenOpTensor(handle.dnn, - m.opDesc, - &beta, - m.outputTensor, - out_ptr, - &alpha1, - m.inputLHSTensor, - lhs_ptr, - &beta, - m.outputTensor, - out_ptr)); - checkCUDNN(miopenOpTensor(handle.dnn, - m.opDesc, - &beta, - m.outputTensor, - out_ptr, - &alpha2, - m.inputRHSTensor, - rhs_ptr, - &alpha1, - m.outputTensor, - out_ptr)); - } else { - checkCUDNN(miopenOpTensor(handle.dnn, - m.opDesc, - &alpha1, - m.inputLHSTensor, - lhs_ptr, - &alpha2, - m.inputRHSTensor, - rhs_ptr, - &beta, - m.outputTensor, - out_ptr)); + assert(op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD || + op_type == OperatorType::EW_MUL); + if (op_type == OperatorType::EW_SUB || op_type == OperatorType::EW_ADD) { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &beta, + m.outputTensor, + out_ptr, + &alpha1, + m.inputLHSTensor, + lhs_ptr, + &beta, + m.outputTensor, + out_ptr)); + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &beta, + m.outputTensor, + out_ptr, + &alpha2, + m.inputRHSTensor, + rhs_ptr, + &alpha1, + m.outputTensor, + out_ptr)); + } else if (op_type == OperatorType::EW_MUL) { + checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, + CUDNN_OP_TENSOR_MUL, + CUDNN_DATA_FLOAT, + CUDNN_NOT_PROPAGATE_NAN)); + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.inputLHSTensor, + lhs_ptr, + &alpha2, + m.inputRHSTensor, + rhs_ptr, + &beta, + m.outputTensor, + out_ptr)); + checkCUDNN(cudnnSetOpTensorDescriptor(m.opDesc, + CUDNN_OP_TENSOR_ADD, + CUDNN_DATA_FLOAT, + CUDNN_NOT_PROPAGATE_NAN)); + + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &beta, + m.outputTensor, + out_ptr, + &alpha1, + m.inputLHSTensor, + lhs_ptr, + &beta, + m.outputTensor, + out_ptr)); + } else { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.inputLHSTensor, + lhs_ptr, + &alpha2, + m.inputRHSTensor, + rhs_ptr, + &beta, + m.outputTensor, + out_ptr)); + } } } @@ -235,10 +256,10 @@ void backward_kernel(hipStream_t stream, checkCUDA(hipblasSetStream(handle.blas, stream)); checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (m.op_type == Op::EW_ADD || m.op_type == Op::EW_SUB) { - float alpha = 1.0f, alpha2 = 0.0f, beta = 1.0f; + if (m.op_type == OperatorType::EW_ADD || m.op_type == OperatorType::EW_SUB) { + float alpha = 1.0f, beta = 1.0f; if (lhs_grad_ptr != nullptr) { - if (m.broadcast_input1) { + if (broadcast_inputLHS) { checkCUDNN(miopenReduceTensor(handle.dnn, m.reduceAddDesc, nullptr /*indices*/, @@ -257,7 +278,7 @@ void backward_kernel(hipStream_t stream, &alpha, m.outputTensor, out_grad_ptr, - &alpha2, + &alpha, m.outputTensor, out_grad_ptr, &beta, @@ -265,11 +286,11 @@ void backward_kernel(hipStream_t stream, lhs_grad_ptr)); } } - if (m.op_type == Op::EW_SUB) { + if (m.op_type == OperatorType::EW_SUB) { alpha = -1.0f; } if (rhs_grad_ptr != nullptr) { - if (m.broadcast_input2) { + if (broadcast_inputRHS) { checkCUDNN(miopenReduceTensor(handle.dnn, m.reduceAddDesc, nullptr /*indices*/, @@ -288,7 +309,7 @@ void backward_kernel(hipStream_t stream, &alpha, m.outputTensor, out_grad_ptr, - &alpha2, + &alpha, m.outputTensor, out_grad_ptr, &beta, @@ -296,34 +317,112 @@ void backward_kernel(hipStream_t stream, rhs_grad_ptr)); } } - } else if (m.op_type == Op::EW_MUL) { - float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f; + } else if (m.op_type == OperatorType::EW_MUL) { + float alpha1 = 1.0f, alpha2 = 1.0f, beta = 1.0f, zero = 0.0f; if (lhs_grad_ptr != nullptr) { - checkCUDNN(miopenOpTensor(handle.dnn, - m.opDesc, - &alpha1, - m.outputTensor, - out_grad_ptr, - &alpha2, - m.inputRHSTensor, - rhs_ptr, - &beta, - m.inputLHSTensor, - lhs_grad_ptr)); + if (broadcast_inputLHS) { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.outputTensor, + out_grad_ptr, + &alpha2, + m.inputRHSTensor, + rhs_ptr, + &beta, + m.inputLHSTensor, + lhs_grad_ptr)); + checkCUDNN(miopenReduceTensor(handle.dnn, + m.reduceAddDesc, + nullptr /*indices*/, + 0 /*indicesSizeInBytes*/, + handle.workSpace, + handle.workSpaceSize, + &alpha1, + m.outputTensor, + out_grad_ptr, + &zero, + m.inputLHSTensor, + lhs_grad_ptr)); + } else { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.outputTensor, + out_grad_ptr, + &alpha2, + m.inputRHSTensor, + rhs_ptr, + &beta, + m.inputLHSTensor, + lhs_grad_ptr)); + } } if (rhs_grad_ptr != nullptr) { - checkCUDNN(miopenOpTensor(handle.dnn, - m.opDesc, - &alpha1, - m.outputTensor, - out_grad_ptr, - &alpha2, - m.inputRHSTensor, - lhs_ptr, - &beta, - m.inputLHSTensor, - rhs_grad_ptr)); + if (broadcast_inputRHS) { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.outputTensor, + out_grad_ptr, + &alpha2, + m.inputLHSTensor, + lhs_ptr, + &beta, + m.inputRHSTensor, + rhs_grad_ptr)); + checkCUDNN(miopenReduceTensor(handle.dnn, + m.reduceAddDesc, + nullptr /*indices*/, + 0 /*indicesSizeInBytes*/, + handle.workSpace, + handle.workSpaceSize, + &alpha1, + m.outputTensor, + out_grad_ptr, + &zero, + m.inputRHSTensor, + rhs_grad_ptr)); + } else { + checkCUDNN(miopenOpTensor(handle.dnn, + m.opDesc, + &alpha1, + m.outputTensor, + out_grad_ptr, + &alpha2, + m.inputLHSTensor, + lhs_ptr, + &beta, + m.inputRHSTensor, + rhs_grad_ptr)); + } + } + } else if (op_type == Op::EW_MIN || op_type == Op::EW_MAX) { + float alpha = 1.0f, beta = 1.0f; + miopenDataType_t data_type; + int n; + int dims[MAX_TENSOR_DIM]; + int strides[MAX_TENSOR_DIM]; + checkCUDNN(miopenGetTensorDescriptorSize(m.outputTensor, &n)); + size_t volume = 1; + for (int i = 0; i < n; i++) { + volume *= dims[i]; } + // launch hip kernel + hipLaunchKernelGGL(elewise_binary_backward_kernel, + GET_BLOCKS(volume), + CUDA_NUM_THREADS, + 0, + stream, + volume, + alpha, + beta, + op_type, + out_grad_ptr, + lhs_ptr, + rhs_ptr, + lhs_grad_ptr, + rhs_grad_ptr); } else { assert(false && "Unsupported ElementWise Binary Type"); } diff --git a/lib/kernels/src/hip/element_unary_kernels.cpp b/lib/kernels/src/hip/element_unary_kernels.cpp index e79ef57592..e14018fa24 100644 --- a/lib/kernels/src/hip/element_unary_kernels.cpp +++ b/lib/kernels/src/hip/element_unary_kernels.cpp @@ -14,178 +14,76 @@ */ #include "kernels/element_unary_kernels.h" +#include "device.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" +#include "op-attrs/get_op_type.h" #include +#include namespace FlexFlow { namespace Kernels { namespace ElementUnary { +using coord_t = long long; + +static bool use_cudnn(OperatorType op_type) { + switch (op_type) { + case Op::RELU: + case Op::SIGMOID: + case Op::TANH: + case Op::ELU: + return true; + default: + return false; + } +} + +template +T get_scalar(ElementUnaryUnifiedAttrs const &attrs) { + if (std::holds_alternative(attrs)) { + return (T)std::get(attrs).scalar; + } else { + T dummy_scalar; + return dummy_scalar; + } +} + ElementUnaryPerDeviceState init_kernel(ArrayShape const &input_shape, ArrayShape const &output_shape, - ElementUnaryAttrs const &attrs) { - miopenTensorDescriptor_t inputTensor; - miopenTensorDescriptor_t outputTensor; - miopenActivationDescriptor_t actiDesc; - miopenActivationMode_t mode; + ElementUnaryUnifiedAttrs const &attrs) { + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffActivationDescriptor_t actiDesc; checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); - if (use_cudnn(attrs.op_type)) { - switch (attrs.op_type) { - case OP_SIGMOID: + Op op_type = get_op_type(attrs); + + if (use_cudnn(op_type)) { + miopenActivationMode_t mode; + switch (op_type) { + case Op::SIGMOID: mode = miopenActivationLOGISTIC; break; - case OP_RELU: + case Op::RELU: mode = miopenActivationRELU; break; - case OP_TANH: + case Op::TANH: mode = miopenActivationTANH; break; - case OP_ELU: + case Op::ELU: mode = miopenActivationELU; break; default: assert(false); } - checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN( - cudnnSetTensorDescriptorFromArrayShape(inputTensor, input_shape)); - // input_domain == output_domain - checkCUDNN( - cudnnSetTensorDescriptorFromArrayShape(outputTensor, output_shape)); - } - - ElementUnaryPerDeviceState per_device_state = { - inputTensor, outputTensor, actiDesc}; - - return per_device_state; -} - -bool use_cudnn(OperatorType type) { - if (type == OP_RELU) { - return true; - } - if (type == OP_SIGMOID) { - return true; - } - if (type == OP_TANH) { - return true; - } - if (type == OP_ELU) { - return true; - } - return false; -} - -template -struct ForwardKernel { - void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const &m, - ElementUnaryAttrs const &attrs, - PerDeviceFFHandle const &handle, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { - checkCUDNN(miopenSetStream(handle.dnn, stream)); - if (use_cudnn(attrs.op_type)) { - float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenActivationForward(handle.dnn, - m.actiDesc, - &alpha, - m.inputTensor, - input.get(), - &beta, - m.outputTensor, - output.get())); - } else { - size_t num_elements = input.shape.num_elements(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(elewise_unary_forward_kernel), - GET_BLOCKS(num_elements), - CUDA_NUM_THREADS, - 0, - stream, - num_elements, - (T)attrs.scalar, - attrs.op_type, - input.get(), - output.get()); - } } -} - -template -struct BackwardKernel { - void operator()(ffStream_t stream, - ElementUnaryPerDeviceState const &m, - ElementUnaryAttrs const &attrs, - PerDeviceFFHandle const &handle, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &input_grad, - GenericTensorAccessorW const &output, - GenericTensorAccessorW const &output_grad) { - checkCUDNN(miopenSetStream(handle.dnn, stream)); - - if (use_cudnn(attrs.op_type)) { - float alpha = 1.0f; - float beta = 0.0f; - checkCUDNN(miopenActivationBackward(handle.dnn, - m.actiDesc, - &alpha, - m.outputTensor, - output.get(), - m.outputTensor, - output_grad.get()), - m.inputTensor, - input.get(), - &beta, - m.inputTensor, - input_grad.get()); - } else { - size_t num_elements = input.shape.num_elements(); - hipLaunchKernelGGL(HIP_KERNEL_NAME(elewise_unary_backward_kernel), - GET_BLOCKS(num_elements), - CUDA_NUM_THREADS, - 0, - stream, - num_elements, - attrs.scalar, - attrs.op_type, - output.get(), - output_grad.get(), - input.get(), - input_grad.get()); - } - } -} void forward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, - PerDeviceFFHandle const &handle, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output) { - DataTypeDispatch1{}( - input.data_type, stream, m, attrs, handle, input, output); -} - -void backward_kernel(ffStream_t stream, - ElementUnaryPerDeviceState const &device_state, - ElementUnaryAttrs const &attrs, - PerDeviceFFHandle const &handle, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &input_grad, - GenericTensorAccessorR const &output, - GenericTensorAccessorR const &output_grad) { - DataTypeDispatch1{}(input.data_type, - stream, - m, - attrs, - handle, - input, - input_grad, - output, - output_grad); + checkCUDNN(miopenSetActivationDescriptor(actiDesc, mode, 0.0, 0.0, 0.0)); + checkCUDNN(cudnnSetTensorDescriptorFromDomain(inputTensor, input_shape)); + checkCUDNN(cudnnSetTensorDescriptorFromDomain(outputTensor, output_shape)); + return {inputTensor, outputTensor, actiDesc}; } template @@ -309,6 +207,115 @@ __global__ void elewise_unary_backward_kernel(coord_t volume, } } +template +struct ForwardKernel { + void operator()(ffStream_t stream, + ElementUnaryPerDeviceState const &m, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) const { + checkCUDNN(miopenSetStream(handle.dnn, stream)); + Op op_type = get_op_type(attrs); + if (use_cudnn(op_type)) { + float alpha = 1.0f, beta = 0.0f; + checkCUDNN(miopenActivationForward(handle.dnn, + m.actiDesc, + &alpha, + m.inputTensor, + input.get(), + &beta, + m.outputTensor, + output.get())); + } else { + size_t num_elements = input.shape.num_elements(); + hipLaunchKernelGGL(HIP_KERNEL_NAME(elewise_unary_forward_kernel), + GET_BLOCKS(num_elements), + CUDA_NUM_THREADS, + 0, + stream, + num_elements, + (T)m.scalar, + m.op_type, + input.get(), + output.get()); + } + } +}; + +template +struct BackwardKernel { + void operator()(ffStream_t stream, + ElementUnaryPerDeviceState const &m, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + checkCUDNN(miopenSetStream(handle.dnn, stream)); + + Op op_type = get_op_type(attrs); + if (use_cudnn(op_type)) { + float alpha = 1.0f; + checkCUDNN(miopenActivationBackward(handle.dnn, + m.actiDesc, + &alpha, + m.outputTensor, + output.get(), + m.outputTensor, + output_grad.get()), + m.inputTensor, + input.get(), + &beta, + m.inputTensor, + input_grad.get()); + } else { + size_t num_elements = input.shape.num_elements(); + hipLaunchKernelGGL(HIP_KERNEL_NAME(elewise_unary_backward_kernel), + GET_BLOCKS(num_elements), + CUDA_NUM_THREADS, + 0, + stream, + num_elements, + m.scalar, + m.op_type, + output.get(), + output_grad.get(), + input.get(), + input_grad.get()); + } + } +}; +void forward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + DataTypeDispatch1{}( + input.data_type, stream, device_state, attrs, handle, input, output); +} + +void backward_kernel(ffStream_t stream, + ElementUnaryPerDeviceState const &device_state, + ElementUnaryUnifiedAttrs const &attrs, + PerDeviceFFHandle const &handle, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output, + GenericTensorAccessorR const &output_grad) { + DataTypeDispatch1{}(input.data_type, + stream, + device_state, + attrs, + handle, + input, + input_grad, + output, + output_grad); +} + } // namespace ElementUnary } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/embedding_kernels.cpp b/lib/kernels/src/hip/embedding_kernels.cpp index 17edfea5c1..7ca3149f2f 100644 --- a/lib/kernels/src/hip/embedding_kernels.cpp +++ b/lib/kernels/src/hip/embedding_kernels.cpp @@ -14,145 +14,16 @@ */ #include "kernels/embedding_kernels.h" +#include "device.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" #include namespace FlexFlow { namespace Kernels { namespace Embedding { -template -struct ForwardKernel { - void operator()(hipStream_t stream, - AggrMode aggr, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &weight, - int in_dim, - int out_dim, - int batch_size) { - assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); - assert(weight.data_type == DT_HALF || weight.data_type == DT_FLOAT || - weight.data_type == DT_DOUBLE); - - if (aggr == AGGR_MODE_NONE) { - hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr), - GET_BLOCKS(output.shape.get_volume()), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - weight.get(), - out_dim, - batch_size); - } else { - hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr), - GET_BLOCKS(output.shape.get_volume()), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - weight.get(), - out_dim, - in_dim, - batch_size, - aggr); - } - } -} - -template -struct BackwardKernel { - void operator()(hipStream_t stream, - AggrMode aggr, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &output, - GenericTensorAccessorW const &weight_grad, - int in_dim, - int out_dim, - int batch_size) { - assert(input.data_type == DT_INT32 || input.data_type == DT_INT64); - assert(output.data_type == DT_HALF || output.data_type == DT_FLOAT, - || output.data_type == DT_DOUBLE); - if (aggr == AGGR_MODE_NONE) { - hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr), - GET_BLOCKS(output.shape.get_volume()), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - weight_grad.get(), - out_dim, - batch_size); - } else { - hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr), - GET_BLOCKS(output.shape.get_volume()), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - weight_grad.get(), - out_dim, - in_dim, - batch_size, - aggr); - } - } -} - -void forward_kernel(hipStream_t stream, - GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &weight, - DataType input_data_type, - DataType output_data_type, - AggrMode aggr, - int in_dim, - int out_dim, - int batch_size) { - DataTypeDispatch2{}(input_data_type, - output_data_type, - stream, - aggr, - input, - output, - weight, - in_dim, - out_dim, - batch_size); -} - -void backward_kernel(hipStream_t stream, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &output, - GenericTensorAccessorW const &weight_grad, - DataType input_data_type, - DataType output_data_type, - AggrMode aggr, - int in_dim, - int out_dim, - int batch_size) { - DataTypeDispatch2{}(input_data_type, - output_data_type, - stream, - aggr, - input, - output, - weight, - in_dim, - out_dim, - batch_size); -} - void rand_generate_int64_wrapper(int64_t *ptr, size_t size, int64_t p) { hipStream_t stream; - - // Randomly initialize the intput tensor to avoid out of index range issues hipLaunchKernelGGL(HIP_KERNEL_NAME(rand_generate_int), GET_BLOCKS(size), CUDA_NUM_THREADS, @@ -165,8 +36,6 @@ void rand_generate_int64_wrapper(int64_t *ptr, size_t size, int64_t p) { void rand_generate_int32_wrapper(int32_t *ptr, size_t size, int32_t p) { hipStream_t stream; - - // Randomly initialize the intput tensor to avoid out of index range issues hipLaunchKernelGGL(HIP_KERNEL_NAME(rand_generate_int), GET_BLOCKS(size), CUDA_NUM_THREADS, @@ -179,48 +48,131 @@ void rand_generate_int32_wrapper(int32_t *ptr, size_t size, int32_t p) { template __global__ void embed_forward_no_aggr( - TI const *input, TD *output, TD const *embed, int out_dim, int batch_size) { + TI const *input, TD *output, TD const *embed, int out_dim, int batch_size); +template +__global__ void embed_forward_with_aggr(TI const *input, + TD *output, + TD const *embed, + int out_dim, + int in_dim, + int batch_size, + AggregateOp aggr); +template +__global__ void embed_backward_no_aggr( + TI const *input, TD const *output, TD *embed, int out_dim, int batch_size); +template +__global__ void embed_backward_with_aggr(TI const *input, + TD const *output, + TD *embed, + int out_dim, + int in_dim, + int batch_size, + AggregateOp aggr); + +template +__global__ void embed_forward_no_aggr(int32_t const *input, + TD *output, + TD const *embed, + int out_dim, + int batch_size) { CUDA_KERNEL_LOOP(i, batch_size * out_dim) { output[i] = 0; int idx = i / out_dim; int off = i % out_dim; - TI wordIdx = input[idx]; + int32_t wordIdx = input[idx]; output[i] = embed[wordIdx * out_dim + off]; } } -template -__global__ void embed_forward_with_aggr(TI const *input, +template +__global__ void embed_forward_no_aggr(int64_t const *input, + TD *output, + TD const *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + output[i] = 0; + int idx = i / out_dim; + int off = i % out_dim; + int64_t wordIdx = input[idx]; + output[i] = embed[wordIdx * out_dim + off]; + } +} + +template +__global__ void embed_forward_with_aggr(int32_t const *input, TD *output, TD const *embed, int out_dim, int in_dim, int batch_size, - AggrMode aggr) { + AggregateOp aggr) { TD scale = 1.0f / in_dim; CUDA_KERNEL_LOOP(i, batch_size * out_dim) { output[i] = 0; int idx = i / out_dim; int off = i % out_dim; for (int j = 0; j < in_dim; j++) { - TI wordIdx = input[idx * in_dim + j]; + int32_t wordIdx = input[idx * in_dim + j]; output[i] = output[i] + embed[wordIdx * out_dim + off]; - if (aggr == AGGR_MODE_SUM) { + if (aggr == AggregateOp::SUM) { } else { - assert(aggr == AGGR_MODE_AVG); + assert(aggr == AggregateOp::AVG); output[i] = output[i] * scale; } } } } -template -__global__ void embed_backward_no_aggr( - TI const *input, TD const *output, TD *embed, int out_dim, int batch_size) { +template +__global__ void embed_forward_with_aggr(int64_t const *input, + TD *output, + TD const *embed, + int out_dim, + int in_dim, + int batch_size, + AggregateOp aggr) { + TD scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + output[i] = 0; + int idx = i / out_dim; + int off = i % out_dim; + for (int j = 0; j < in_dim; j++) { + int64_t wordIdx = input[idx * in_dim + j]; + output[i] = output[i] + embed[wordIdx * out_dim + off]; + if (aggr == AggregateOp::SUM) { + } else { + assert(aggr == AggregateOp::AVG); + output[i] = output[i] * scale; + } + } + } +} + +template +__global__ void embed_backward_no_aggr(int32_t const *input, + TD const *output, + TD *embed, + int out_dim, + int batch_size) { CUDA_KERNEL_LOOP(i, batch_size * out_dim) { int idx = i / out_dim; int off = i % out_dim; - TI wordIdx = input[idx]; + int32_t wordIdx = input[idx]; + atomicAdd(embed + wordIdx * out_dim + off, output[i]); + } +} + +template +__global__ void embed_backward_no_aggr(int64_t const *input, + TD const *output, + TD *embed, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + int64_t wordIdx = input[idx]; atomicAdd(embed + wordIdx * out_dim + off, output[i]); } } @@ -228,15 +180,15 @@ __global__ void embed_backward_no_aggr( // Specialization for half type template <> -__global__ void embed_backward_no_aggr(int const *input, - half const *output, - half *embed, - int out_dim, - int batch_size) { +__global__ void embed_backward_no_aggr(int32_t const *input, + half const *output, + half *embed, + int out_dim, + int batch_size) { CUDA_KERNEL_LOOP(i, batch_size * out_dim) { int idx = i / out_dim; int off = i % out_dim; - int wordIdx = input[idx]; + int32_t wordIdx = input[idx]; #if __CUDA_ARCH__ >= 700 atomicAdd(embed + wordIdx * out_dim + off, output[i]); #else @@ -269,27 +221,53 @@ __global__ void embed_backward_no_aggr(int64_t const *input, } } -template -__global__ void embed_backward_with_aggr(TI const *input, +template +__global__ void embed_backward_with_aggr(int32_t const *input, TD const *output, TD *embed, int out_dim, int in_dim, int batch_size, - AggrMode aggr) { + AggregateOp aggr) { TD scale = 1.0f / in_dim; CUDA_KERNEL_LOOP(i, batch_size * out_dim) { int idx = i / out_dim; int off = i % out_dim; TD gradient; - if (aggr == AGGR_MODE_SUM) { + if (aggr == AggregateOp::SUM) { gradient = output[i]; } else { - assert(aggr == AGGR_MODE_AVG); + assert(aggr == AggregateOp::AVG); gradient = output[i] * scale; } for (int j = 0; j < in_dim; j++) { - TI wordIdx = input[idx * in_dim + j]; + int32_t wordIdx = input[idx * in_dim + j]; + atomicAdd(embed + wordIdx * out_dim + off, gradient); + } + } +} + +template +__global__ void embed_backward_with_aggr(int64_t const *input, + TD const *output, + TD *embed, + int out_dim, + int in_dim, + int batch_size, + AggregateOp aggr) { + TD scale = 1.0f / in_dim; + CUDA_KERNEL_LOOP(i, batch_size * out_dim) { + int idx = i / out_dim; + int off = i % out_dim; + TD gradient; + if (aggr == AggregateOp::SUM) { + gradient = output[i]; + } else { + assert(aggr == AggregateOp::AVG); + gradient = output[i] * scale; + } + for (int j = 0; j < in_dim; j++) { + int64_t wordIdx = input[idx * in_dim + j]; atomicAdd(embed + wordIdx * out_dim + off, gradient); } } @@ -298,26 +276,26 @@ __global__ void embed_backward_with_aggr(TI const *input, // Specialization for half type template <> -__global__ void embed_backward_with_aggr(int const *input, - half const *output, - half *embed, - int out_dim, - int in_dim, - int batch_size, - AggrMode aggr) { +__global__ void embed_backward_with_aggr(int32_t const *input, + half const *output, + half *embed, + int out_dim, + int in_dim, + int batch_size, + AggregateOp aggr) { half scale = 1.0f / in_dim; CUDA_KERNEL_LOOP(i, batch_size * out_dim) { int idx = i / out_dim; int off = i % out_dim; half gradient; - if (aggr == AGGR_MODE_SUM) { + if (aggr == AggregateOp::SUM) { gradient = output[i]; } else { - assert(aggr == AGGR_MODE_AVG); + assert(aggr == AggregateOp::AVG); gradient = output[i] * scale; } for (int j = 0; j < in_dim; j++) { - int wordIdx = input[idx * in_dim + j]; + int32_t wordIdx = input[idx * in_dim + j]; #if __CUDA_ARCH__ >= 700 atomicAdd(embed + wordIdx * out_dim + off, gradient); #else @@ -337,16 +315,16 @@ __global__ void embed_backward_with_aggr(int64_t const *input, int out_dim, int in_dim, int batch_size, - AggrMode aggr) { + AggregateOp aggr) { half scale = 1.0f / in_dim; CUDA_KERNEL_LOOP(i, batch_size * out_dim) { int idx = i / out_dim; int off = i % out_dim; half gradient; - if (aggr == AGGR_MODE_SUM) { + if (aggr == AggregateOp::SUM) { gradient = output[i]; } else { - assert(aggr == AGGR_MODE_AVG); + assert(aggr == AggregateOp::AVG); gradient = output[i] * scale; } for (int j = 0; j < in_dim; j++) { @@ -370,6 +348,138 @@ __global__ void rand_generate_int(TD *ptr, size_t size, TD p) { } } +template +struct ForwardKernel { + void operator()(hipStream_t stream, + AggrMode aggr, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &weight, + int in_dim, + int out_dim, + int batch_size) { + assert(input.data_type == DataType::INT32 || + input.data_type == DataType::INT64); + assert(weight.data_type == DataType::HALF || + weight.data_type == DataType::FLOAT || + weight.data_type == DataType::DOUBLE); + + if (aggr == AggregateOp::NONE) { + hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_no_aggr), + GET_BLOCKS(output.shape.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + weight.get(), + out_dim, + batch_size); + } else { + assert(aggr == AggregateOp::AVG || aggr == AggregateOp::SUM); + hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_forward_with_aggr), + GET_BLOCKS(output.shape.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + weight.get(), + out_dim, + in_dim, + batch_size, + aggr); + } + } +} + +template +struct BackwardKernel { + void operator()(hipStream_t stream, + AggrMode aggr, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &output, + GenericTensorAccessorW const &weight_grad, + int in_dim, + int out_dim, + int batch_size) { + assert(input.data_type == DataType::INT32 || + input.data_type == DataType::INT64); + assert(output.data_type == DataType::HALF || + output.data_type == DataType::FLOAT || + output.data_type == DataType::DOUBLE); + if (aggr == AggregateOp::NONE) { + hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_no_aggr), + GET_BLOCKS(output.shape.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + weight_grad.get(), + out_dim, + batch_size); + } else { + hipLaunchKernelGGL(HIP_KERNEL_NAME(embed_backward_with_aggr), + GET_BLOCKS(output.shape.get_volume()), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + weight_grad.get(), + out_dim, + in_dim, + batch_size, + aggr); + } + } +} + +void forward_kernel(ffStream_t stream, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &weight, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, + int in_dim, + int out_dim, + int batch_size) { + DataTypeDispatch2{}(input_data_type, + output_data_type, + stream, + aggr, + input, + output, + weight, + in_dim, + out_dim, + batch_size); +} + +void backward_kernel(ffStream_t stream, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &output, + GenericTensorAccessorW const &weight_grad, + DataType input_data_type, + DataType output_data_type, + AggrMode aggr, + int in_dim, + int out_dim, + int batch_size) { + DataTypeDispatch2{}(input_data_type, + output_data_type, + stream, + aggr, + input, + output, + weight, + in_dim, + out_dim, + batch_size); +} + } // namespace Embedding } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/hip/partition_kernels.cpp b/lib/kernels/src/hip/partition_kernels.cpp index 3761da5c84..4591247faa 100644 --- a/lib/kernels/src/hip/partition_kernels.cpp +++ b/lib/kernels/src/hip/partition_kernels.cpp @@ -14,21 +14,17 @@ */ #include "kernels/partition_kernels.h" +#include "device.h" #include "kernels/datatype_dispatch.h" -#include "kernels/hip_helper.h" #include namespace FlexFlow { - -RepartitionPerDeviceState::RepartitionPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} - namespace Kernels { namespace Repartition { tempate struct ForwardKernel { void operator()(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { checkCUDA(hipMemcpyAsync(output.get(), @@ -41,7 +37,7 @@ tempate struct ForwardKernel { tempate struct BackwardKernel { void operator()(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorW const &input_grad) { hipLaunchKernelGGL(HIP_KERNEL_NAME(add_kernel), @@ -55,19 +51,25 @@ tempate struct BackwardKernel { } } +RepartitionPerDeviceState + init_kernel(PerDeviceFFHandle const &handle, DataType data_type) { + RepartitionPerDeviceState per_device_state = {handle, data_type}; + return per_device_state; +} + void forward_kernel(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output) { - DataTypeDispatch1{}(m->data_type, stream, m, input, output) + DataTypeDispatch1{}(m.data_type, stream, m, input, output) } void backward_kernel(hipStream_t stream, - RepartitionPerDeviceState const *m, + RepartitionPerDeviceState const &m, GenericTensorAccessorR const &output_grad, GenericTensorAccessorW const &input_grad) { DataTypeDispatch1{}( - m->data_type, stream, m, input_grad, output_grad) + m.data_type, stream, m, input_grad, output_grad) } } // namespace Repartition diff --git a/lib/kernels/src/hip/pool_2d_kernels.cpp b/lib/kernels/src/hip/pool_2d_kernels.cpp index 0bb44c3e1a..ed942c105c 100644 --- a/lib/kernels/src/hip/pool_2d_kernels.cpp +++ b/lib/kernels/src/hip/pool_2d_kernels.cpp @@ -14,116 +14,122 @@ */ #include "kernels/pool_2d_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include namespace FlexFlow { -Pool2DPerDeviceState::Pool2DPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) { +namespace Kernels { +namespace Pool2D { + +Pool2DPerDeviceState init_kernel(PerDeviceFFHandle handle, + optional activation, + int input_w, + int input_h, + int input_c, + int input_n, + int output_w, + int output_h, + int output_c, + int output_n, + int pad_h, + int pad_w, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + PoolOp pool_type) { + ffTensorDescriptor_t inputTensor; + ffTensorDescriptor_t outputTensor; + ffPoolingDescriptor_t poolDesc; + ffActivationDescriptor_t actiDesc; + checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); checkCUDNN(miopenCreateTensorDescriptor(&outputTensor)); checkCUDNN(miopenCreatePoolingDescriptor(&poolDesc)); -} + checkCUDNN(miopenCreateActivationDescriptor(&actiDesc)); -namespace Kernels { -namespace Pool2D { - -void init_kernel(Pool2DPerDeviceState *m, - int input_w, - int input_h, - int input_c, - int input_n, - int output_w, - int output_h, - int output_c, - int output_n, - int pad_h, - int pad_w, - int kernel_h, - int kernel_w, - int stride_h, - int stride_w, - PoolType pool_type) { checkCUDNN(miopenSet4dTensorDescriptor( - m->inputTensor, miopenFloat, input_n, input_c, input_h, input_w)); - + inputTensor, miopenFloat, input_n, input_c, input_h, input_w)); miopenPoolingMode_t mode; - if (pool_type == POOL_MAX) { + if (pool_type == PoolOp::MAX) { mode = miopenPoolingMax; } else { - assert(pool_type == POOL_AVG); + assert(pool_type == PoolOp::AVG); mode = miopenPoolingAverage; } - checkCUDNN(miopenSet2dPoolingDescriptor( - m->poolDesc, mode, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + + checkCUDNN(miopenSetPooling2dDescriptor( + poolDesc, mode, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w)); + int n, c, h, w; - checkCUDNN(miopenGetPoolingForwardOutputDim( - m->poolDesc, m->inputTensor, &n, &c, &h, &w)); + checkCUDNN(miopenGetPooling2dForwardOutputDim( + poolDesc, inputTensor, &n, &c, &h, &w)); assert(n == output_n); assert(c == output_c); assert(h == output_h); assert(w == output_w); checkCUDNN( - miopenSet4dTensorDescriptor(m->outputTensor, miopenFloat, n, c, h, w)); + miopenSet4dTensorDescriptor(outputTensor, miopenFloat, n, c, h, w)); + bool relu = false; + if (activation == Activation::RELU) { + relu = true; + } + Pool2DPerDeviceState state = { + handle, + inputTensor, + outputTensor, + actiDesc, + poolDesc, + relu, + }; + return state; } void forward_kernel(hipStream_t stream, - Pool2DPerDeviceState const *m, + Pool2DPerDeviceState const &m, void const *input_ptr, void *output_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenPoolingForward(m->handle.dnn, - m->poolDesc, + checkCUDNN(miopenPoolingForward(m.handle.dnn, + m.poolDesc, &alpha, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->outputTensor, + m.outputTensor, output_ptr, true, - m->handle.workSpace, - m->handle.workSpaceSize)); - if (m->profiling) { - hipEventRecord(t_end, stream); - checkCUDA(hipEventSynchronize(t_end)); - // print_tensor<4, float>(acc_input.ptr, acc_input.rect, - // "[Pool2D:forward:input]"); print_tensor<4, float>(acc_output.ptr, - // acc_output.rect, "[Pool2D:forward:output]"); - float elapsed = 0; - checkCUDA(hipEventElapsedTime(&elapsed, t_start, t_end)); - hipEventDestroy(t_start); - hipEventDestroy(t_end); - printf("%s [Pool2D] forward time = %.2fms\n", m->op_name, elapsed); - } + m.handle.workSpace, + m.handle.workSpaceSize)); } void backward_kernel(hipStream_t stream, - Pool2DPerDeviceState const *m, + Pool2DPerDeviceState const &m, void const *input_ptr, void *input_grad_ptr, void const *output_ptr, void const *output_grad_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f; - float beta = 0.0f; - checkCUDNN(miopenPoolingBackward(m->handle.dnn, - m->poolDesc, + checkCUDNN(miopenPoolingBackward(m.handle.dnn, + m.poolDesc, &alpha, - m->outputTensor, + m.outputTensor, output_ptr, - m->outputTensor, + m.outputTensor, output_grad_ptr, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->inputTensor, + m.inputTensor, input_grad_ptr, - m->handle.workSpace)); + m.handle.workSpace)); } } // namespace Pool2D diff --git a/lib/kernels/src/hip/softmax_kernels.cpp b/lib/kernels/src/hip/softmax_kernels.cpp index e7bec53962..3a8f2813b7 100644 --- a/lib/kernels/src/hip/softmax_kernels.cpp +++ b/lib/kernels/src/hip/softmax_kernels.cpp @@ -14,40 +14,36 @@ */ #include "kernels/softmax_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::Domain; - -SoftmaxPerDeviceState::SoftmaxPerDeviceState(FFHandler handler, - Softmax const *softmax, - Domain const &input_domain) - : PerDeviceOpState(handler) { - checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); - checkCUDNN(cudnnSetTensorDescriptorFromDomain(inputTensor, input_domain)); - dim = softmax->dim; - profiling = softmax->profiling; - std::strcpy(op_name, softmax->name); -} namespace Kernels { namespace Softmax { +SoftmaxPerDeviceState init_kernel(PerDeviceFFHandle const &handle, int dim) { + ffTensorDescriptor_t inputTensor; + + checkCUDNN(miopenCreateTensorDescriptor(&inputTensor)); + + SoftmaxPerDeviceState per_device_state = {handle, inputTensor, dim}; + return per_device_state; +} + void forward_kernel(hipStream_t stream, - SoftmaxPerDeviceState const *m, + SoftmaxPerDeviceState const &m, float const *input_ptr, float *output_ptr) { - checkCUDNN(miopenSetStream(m->handle.dnn, stream)); + checkCUDNN(miopenSetStream(m.handle.dnn, stream)); float alpha = 1.0f, beta = 0.0f; - checkCUDNN(miopenSoftmaxForward_V2(m->handle.dnn, + checkCUDNN(miopenSoftmaxForward_V2(m.handle.dnn, &alpha, - m->inputTensor, + m.inputTensor, input_ptr, &beta, - m->inputTensor, + m.inputTensor, output_ptr, MIOPEN_SOFTMAX_ACCURATE, MIOPEN_SOFTMAX_MODE_CHANNEL)); diff --git a/lib/kernels/src/hip/split_kernels.cpp b/lib/kernels/src/hip/split_kernels.cpp index 439e715c88..5599ae6d6f 100644 --- a/lib/kernels/src/hip/split_kernels.cpp +++ b/lib/kernels/src/hip/split_kernels.cpp @@ -14,12 +14,10 @@ */ #include "kernels/split_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; namespace Kernels { namespace Split { diff --git a/lib/kernels/src/hip/topk_kernels.cpp b/lib/kernels/src/hip/topk_kernels.cpp index 4c9fa4f037..f085c5831f 100644 --- a/lib/kernels/src/hip/topk_kernels.cpp +++ b/lib/kernels/src/hip/topk_kernels.cpp @@ -14,15 +14,10 @@ */ #include "kernels/topk_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; - -TopKPerDeviceState::TopKPerDeviceState(FFHandler handler) - : PerDeviceOpState(handler) {} namespace Kernels { namespace TopK { @@ -36,6 +31,11 @@ struct Entry { T value; }; +TopKPerDeviceState init_kernel(bool sorted) { + TopKPerDeviceState per_device_state = {sorted}; + return per_device_state; +} + template struct LinearData { typedef Entry Entry; @@ -371,7 +371,7 @@ __global__ void topk_forward_kernel(T const *__restrict__ input, } void forward_kernel(hipStream_t stream, - TopKPerDeviceState const *m, + TopKPerDeviceState const &m, float const *input_ptr, float *output_ptr, int *indices_ptr, @@ -428,7 +428,7 @@ __global__ void topk_backward_kernel(T const *__restrict__ value_grad_ptr, } void backward_kernel(hipStream_t stream, - TopKPerDeviceState const *m, + TopKPerDeviceState const &m, float const *value_grad_ptr, int const *indices_ptr, float *in_grad_ptr, diff --git a/lib/kernels/src/hip/transpose_kernels.cpp b/lib/kernels/src/hip/transpose_kernels.cpp index de64c74719..ef9dd58c63 100644 --- a/lib/kernels/src/hip/transpose_kernels.cpp +++ b/lib/kernels/src/hip/transpose_kernels.cpp @@ -14,13 +14,12 @@ */ #include "kernels/transpose_kernels.h" -#include "kernels/hip_helper.h" +#include "device.h" +#include "kernels/accessor.h" +#include "utils/exception.h" #include namespace FlexFlow { -// declare Legion names -using Legion::coord_t; -using Legion::Domain; struct TransposeStrides { int num_dim; @@ -31,81 +30,103 @@ struct TransposeStrides { namespace Kernels { namespace Transpose { +TransposePerDeviceState init_kernel(int num_dim, + std::vector const &perm) { + int const length = perm.size(); + + std::vector perm_vector; + assert(length <= MAX_TENSOR_DIM); + for (int i = 0; i < length; ++i) { + perm_vector.push_back(perm[i].value()); + } + + return {num_dim, perm_vector}; +} + +__global__ void transpose_simple_kernel(std::size_t volume, + float const *in_ptr, + float *out_ptr, + const TransposeStrides info, + float const beta) { + CUDA_KERNEL_LOOP(o_idx, volume) { + coord_t i_idx = 0; + coord_t t = o_idx; + for (int i = info.num_dim - 1; i >= 0; i--) { + coord_t ratio = t / info.out_strides[i]; + t -= ratio * info.out_strides[i]; + i_idx += ratio * info.in_strides[info.perm[i]]; + } + out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx]; + } +} + void forward_kernel(hipStream_t stream, - TransposePerDeviceState const *m, - float const *input_ptr, - float *output_ptr, - Domain in_domain, - Domain out_domain) { + TransposePerDeviceState const &m, + GenericTensorAccessorW const &in_grad, + GenericTensorAccessorR const &out_grad) { TransposeStrides info; - info.num_dim = out_domain.get_dim(); - assert(info.num_dim == m->num_dim); + info.num_dim = in_grad.shape.num_dims(); + assert(info.num_dim == m.num_dim); for (int i = 0; i < info.num_dim; i++) { - int in_dim_size = (in_domain.hi()[i] - in_domain.lo()[i] + 1); - int out_dim_size = (out_domain.hi()[i] - out_domain.lo()[i] + 1); - info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size; - info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size; - info.perm[i] = m->perm[i]; + if (i == 0) { + info.in_strides[i] = 1; + info.out_strides[i] = 1; + } else { + int in_dim_size = input.shape[legion_dim_t(i)] + 1; + int out_dim_size = output.shape[legion_dim_t(i)] + 1; + info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; + info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; + } + info.perm[i] = m.perm[i]; } + hipLaunchKernelGGL(transpose_simple_kernel, - GET_BLOCKS(out_domain.get_volume()), + GET_BLOCKS(output.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, - out_domain.get_volume(), - input_ptr, - output_ptr, + output.shape.get_volume(), + input.get_float_ptr(), + output.get_float_ptr(), info, 0.0f /*beta*/); } void backward_kernel(hipStream_t stream, - TransposePerDeviceState const *m, + TransposePerDeviceState const &m, float *input_grad_ptr, float const *output_grad_ptr, Domain in_grad_domain, Domain out_grad_domain) { TransposeStrides info; - info.num_dim = in_grad_domain.get_dim(); - assert(info.num_dim == m->num_dim); + info.num_dim = in_grad.shape.num_dims(); + assert(info.num_dim == m.num_dim); for (int i = 0; i < info.num_dim; i++) { - int in_dim_size = (out_grad_domain.hi()[i] - out_grad_domain.lo()[i] + 1); - int out_dim_size = (in_grad_domain.hi()[i] - in_grad_domain.lo()[i] + 1); - info.in_strides[i] = (i == 0) ? 1 : info.in_strides[i - 1] * in_dim_size; - info.out_strides[i] = (i == 0) ? 1 : info.out_strides[i - 1] * out_dim_size; - info.perm[m->perm[i]] = i; + if (i == 0) { + info.in_strides[i] = 1; + info.out_strides[i] = 1; + } else { + int in_dim_size = out_grad.shape[legion_dim_t(i)] + 1; + int out_dim_size = in_grad.shape[legion_dim_t(i)] + 1; + info.in_strides[i] = info.in_strides[i - 1] * in_dim_size; + info.out_strides[i] = info.out_strides[i - 1] * out_dim_size; + } + info.perm[m.perm[i]] = i; } hipLaunchKernelGGL(transpose_simple_kernel, - GET_BLOCKS(in_grad_domain.get_volume()), + GET_BLOCKS(in_grad.shape.get_volume()), CUDA_NUM_THREADS, 0, stream, - in_grad_domain.get_volume(), - output_grad_ptr, - input_grad_ptr, + in_grad.shape.get_volume(), + out_grad.get_float_ptr(), + in_grad.get_float_ptr(), info, 1.0f /*beta*/); } -__global__ void transpose_simple_kernel(coord_t volume, - float const *in_ptr, - float *out_ptr, - const TransposeStrides info, - float const beta) { - CUDA_KERNEL_LOOP(o_idx, volume) { - coord_t i_idx = 0; - coord_t t = o_idx; - for (int i = info.num_dim - 1; i >= 0; i--) { - coord_t ratio = t / info.out_strides[i]; - t -= ratio * info.out_strides[i]; - i_idx += ratio * info.in_strides[info.perm[i]]; - } - out_ptr[o_idx] += out_ptr[o_idx] * beta + in_ptr[i_idx]; - } -} - } // namespace Transpose } // namespace Kernels } // namespace FlexFlow diff --git a/lib/kernels/src/kernels/legion_dim_t.dtg.cc b/lib/kernels/src/kernels/legion_dim_t.dtg.cc new file mode 100644 index 0000000000..99c1a3b3a2 --- /dev/null +++ b/lib/kernels/src/kernels/legion_dim_t.dtg.cc @@ -0,0 +1,69 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/kernels/include/kernels/legion_dim_t.struct.toml +/* proj-data +{ + "generated_from": "f67d6e50c53539a21d69e7162cf965f4" +} +*/ + +#include "kernels/legion_dim_t.dtg.h" + +#include + +namespace FlexFlow { +legion_dim_t::legion_dim_t(int const &value) : value(value) {} +bool legion_dim_t::operator==(legion_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool legion_dim_t::operator!=(legion_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool legion_dim_t::operator<(legion_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool legion_dim_t::operator>(legion_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool legion_dim_t::operator<=(legion_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool legion_dim_t::operator>=(legion_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::legion_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::legion_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::legion_dim_t const &v) { + j["__type"] = "legion_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(legion_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, legion_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/local-execution/CMakeLists.txt b/lib/local-execution/CMakeLists.txt index ee1d8fecdc..6b432fad75 100644 --- a/lib/local-execution/CMakeLists.txt +++ b/lib/local-execution/CMakeLists.txt @@ -12,4 +12,5 @@ ff_add_library( utils kernels pcg + spdlog ) \ No newline at end of file diff --git a/lib/runtime/src/task_spec/arg_ref.h b/lib/local-execution/include/local-execution/arg_ref.h similarity index 50% rename from lib/runtime/src/task_spec/arg_ref.h rename to lib/local-execution/include/local-execution/arg_ref.h index 62f89f0b5c..50fe4e6f80 100644 --- a/lib/runtime/src/task_spec/arg_ref.h +++ b/lib/local-execution/include/local-execution/arg_ref.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_ARG_REF_H -#define _FLEXFLOW_RUNTIME_SRC_ARG_REF_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H +#define _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H #include "kernels/ff_handle.h" -#include "runtime/profiling.h" -#include "runtime/task_spec/arg_type_runtime_tag.h" +#include "local-execution/profiling.h" +#include "local-execution/serialization.h" #include "utils/type_index.h" #include "utils/visitable.h" @@ -21,37 +21,43 @@ struct ArgRefSpec { template bool holds() const { - return this->type_tag.template matches(); + // return this->type_tag.template matches(); + + return matches(this->type_idx); } LABEL_TYPE const &get_ref_type() const { return this->ref_type; } - ArgTypeRuntimeTag get_type_tag() const { - return this->type_tag; + // TODO - how to extend this for legion runtime? + // ArgTypeRuntimeTag get_type_tag() const { + // return this->type_tag; + // } + std::type_index get_type_index() const { + return this->type_idx; } template static ArgRefSpec create(ArgRef const &r) { static_assert(is_serializable::value, "Type must be serializeable"); - return ArgRefSpec(ArgTypeRuntimeTag::create(), r.ref_type); + return ArgRefSpec(get_type_index_for_type(), r.ref_type); } template static ArgRefSpec create_device_specific(ArgRef const &r, size_t device_idx) { - return ArgRefSpec(ArgTypeRuntimeTag::create(), r.ref_type, device_idx); + return ArgRefSpec(get_type_index_for_type(), r.ref_type, device_idx); } private: - ArgRefSpec(ArgTypeRuntimeTag const &type_tag, LABEL_TYPE ref_type) - : type_tag(type_tag), ref_type(ref_type) {} + ArgRefSpec(std::type_index const &type_index, LABEL_TYPE ref_type) + : type_idx(type_index), ref_type(ref_type) {} - ArgTypeRuntimeTag type_tag; + std::type_index type_idx; LABEL_TYPE ref_type; - optional device_idx = nullopt; + std::optional device_idx = std::nullopt; }; } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/concrete_arg.h b/lib/local-execution/include/local-execution/concrete_arg.h new file mode 100644 index 0000000000..2db5e45e9e --- /dev/null +++ b/lib/local-execution/include/local-execution/concrete_arg.h @@ -0,0 +1,55 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H +#define _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H + +#include "local-execution/serialization.h" +#include "utils/type_index.h" +#include + +namespace FlexFlow { + +struct ConcreteArgSpec { +public: + ConcreteArgSpec() = delete; + + template + T const &get() const { + assert(matches(this->type_idx)); + + return *(T const *)ptr.get(); + } + + // ArgTypeRuntimeTag get_type_tag() const { + // return this->type_tag; + // } + // size_t serialize(Legion::Serializer &) const; + + std::type_index get_type_index() const { + return this->type_idx; + } + + template + static ConcreteArgSpec create(T const &t) { + static_assert(is_serializable::value, "Type must be serializable"); + + std::type_index type_idx = get_type_index_for_type(); + std::shared_ptr ptr = + std::static_pointer_cast(std::make_shared(t)); + + return ConcreteArgSpec(type_idx, ptr); + // ArgTypeRuntimeTag::create()); + } + +private: + ConcreteArgSpec(std::type_index const &type_index, + std::shared_ptr ptr) + : type_idx(type_index), ptr(ptr) {} + // ArgTypeRuntimeTag const &); + + // ArgTypeRuntimeTag type_tag; + std::type_index type_idx; + std::shared_ptr ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/include/runtime/config.h b/lib/local-execution/include/local-execution/config.h similarity index 89% rename from lib/runtime/include/runtime/config.h rename to lib/local-execution/include/local-execution/config.h index 34f45040d1..73653aebae 100644 --- a/lib/runtime/include/runtime/config.h +++ b/lib/local-execution/include/local-execution/config.h @@ -13,12 +13,11 @@ * limitations under the License. */ -#ifndef _FLEXFLOW_CONFIG_H_ -#define _FLEXFLOW_CONFIG_H_ -#include "legion.h" +#ifndef _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ +#define _FLEXFLOW_LOCAL_EXECUTION_CONFIG_H_ + #include "op-attrs/param_sync.h" #include "utils/fmt.h" -#include "utils/optional.h" #include "utils/visitable.h" #include @@ -47,6 +46,8 @@ struct FFInitInfo : public use_visitable_cmp { bool allowTensorOpMathConversion; }; +using legion_mapping_tag_id_t = unsigned long; + struct FFConfig : public use_visitable_cmp { public: enum PreservedIDs { @@ -64,7 +65,7 @@ struct FFConfig : public use_visitable_cmp { }; FFConfig() = default; - static Legion::MappingTagID get_hash_id(std::string const &pcname); + static legion_mapping_tag_id_t get_hash_id(std::string const &pcname); public: int epochs = 1; @@ -88,16 +89,17 @@ struct FFConfig : public use_visitable_cmp { bool enable_inplace_optimizations = false; // Control Tensor Op Math Conversion bool allow_tensor_op_math_conversion = false; - optional dataset_path = nullopt; - optional export_strategy_computation_graph_file = nullopt; + std::optional dataset_path = std::nullopt; + std::optional export_strategy_computation_graph_file = + std::nullopt; bool include_costs_dot_graph = false; - optional substitution_json_path = nullopt; + std::optional substitution_json_path = std::nullopt; int machine_model_version = 0; - optional machine_model_file = nullopt; + std::optional machine_model_file = std::nullopt; int simulator_segment_size = 16777216; // 16 MB int simulator_max_num_segments = 1; - optional search_num_nodes = nullopt; - optional search_num_workers = nullopt; + std::optional search_num_nodes = std::nullopt; + std::optional search_num_workers = std::nullopt; int base_optimize_threshold = 10; bool enable_control_replication = true; // The default python data loader type is 2 to enable control replication diff --git a/lib/runtime/src/cost_metrics.h b/lib/local-execution/include/local-execution/cost_metrics.h similarity index 95% rename from lib/runtime/src/cost_metrics.h rename to lib/local-execution/include/local-execution/cost_metrics.h index 77526ccd1a..edc0190daf 100644 --- a/lib/runtime/src/cost_metrics.h +++ b/lib/local-execution/include/local-execution/cost_metrics.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_COST_METRICS_H -#define _FLEXFLOW_RUNTIME_SRC_COST_METRICS_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_COST_METRICS_H +#define _FLEXFLOW_LOCAL_EXECUTION_COST_METRICS_H #include "utils/visitable.h" diff --git a/lib/runtime/src/task_spec/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h similarity index 64% rename from lib/runtime/src/task_spec/device_specific.h rename to lib/local-execution/include/local-execution/device_specific.h index e29e4e9450..6136d16f2d 100644 --- a/lib/runtime/src/task_spec/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -1,7 +1,7 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_DEVICE_SPECIFIC_ARG_H -#define _FLEXFLOW_RUNTIME_SRC_DEVICE_SPECIFIC_ARG_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_DEVICE_SPECIFIC_H +#define _FLEXFLOW_LOCAL_EXECUTION_DEVICE_SPECIFIC_H -#include "serialization.h" +#include "local-execution/serialization.h" #include "utils/exception.h" namespace FlexFlow { @@ -10,10 +10,17 @@ template struct DeviceSpecific { DeviceSpecific() = delete; + DeviceSpecific(T ptr_type) { // accessor + size_t device_idx = 0; + DeviceSpecific device_specific = + DeviceSpecific::create(device_idx, ptr_type); + this->ptr = device_specific.ptr; + this->device_idx = device_specific.device_idx; + } template static DeviceSpecific create(size_t device_idx, Args &&...args) { - NOT_IMPLEMENTED(); + NOT_IMPLEMENTED(); // accessor } T const *get(size_t curr_device_idx) const { @@ -26,6 +33,8 @@ struct DeviceSpecific { return this->ptr; } + // TODO: can modify ptr + private: T *ptr; size_t device_idx; diff --git a/lib/runtime/src/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h similarity index 92% rename from lib/runtime/src/legion_tensor_shape.h rename to lib/local-execution/include/local-execution/legion_tensor_shape.h index 1f5fab76a6..ff96ba9a15 100644 --- a/lib/runtime/src/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -28,8 +28,8 @@ struct LegionTensorShape : public use_visitable_cmp, DataType data_type; }; -ff_dim_t to_ff(legion_dim_t, int num_dims); -legion_dim_t to_legion(ff_dim_t, int num_dims); +ff_dim_t to_ff(legion_dim_t, size_t num_dims); +legion_dim_t to_legion(ff_dim_t, size_t num_dims); ff_dim_t to_ff(legion_dim_t, TensorShape const &); legion_dim_t to_legion(ff_dim_t, TensorShape const &); diff --git a/lib/local-execution/include/local_allocator.h b/lib/local-execution/include/local-execution/local_allocator.h similarity index 82% rename from lib/local-execution/include/local_allocator.h rename to lib/local-execution/include/local-execution/local_allocator.h index f4b253b281..b47220eb8c 100644 --- a/lib/local-execution/include/local_allocator.h +++ b/lib/local-execution/include/local-execution/local_allocator.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_LOCAL_ALLOCATOR_H -#define _FLEXFLOW_RUNTIME_SRC_LOCAL_ALLOCATOR_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ALLOCATOR_H +#define _FLEXFLOW_LOCAL_EXECUTION_LOCAL_ALLOCATOR_H #include "kernels/allocation.h" #include diff --git a/lib/runtime/src/task_spec/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h similarity index 54% rename from lib/runtime/src/task_spec/op_arg_ref.h rename to lib/local-execution/include/local-execution/op_arg_ref.h index 3e931d79a4..1650656b42 100644 --- a/lib/runtime/src/task_spec/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_ARG_REF_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_ARG_REF_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_ARG_REF_H +#define _FLEXFLOW_LOCAL_EXECUTION_OP_ARG_REF_H -#include "arg_ref.h" -#include "device_specific.h" +#include "local-execution/arg_ref.h" +#include "local-execution/device_specific.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -15,13 +15,9 @@ using OpArgRef = ArgRef; using OpArgRefSpec = ArgRefSpec; template -OpArgRef> per_device_op_state() { - return {OpArgRefType::PER_DEVICE_OP_STATE}; -} +OpArgRef> per_device_op_state(); -OpArgRef input_parallel_tensor_shape(int idx) { - return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; -} +OpArgRef input_parallel_tensor_shape(int idx); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h new file mode 100644 index 0000000000..37ca5c239d --- /dev/null +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -0,0 +1,97 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_INVOCATION_H +#define _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_INVOCATION_H + +#include "kernels/accessor.h" +#include "local-execution/concrete_arg.h" +#include "local-execution/op_arg_ref.h" +#include "local-execution/op_task_signature.h" +#include "local-execution/op_tensor_spec.h" +#include "local-execution/profiling.h" +#include "local-execution/runtime_arg_ref.h" +#include "local-execution/tasks.h" +#include "local-execution/variadic_tensor_ref.h" +#include "utils/bidict.h" +#include "utils/stack_map.h" +#include +#include +#include +#include + +namespace FlexFlow { + +enum class IsTrainable { YES, NO }; + +using OpArgSpec = + std::variant; + +struct OpTaskBinding { + OpTaskBinding() = default; + + void bind(slot_id, VariadicTensorRef const &) { + NOT_IMPLEMENTED(); + } + void bind(slot_id, OpTensorSpec const &); + void bind_grad(slot_id, OpTensorSpec const &); + + template + void bind_device_specific_arg(slot_id name, T const &t) { + NOT_IMPLEMENTED(); + } + + template + void bind_device_specific_arg(slot_id name, OpArgRef const &t) { + NOT_IMPLEMENTED(); + } + + template + void bind_arg(slot_id name, T const &t) { + this->insert_arg_spec(name, ConcreteArgSpec::create(t)); + } + + template + void bind_arg(slot_id name, RuntimeArgRef const &ref) { + this->insert_arg_spec(name, RuntimeArgRefSpec::create(ref)); + } + + template + void bind_arg(slot_id name, OpArgRef const &ref) { + this->insert_arg_spec(name, OpArgRefSpec::create(ref)); + } + + std::unordered_map, OpTensorSpec> const & + get_tensor_bindings() const; + std::unordered_map const &get_arg_bindings() const; + + void insert_arg_spec(slot_id name, OpArgSpec const &arg_spec) { + assert(!contains_key(this->arg_bindings, name)); + this->arg_bindings.insert({name, arg_spec}); + } + + std::unordered_map arg_bindings; + std::unordered_map, OpTensorSpec> tensor_bindings; +}; +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpTaskBinding, + arg_bindings, + tensor_bindings); + +struct OpTaskInvocation { +public: + OpTaskInvocation() = delete; + OpTaskInvocation(task_id_t const &task_id, OpTaskBinding const &binding) + : task_id(task_id), binding(binding) {} + +public: + task_id_t task_id; + OpTaskBinding binding; +}; +FF_VISITABLE_STRUCT(OpTaskInvocation, task_id, binding); + +OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); +OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); + +bool is_invocation_valid(OpTaskSignature const &sig, + OpTaskInvocation const &inv); + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/task_spec/op_task_signature.h b/lib/local-execution/include/local-execution/op_task_signature.h similarity index 73% rename from lib/runtime/src/task_spec/op_task_signature.h rename to lib/local-execution/include/local-execution/op_task_signature.h index 656df39309..3bcb8397b7 100644 --- a/lib/runtime/src/task_spec/op_task_signature.h +++ b/lib/local-execution/include/local-execution/op_task_signature.h @@ -1,8 +1,11 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_OP_TASK_SIGNATURE_H -#define _FLEXFLOW_RUNTIME_SRC_OP_TASK_SIGNATURE_H - -#include "task_invocation.h" -#include "task_signature.h" +#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_SIGNATURE_H +#define _FLEXFLOW_LOCAL_EXECUTION_OP_TASK_SIGNATURE_H + +#include "local-execution/serialization.h" +#include "local-execution/slot_id.h" +#include "local-execution/slot_type.h" +#include "local-execution/tasks.h" +#include "utils/type_index.h" #include "utils/visitable.h" namespace FlexFlow { @@ -14,6 +17,7 @@ enum class TensorRole { }; enum class OpTaskType { INIT, FWD, BWD }; +enum class IsGrad { YES, NO }; enum class OpSlotOptions { OPTIONAL, @@ -25,7 +29,6 @@ enum class OpSlotOptions { struct OpTensorSlotSpec { public: OpTensorSlotSpec() = delete; - OpTensorSlotSpec(slot_id, SlotType, TensorRole); public: slot_id name; @@ -41,7 +44,9 @@ struct OpTaskSignature { OpTaskSignature() = delete; explicit OpTaskSignature(OpTaskType); - OpTaskType get_task_type() const; + OpTaskType get_task_type() const { + return this->type; + } void add_input_slot(slot_id, SlotType slot_type = SlotType::TENSOR); void add_optional_input_slot(slot_id, SlotType slot_type = SlotType::TENSOR); @@ -59,45 +64,35 @@ struct OpTaskSignature { void add_from_slot_spec(OpTensorSlotSpec const &spec); - /* void add_input_slot(slot_id, Legion::PrivilegeMode); */ - /* void add_input_slot(slot_id, SlotType, Legion::PrivilegeMode); */ - - bool operator==(OpTaskSignature const &) const; - bool operator!=(OpTaskSignature const &) const; - template void add_arg_slot(slot_id name) { static_assert(is_serializable::value, "Type must be serializable"); + this->task_arg_types.insert({name, get_type_index_for_type()}); } template - void add_return_value(); + void add_return_value() { + this->return_value = get_type_index_for_type(); + } // adds arg_slot without checking is_serializable, used for arguments that are // deviceSpecific template void add_unchecked_arg_slot(slot_id name) { - NOT_IMPLEMENTED(); + this->task_arg_types.insert({name, get_type_index_for_type()}); } - std::unordered_set get_tensor_slots(); + std::unordered_set get_tensor_slots() const; void set_arg_types(std::unordered_map const &); - std::unordered_map get_arg_types(); + std::unordered_map get_arg_types() const; -private: + OpTaskType type; + std::optional return_value; std::unordered_map task_arg_types; std::unordered_set op_tensor_slots; }; - -template -OpTaskSignature init_signature(); -template -OpTaskSignature fwd_signature(); -template -OpTaskSignature bwd_signature(); - -template -OpTaskSignature get_signature(); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( + OpTaskSignature, type, return_value, task_arg_types, op_tensor_slots); template void register_task(task_id_t, @@ -112,6 +107,15 @@ void register_task(task_id_t, F const &func, F const &cpu_func); +template +OpTaskSignature init_signature(); + +template +OpTaskSignature fwd_signature(); + +template +OpTaskSignature bwd_signature(); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/op_tensor_spec.h b/lib/local-execution/include/local-execution/op_tensor_spec.h new file mode 100644 index 0000000000..cc2cd75153 --- /dev/null +++ b/lib/local-execution/include/local-execution/op_tensor_spec.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_OP_TENSOR_SPEC_REF_H +#define _FLEXFLOW_LOCAL_EXECUTION_OP_TENSOR_SPEC_REF_H + +#include "local-execution/op_task_signature.h" + +namespace FlexFlow { + +struct OpTensorSpec { + TensorRole role; + OpSlotOptions slot_option; + req idx; +}; +FF_VISITABLE_STRUCT(OpTensorSpec, role, slot_option, idx); + +OpTensorSpec input_tensor(int); +OpTensorSpec output_tensor(int); +OpTensorSpec weight_tensor(int); + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/permissions.h b/lib/local-execution/include/local-execution/permissions.h similarity index 84% rename from lib/runtime/src/permissions.h rename to lib/local-execution/include/local-execution/permissions.h index e7793a1dcb..ce19e38e7e 100644 --- a/lib/runtime/src/permissions.h +++ b/lib/local-execution/include/local-execution/permissions.h @@ -1,18 +1,13 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_PERMISSION_H -#define _FLEXFLOW_RUNTIME_SRC_PERMISSION_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_PERMISSION_H +#define _FLEXFLOW_LOCAL_EXECUTION_PERMISSION_H -#include "legion.h" #include "utils/exception.h" #include "utils/fmt.h" -#include "utils/optional.h" namespace FlexFlow { enum class Permissions { NONE, RO, WO, RW }; -Legion::PrivilegeMode to_legion(Permissions); -optional from_legion(Legion::PrivilegeMode); - Permissions join(Permissions lhs, Permissions rhs); Permissions meet(Permissions lhs, Permissions rhs); diff --git a/lib/runtime/include/runtime/profiling.h b/lib/local-execution/include/local-execution/profiling.h similarity index 64% rename from lib/runtime/include/runtime/profiling.h rename to lib/local-execution/include/local-execution/profiling.h index 3f43ede520..bd50801fc4 100644 --- a/lib/runtime/include/runtime/profiling.h +++ b/lib/local-execution/include/local-execution/profiling.h @@ -1,21 +1,20 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_PROFILING_H -#define _FLEXFLOW_RUNTIME_SRC_PROFILING_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_PROFILING_H +#define _FLEXFLOW_LOCAL_EXECUTION_PROFILING_H #include "kernels/profiling.h" -#include "legion.h" -#include "loggers.h" +#include "spdlog/spdlog.h" namespace FlexFlow { enum class EnableProfiling { YES, NO }; template -optional +std::optional profile(F const &f, ProfilingSettings profiling, Str s, Ts &&...ts) { - optional elapsed = + std::optional elapsed = profiling_wrapper(f, profiling, std::forward(ts)...); if (elapsed.has_value()) { - log_profile.debug(s, elapsed.value()); + spdlog::debug(s, elapsed.value()); } return elapsed; } diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.h b/lib/local-execution/include/local-execution/runtime_arg_ref.h similarity index 73% rename from lib/runtime/src/task_spec/runtime_arg_ref.h rename to lib/local-execution/include/local-execution/runtime_arg_ref.h index 655300e692..295f32455c 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.h +++ b/lib/local-execution/include/local-execution/runtime_arg_ref.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H #define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_RUNTIME_ARG_REF_H -#include "arg_ref.h" -#include "device_specific.h" -#include "runtime/config.h" +#include "local-execution/arg_ref.h" +#include "local-execution/config.h" +#include "local-execution/device_specific.h" namespace FlexFlow { @@ -20,7 +20,7 @@ using RuntimeArgRefSpec = ArgRefSpec; RuntimeArgRef profiling_settings(); RuntimeArgRef> ff_handle(); -RuntimeArgRef iteration_config(); +RuntimeArgRef> iteration_config(); } // namespace FlexFlow diff --git a/lib/runtime/src/serialization.h b/lib/local-execution/include/local-execution/serialization.h similarity index 55% rename from lib/runtime/src/serialization.h rename to lib/local-execution/include/local-execution/serialization.h index 65601990b0..147ed8159c 100644 --- a/lib/runtime/src/serialization.h +++ b/lib/local-execution/include/local-execution/serialization.h @@ -1,12 +1,9 @@ -#ifndef _FLEXFLOW_RUNTIME_SERIALIZATION_H -#define _FLEXFLOW_RUNTIME_SERIALIZATION_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H +#define _FLEXFLOW_LOCAL_EXECUTION_SERIALIZATION_H #include "kernels/device.h" #include "kernels/nccl.h" -#include "legion.h" -#include "legion/legion_utilities.h" #include "op-attrs/dim_ordered.h" -#include "utils/optional.h" #include "utils/required.h" #include "utils/type_traits.h" #include "utils/variant.h" @@ -28,23 +25,6 @@ namespace FlexFlow { template struct needs_serialization {}; -/* template */ -/* class Serializer { */ -/* void serialize(Legion::Serializer &, T const &) const; */ -/* void deserialize(Legion::Deserializer &, T &) const; */ -/* }; */ - -/* template struct trivially_serializable; */ - -/* template struct - * visit_trivially_serializable; */ - -/* template >::value && - * visit_serializable::value)>::type> */ - template struct visit_trivially_serializable; @@ -101,6 +81,10 @@ struct is_trivially_serializable< typename std::enable_if::value>::type> : std::true_type {}; +template +struct is_trivially_serializable> + : is_trivially_serializable {}; + template struct is_trivially_serializable> : is_trivially_serializable {}; @@ -155,108 +139,6 @@ static_assert(std::is_same, static_assert(visit_trivially_serializable::value, ""); static_assert(is_trivially_serializable::value, ""); -template -struct Serialization { - void serialize(Legion::Serializer &, T const &) const; - T deserialize(Legion::Deserializer &) const; -}; - -template -struct Serialization< - T, - typename std::enable_if::value>::type> { - static void serialize(Legion::Serializer &sez, T const &t) { - sez.serialize(&t, sizeof(T)); - } - - static T const &deserialize(Legion::Deserializer &dez) { - void const *cur = dez.get_current_pointer(); - dez.advance_pointer(sizeof(T)); - return *(T const *)cur; - } -}; - -struct needs_serialize_visitor { - bool result = true; - - template - void operator()(char const *, T const &t) { - result &= needs_serialize(t); - } -}; - -template -bool visit_needs_serialize(T const &t) { - needs_serialize_visitor vis; - visit_struct::for_each(t, vis); - return vis.result; -} - -struct serialize_visitor { - serialize_visitor() = delete; - explicit serialize_visitor(Legion::Serializer &sez) : sez(sez) {} - - Legion::Serializer &sez; - - template - void operator()(char const *, T const &t) { - serialize(this->sez, t); - } -}; - -template -void visit_serialize(Legion::Serializer &sez, T const &t) { - serialize_visitor vis(sez); - visit_struct::for_each(t, vis); -} - -struct deserialize_visitor { - deserialize_visitor() = delete; - explicit deserialize_visitor(Legion::Deserializer &dez) : dez(dez) {} - - Legion::Deserializer &dez; - - template - T const &operator()(char const *, T &t) { - deserialize(dez, t); - } -}; - -template -T const &visit_deserialize(Legion::Deserializer &dez) { - deserialize_visitor vis(dez); - return visit_struct::for_each(vis); -} - -template -class VisitSerialize { - void serialize(Legion::Serializer &sez, T const &t) const { - return visit_serialize(sez, t); - } - - T const &deserialize(Legion::Deserializer &dez) const { - return visit_deserialize(dez); - } -}; - -template -size_t ff_task_serialize(Legion::Serializer &sez, T const &t) { - static_assert(is_serializable::value, "Type must be serializable"); - - size_t pre_size = sez.get_used_bytes(); - Serialization::serialize(sez, t); - size_t post_size = sez.get_used_bytes(); - - return post_size - pre_size; -} - -template -T const &ff_task_deserialize(Legion::Deserializer &dez) { - static_assert(is_serializable::value, "Type must be serializable"); - - return Serialization::deserialize(dez); -} - } // namespace FlexFlow #endif diff --git a/lib/runtime/src/sim_environment.h b/lib/local-execution/include/local-execution/sim_environment.h similarity index 94% rename from lib/runtime/src/sim_environment.h rename to lib/local-execution/include/local-execution/sim_environment.h index 4297d9d970..78608a3228 100644 --- a/lib/runtime/src/sim_environment.h +++ b/lib/local-execution/include/local-execution/sim_environment.h @@ -1,12 +1,13 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_OPS_SIM_ENVIRONMENT_H -#define _FLEXFLOW_RUNTIME_SRC_OPS_SIM_ENVIRONMENT_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_SIM_ENVIRONMENT_H +#define _FLEXFLOW_LOCAL_EXECUTION_SIM_ENVIRONMENT_H -#include "cost_metrics.h" #include "kernels/accessor.h" #include "kernels/allocation.h" +#include "local-execution/cost_metrics.h" +#include "local-execution/op_task_invocation.h" +#include "local-execution/task_argument_accessor.h" #include "op-attrs/parallel_tensor_shape.h" -#include "task_spec/op_task_invocation.h" -#include "task_spec/task_argument_accessor.h" +#include "pcg/machine_view.h" #include namespace FlexFlow { diff --git a/lib/runtime/include/runtime/task_spec/slot_id.h b/lib/local-execution/include/local-execution/slot_id.h similarity index 73% rename from lib/runtime/include/runtime/task_spec/slot_id.h rename to lib/local-execution/include/local-execution/slot_id.h index a5e4322d3c..53820fdb2f 100644 --- a/lib/runtime/include/runtime/task_spec/slot_id.h +++ b/lib/local-execution/include/local-execution/slot_id.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_SLOT_ID_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_SLOT_ID_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_SPEC_SLOT_ID_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASK_SPEC_SLOT_ID_H #include "utils/strong_typedef.h" diff --git a/lib/runtime/src/task_spec/slot_type.h b/lib/local-execution/include/local-execution/slot_type.h similarity index 86% rename from lib/runtime/src/task_spec/slot_type.h rename to lib/local-execution/include/local-execution/slot_type.h index 64b79ee281..957f89fa4e 100644 --- a/lib/runtime/src/task_spec/slot_type.h +++ b/lib/local-execution/include/local-execution/slot_type.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_SLOT_TYPE_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_SLOT_TYPE_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_SLOT_TYPE_H +#define _FLEXFLOW_LOCAL_EXECUTION_SLOT_TYPE_H #include "utils/fmt.h" diff --git a/lib/local-execution/include/local-execution/task_argument_accessor.h b/lib/local-execution/include/local-execution/task_argument_accessor.h new file mode 100644 index 0000000000..663c862e18 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_argument_accessor.h @@ -0,0 +1,155 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H + +#include "kernels/accessor.h" +#include "kernels/allocation.h" +#include "kernels/linear_kernels.h" +#include "local-execution/arg_ref.h" +#include "local-execution/concrete_arg.h" +#include "local-execution/config.h" +#include "local-execution/device_specific.h" +#include "local-execution/op_task_signature.h" +#include "local-execution/permissions.h" +#include "local-execution/tasks.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/variant.h" +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct privilege_mode_to_accessor_t {}; + +template <> +struct privilege_mode_to_accessor_t { + using type = GenericTensorAccessorW; +}; + +template <> +struct privilege_mode_to_accessor_t { + using type = GenericTensorAccessorR; +}; + +template <> +struct privilege_mode_to_accessor_t { + using type = GenericTensorAccessorW; +}; + +template +using privilege_mode_to_accessor = + typename privilege_mode_to_accessor_t::type; + +using PrivilegeType = + std::variant; +using PrivilegeVariadicType = std::variant, + std::vector>; + +// TODO: define device state variant in another file +using DeviceStates = std::variant; + +using OpArgRefTypeBacking = + std::variant>; +using RuntimeArgRefTypeBacking = std::variant, + FFIterationConfig>; + +using ArgRefBacking = std:: + variant; + +struct ITaskArgumentAccessor { + ITaskArgumentAccessor &operator=(ITaskArgumentAccessor const &) = delete; + + virtual ~ITaskArgumentAccessor() = default; + + virtual ConcreteArgSpec const &get_concrete_arg(slot_id) const = 0; + virtual OpArgRefTypeBacking const &get_op_arg_ref(slot_id) const = 0; + virtual RuntimeArgRefTypeBacking const &get_runtime_arg(slot_id) const = 0; + + virtual PrivilegeType + get_tensor(slot_id slot, Permissions priv, IsGrad is_grad) const = 0; + virtual PrivilegeVariadicType get_variadic_tensor(slot_id slot, + Permissions priv, + IsGrad is_grad) const = 0; + + virtual Allocator get_allocator() const = 0; + virtual size_t get_device_idx() const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ITaskArgumentAccessor); + +struct TaskArgumentAccessor { + template + T const &get_argument(slot_id slot) const { + if constexpr (is_in_variant::value) { + return std::get(this->ptr->get_op_arg_ref(slot)); + } else if constexpr (is_in_variant::value) { + return std::get(this->ptr->get_runtime_arg(slot)); + } else { + return this->ptr->get_concrete_arg(slot).get(); + } + } + + template + privilege_mode_to_accessor get_tensor(slot_id slot) const { + return std::get>( + this->ptr->get_tensor(slot, PRIV, IsGrad::NO)); + } + + template + privilege_mode_to_accessor get_tensor_grad(slot_id slot) const { + return std::get>( + this->ptr->get_tensor(slot, PRIV, IsGrad::YES)); + } + + template + std::vector> + get_variadic_tensor(slot_id slot) const { + return std::get>>( + this->ptr->get_variadic_tensor(slot, PRIV, IsGrad::NO)); + } + + template + std::vector> + get_variadic_tensor_grad(slot_id slot) const { + return std::get>>( + this->ptr->get_variadic_tensor(slot, PRIV, IsGrad::YES)); + } + + Allocator get_allocator() const { + return this->ptr->get_allocator(); + } + + template + static + typename std::enable_if::value, + TaskArgumentAccessor>::type + create(Args &&...args) { + return TaskArgumentAccessor( + std::make_shared(std::forward(args)...)); + } + +private: + TaskArgumentAccessor(std::shared_ptr ptr) + : ptr(ptr) {} + std::shared_ptr ptr; +}; + +using DeviceStates = std::variant; + +using TaskImplFunction = std::variant< + std::function, + std::function(TaskArgumentAccessor const &)>>; + +template +TaskImplFunction get_task_impl(); + +template +OpTaskSignature get_signature(); + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/tasks.h b/lib/local-execution/include/local-execution/tasks.h similarity index 95% rename from lib/runtime/src/tasks.h rename to lib/local-execution/include/local-execution/tasks.h index 0e07fa3f85..c78fefd4ea 100644 --- a/lib/runtime/src/tasks.h +++ b/lib/local-execution/include/local-execution/tasks.h @@ -1,8 +1,9 @@ -#ifndef _FLEXFLOW_TASKS_H -#define _FLEXFLOW_TASKS_H +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TASKS_H +#define _FLEXFLOW_LOCAL_EXECUTION_TASKS_H -#include "utils/optional.h" +#include #include +#include namespace FlexFlow { @@ -170,9 +171,9 @@ template void register_task(task_id_t, std::string const &name, F const &func, - optional cpu_func = nullopt); + std::optional cpu_func = std::nullopt); -template +template void register_task(); void register_tasks(); diff --git a/lib/local-execution/include/tracked_allocator.h b/lib/local-execution/include/local-execution/tracked_allocator.h similarity index 94% rename from lib/local-execution/include/tracked_allocator.h rename to lib/local-execution/include/local-execution/tracked_allocator.h index 4f51670426..ea3eec64e0 100644 --- a/lib/local-execution/include/tracked_allocator.h +++ b/lib/local-execution/include/local-execution/tracked_allocator.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LOCAL_EXECUTION_TRACKED_ALLOCATOR_H #include "kernels/allocation.h" -#include "local_allocator.h" +#include "local-execution/local_allocator.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/variadic_tensor_ref.h b/lib/local-execution/include/local-execution/variadic_tensor_ref.h new file mode 100644 index 0000000000..56da1bab64 --- /dev/null +++ b/lib/local-execution/include/local-execution/variadic_tensor_ref.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_VARIADIC_TENSOR_ARG_REF_H +#define _FLEXFLOW_LOCAL_EXECUTION_VARIADIC_TENSOR_ARG_REF_H + +#include "local-execution/arg_ref.h" +#include "local-execution/op_tensor_spec.h" + +namespace FlexFlow { + +enum class VariadicTensorRefType { INPUT_TENSORS }; + +template +using VariadicTensorRef = ArgRef; + +VariadicTensorRef get_input_tensors(); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/src/local_allocator.cc b/lib/local-execution/src/local_allocator.cc index 0bb7d04574..d393643ead 100644 --- a/lib/local-execution/src/local_allocator.cc +++ b/lib/local-execution/src/local_allocator.cc @@ -1,4 +1,4 @@ -#include "local_allocator.h" +#include "local-execution/local_allocator.h" #include "kernels/device.h" namespace FlexFlow { diff --git a/lib/local-execution/src/op_arg_ref.cc b/lib/local-execution/src/op_arg_ref.cc new file mode 100644 index 0000000000..8e9b56272b --- /dev/null +++ b/lib/local-execution/src/op_arg_ref.cc @@ -0,0 +1,14 @@ +#include "local-execution/op_arg_ref.h" + +namespace FlexFlow { + +template +OpArgRef> per_device_op_state() { + return {OpArgRefType::PER_DEVICE_OP_STATE}; +} + +OpArgRef input_parallel_tensor_shape(int idx) { + return {OpArgRefType::PARALLEL_TENSOR_SHAPE}; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/op_task_invocation.cc b/lib/local-execution/src/op_task_invocation.cc new file mode 100644 index 0000000000..adad2f3a72 --- /dev/null +++ b/lib/local-execution/src/op_task_invocation.cc @@ -0,0 +1,100 @@ +#include "local-execution/op_task_invocation.h" + +namespace FlexFlow { + +OpTensorSpec input_tensor(int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY) { + return {TensorRole::INPUT, option, idx}; +} + +OpTensorSpec output_tensor(int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY) { + return {TensorRole::OUTPUT, option, idx}; +} + +OpTensorSpec weight_tensor(int idx, + OpSlotOptions option = OpSlotOptions::NECESSARY) { + return {TensorRole::WEIGHT, option, idx}; +} + +void OpTaskBinding::bind(slot_id slot, OpTensorSpec const &tensor_spec) { + this->tensor_bindings.insert({{slot, IsGrad::NO}, tensor_spec}); +} + +void OpTaskBinding::bind_grad(slot_id slot, OpTensorSpec const &tensor_spec) { + this->tensor_bindings.insert({{slot, IsGrad::YES}, tensor_spec}); +} + +std::unordered_map, OpTensorSpec> const & + OpTaskBinding::get_tensor_bindings() const { + return this->tensor_bindings; +} + +std::unordered_map const & + OpTaskBinding::get_arg_bindings() const { + return this->arg_bindings; +} + +OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd) { + OpTaskBinding bwd; + bwd.arg_bindings = fwd.get_arg_bindings(); + bwd.tensor_bindings = fwd.get_tensor_bindings(); + for (auto const &[key, spec] : fwd.get_tensor_bindings()) { + OpSlotOptions slot_option = spec.slot_option; + if (slot_option != OpSlotOptions::UNTRAINABLE || + slot_option != OpSlotOptions::OPTIONAL_UNTRAINABLE) { + slot_id slot = key.first; + bwd.bind_grad(slot, spec); + } + } + return bwd; +} + +bool is_op_tensor_spec_invalid(OpTensorSlotSpec tensor_slot_spec, + OpTensorSpec tensor_spec) { + return tensor_spec.role != tensor_slot_spec.tensor_role || + tensor_spec.slot_option != tensor_slot_spec.slot_option; +} + +bool is_tensor_invocation_valid(OpTaskSignature const &sig, + OpTaskInvocation const &inv) { + auto tensor_bindings = inv.binding.get_tensor_bindings(); + for (OpTensorSlotSpec const &op_tensor_slot_spec : sig.get_tensor_slots()) { + std::pair tensor_key = + std::make_pair(op_tensor_slot_spec.name, op_tensor_slot_spec.is_grad); + OpTensorSpec const &op_tensor_spec = tensor_bindings.at(tensor_key); + if (is_op_tensor_spec_invalid(op_tensor_slot_spec, op_tensor_spec)) { + return false; + } + } + return true; +} + +bool is_arg_type_invalid(std::type_index expected_arg_type, + OpArgSpec op_arg_spec) { + std::type_index arg_spec_type = std::visit( + [](auto &&arg) -> std::type_index { return arg.get_type_index(); }, + op_arg_spec); + return arg_spec_type != expected_arg_type; +} + +bool is_arg_invocation_valid(OpTaskSignature const &sig, + OpTaskInvocation const &inv) { + auto sig_arg_types = sig.get_arg_types(); + for (auto arg_binding : inv.binding.get_arg_bindings()) { + std::type_index arg_type = sig_arg_types.at(arg_binding.first); + if (is_arg_type_invalid(arg_type, arg_binding.second)) { + return false; + } + } + + return true; +} + +bool is_invocation_valid(OpTaskSignature const &sig, + OpTaskInvocation const &inv) { + return is_tensor_invocation_valid(sig, inv) && + is_arg_invocation_valid(sig, inv); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc new file mode 100644 index 0000000000..53a685910e --- /dev/null +++ b/lib/local-execution/src/op_task_signature.cc @@ -0,0 +1,81 @@ +#include "local-execution/op_task_signature.h" + +namespace FlexFlow { + +OpTaskSignature::OpTaskSignature(OpTaskType t) : type(t){}; + +void OpTaskSignature::add_input_slot(slot_id name, SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = { + name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::NECESSARY}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_optional_input_slot(slot_id name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = { + name, slot_type, TensorRole::INPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_untrainable_input_slot(slot_id name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = {name, + slot_type, + TensorRole::INPUT, + IsGrad::NO, + OpSlotOptions::UNTRAINABLE}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_optional_untrainable_input_slot(slot_id name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = {name, + slot_type, + TensorRole::INPUT, + IsGrad::NO, + OpSlotOptions::OPTIONAL_UNTRAINABLE}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_output_slot(slot_id name, SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = { + name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_bwd_necessary_output_slot(slot_id name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = {name, + slot_type, + TensorRole::OUTPUT, + IsGrad::NO, + OpSlotOptions::NECESSARY}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_weight_slot(slot_id name, SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = {name, + slot_type, + TensorRole::WEIGHT, + IsGrad::NO, + OpSlotOptions::NECESSARY}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::add_optional_weight_slot(slot_id name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = { + name, slot_type, TensorRole::WEIGHT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + +void OpTaskSignature::set_arg_types( + std::unordered_map const &arg_type) { + this->task_arg_types = arg_type; +} + +void OpTaskSignature::add_from_slot_spec(OpTensorSlotSpec const &spec) { + this->op_tensor_slots.insert(spec); +} + +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc similarity index 84% rename from lib/runtime/src/ops/attention.cc rename to lib/local-execution/src/ops/attention.cc index 41905f9014..6e6d23cd4a 100644 --- a/lib/runtime/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -15,19 +15,12 @@ #include "attention.h" #include "kernels/attention_kernels.h" -#include "legion.h" -#include "op-attrs/ops/attention.h" -#include "task_spec/op_task_signature.h" +#include "local-execution/op_task_signature.h" namespace FlexFlow { using namespace FlexFlow::Kernels::MultiHeadAttention; -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - enum Slots { QUERY_PARALLEL_TENSOR_SHAPE, KEY_PARALLEL_TENSOR_SHAPE, @@ -86,6 +79,12 @@ OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) { return {ATTENTION_BWD_TASK_ID, b}; } +// OpArgBacking +// generate_op_arg_backing(std::vector +// tensor_shape_args) { + +// } + static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); @@ -122,35 +121,42 @@ static DeviceSpecific int num_samples = get_piece_shape(query_parallel_tensor_shape)[ff_dim_t(2)]; int num_heads = get_piece_shape(weight_parallel_tensor_shape)[ff_dim_t(1)]; + // MHAPerDeviceState per_device_state = + // init_kernel(handle, + // allocator, + // num_samples, + // num_heads, + // qSize, + // kSize, + // vSize, + // qProjSize, + // kProjSize, + // vProjSize, + // oProjSize, + // qoSeqLength, + // kvSeqLength, + // attrs.add_bias_kv); + // return acc.create_device_specific(per_device_state); + DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, - allocator, - num_samples, - num_heads, - qSize, - kSize, - vSize, - qProjSize, - kProjSize, - vProjSize, - oProjSize, - qoSeqLength, - kvSeqLength, - attrs.add_bias_kv)); + init_kernel(handle, + allocator, + num_samples, + num_heads, + qSize, + kSize, + vSize, + qProjSize, + kProjSize, + vProjSize, + oProjSize, + qoSeqLength, + kvSeqLength, + attrs.add_bias_kv); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto query = acc.get_tensor(QUERY); auto key = acc.get_tensor(KEY); auto value = acc.get_tensor(VALUE); @@ -162,7 +168,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[MultiHeadAttention] forward_time = %.2lfms\n", + "[MultiHeadAttention] forward_time = {:.2lf}ms\n", per_device_state, query.get_float_ptr(), key.get_float_ptr(), @@ -171,15 +177,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto query = acc.get_tensor(QUERY); auto key = acc.get_tensor(KEY); auto value = acc.get_tensor(VALUE); @@ -208,7 +207,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[MultiHeadAttention] backward_time = %.2lfms\n", + "[MultiHeadAttention] backward_time = {:.2lf}ms\n", per_device_state, query.get_float_ptr(), query_grad.get_float_ptr(), @@ -221,14 +220,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { output_grad.get_float_ptr()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, MultiHeadAttentionAttrs const &attrs, InputParallelTensorDesc const &query_shape, @@ -307,7 +298,12 @@ void register_task() { register_task(ATTENTION_INIT_TASK_ID, "Attention Init", init_signature(), - init_task); + init_task_impl); +} + +template <> +OpTaskSignature get_signature() { + return init_signature(); } template <> @@ -331,13 +327,13 @@ void register_task() { register_task(ATTENTION_FWD_TASK_ID, "Attention Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> OpTaskSignature bwd_signature() { OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(ATTENTION_FWD_TASK_ID)); + infer_bwd_signature(fwd_signature()); return bwd; } @@ -347,7 +343,7 @@ void register_task() { register_task(ATTENTION_BWD_TASK_ID, "Attention Bwd", bwd_signature(), - backward_task); + backward_task_impl); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/attention.h b/lib/local-execution/src/ops/attention.h similarity index 91% rename from lib/runtime/src/ops/attention.h rename to lib/local-execution/src/ops/attention.h index 09a4ef036f..c8eb17ecec 100644 --- a/lib/runtime/src/ops/attention.h +++ b/lib/local-execution/src/ops/attention.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_ATTENTION_H #define _FLEXFLOW_ATTENTION_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/attention.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/batch_matmul.cc b/lib/local-execution/src/ops/batch_matmul.cc similarity index 87% rename from lib/runtime/src/ops/batch_matmul.cc rename to lib/local-execution/src/ops/batch_matmul.cc index 5f40def699..187e97ecaa 100644 --- a/lib/runtime/src/ops/batch_matmul.cc +++ b/lib/local-execution/src/ops/batch_matmul.cc @@ -15,20 +15,14 @@ #include "batch_matmul.h" #include "kernels/batch_matmul_kernels.h" -#include "legion.h" +#include "local-execution/op_task_signature.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/batch_matmul.h" -#include "task_spec/op_task_signature.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchMatmul; -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - enum Slots { A_INPUT, // tensor B_INPUT, // tensor @@ -60,7 +54,7 @@ OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { return {BATCHMATMUL_BWD_TASK_ID, bwd}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto a_input = acc.get_tensor(A_INPUT); auto b_input = acc.get_tensor(B_INPUT); auto output = acc.get_tensor(OUTPUT); @@ -91,7 +85,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[BatchMatmul] forward_time = %.2lfms\n", + "[BatchMatmul] forward_time = {:.2lf}ms\n", handle, output.get_float_ptr(), a_input.get_float_ptr(), @@ -105,15 +99,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { iter_config.seq_length); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { // BatchMatmul* bmm = (BatchMatmul*) task->args; FFIterationConfig iter_config = acc.get_argument(ITERATION_CONFIG); @@ -151,7 +138,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[BatchMatmul] backward_time = %.2lfms\n", + "[BatchMatmul] backward_time = {:.2lf}ms\n", handle, output.get_float_ptr(), output_grad.get_float_ptr(), @@ -165,14 +152,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { batch); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, BatchMatmulAttrs const &attrs, InputParallelTensorDesc const &a_input, @@ -225,7 +204,7 @@ void register_task() { register_task(BATCHMATMUL_FWD_TASK_ID, "BatchMatmul Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -241,7 +220,7 @@ void register_task() { register_task(BATCHMATMUL_BWD_TASK_ID, "BatchMatmul Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_matmul.h b/lib/local-execution/src/ops/batch_matmul.h similarity index 84% rename from lib/runtime/src/ops/batch_matmul.h rename to lib/local-execution/src/ops/batch_matmul.h index 7d3f2308da..94457c22be 100644 --- a/lib/runtime/src/ops/batch_matmul.h +++ b/lib/local-execution/src/ops/batch_matmul.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_BATCH_MATMUL_H #define _FLEXFLOW_BATCH_MATMUL_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/op_task_signature.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/batch_matmul.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" -#include "task_spec/op_task_signature.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc similarity index 81% rename from lib/runtime/src/ops/batch_norm.cc rename to lib/local-execution/src/ops/batch_norm.cc index a52981a8a3..97830f90fe 100644 --- a/lib/runtime/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -15,17 +15,11 @@ #include "batch_norm.h" #include "kernels/batch_norm_kernels.h" -#include "legion/legion_utilities.h" namespace FlexFlow { using namespace FlexFlow::Kernels::BatchNorm; -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - enum Slots { INPUT, // tensor SCALE, // tensor @@ -88,29 +82,19 @@ static DeviceSpecific float *runningMean; DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, - allocator, - runningMean, - output_n, - output_c, - output_h, - output_w, - attrs.relu)); + init_kernel(handle, + allocator, + runningMean, + output_n, + output_c, + output_h, + output_w, + attrs.relu); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -122,23 +106,16 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[BatchNorm] forward_time = %.2lfms\n", - &per_device_state, + "[BatchNorm] forward_time = {:.2lf}ms\n", + per_device_state, input.get_float_ptr(), output.get_float_ptr(), scale.get_float_ptr(), bias.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -153,8 +130,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[BatchNorm] backward_time = %.2lfms\n", - &per_device_state, + "[BatchNorm] backward_time = {:.2lf}ms\n", + per_device_state, input.get_float_ptr(), output_grad.get_float_ptr(), output.get_float_ptr(), @@ -165,14 +142,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { output.shape.get_volume()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, BatchNormAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -221,6 +190,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); init.add_input_slot(BIAS); init.add_output_slot(OUTPUT); @@ -236,7 +206,7 @@ void register_task() { register_task(BATCHNORM_INIT_TASK_ID, "BatchNorm Init", init_signature(), - init_task); + init_task_impl); } template <> @@ -258,7 +228,7 @@ void register_task() { register_task(BATCHNORM_FWD_TASK_ID, "BatchNorm Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -274,7 +244,7 @@ void register_task() { register_task(BATCHNORM_BWD_TASK_ID, "BatchNorm Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/batch_norm.h b/lib/local-execution/src/ops/batch_norm.h similarity index 89% rename from lib/runtime/src/ops/batch_norm.h rename to lib/local-execution/src/ops/batch_norm.h index 906e85a57c..1745a5cac8 100644 --- a/lib/runtime/src/ops/batch_norm.h +++ b/lib/local-execution/src/ops/batch_norm.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_BATCH_NORM_H #define _FLEXFLOW_BATCH_NORM_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/batch_norm.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/cast.cc b/lib/local-execution/src/ops/cast.cc similarity index 79% rename from lib/runtime/src/ops/cast.cc rename to lib/local-execution/src/ops/cast.cc index 44230eaf46..9e1f777d73 100644 --- a/lib/runtime/src/ops/cast.cc +++ b/lib/local-execution/src/ops/cast.cc @@ -15,17 +15,12 @@ #include "cast.h" #include "kernels/cast_kernels.h" -#include "legion/legion_utilities.h" -#include "task_spec/op_task_signature.h" + +#include "local-execution/op_task_signature.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Cast; -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - namespace FlexFlow { enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; @@ -48,7 +43,7 @@ OpTaskInvocation backward(CastAttrs const &attrs) { return {CAST_BWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); @@ -57,22 +52,15 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Cast] forward_time = %.2lfms\n", + "[Cast] forward_time = {:.2lf}ms\n", input, output, input.data_type, attrs.dtype); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); @@ -83,21 +71,13 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Cast] forward_time = %.2lfms\n", + "[Cast] forward_time = {:.2lf}ms\n", input_grad, output_grad, input.data_type, attrs.dtype); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, CastAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -143,7 +123,7 @@ void register_task() { register_task(CAST_FWD_TASK_ID, "Cast Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -158,7 +138,7 @@ void register_task() { register_task(CAST_BWD_TASK_ID, "Cast Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/cast.h b/lib/local-execution/src/ops/cast.h similarity index 93% rename from lib/runtime/src/ops/cast.h rename to lib/local-execution/src/ops/cast.h index c0c500e869..69aeadf497 100644 --- a/lib/runtime/src/ops/cast.h +++ b/lib/local-execution/src/ops/cast.h @@ -15,9 +15,9 @@ #ifndef _FLEXFLOW_CAST_H #define _FLEXFLOW_CAST_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/cast.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/combine.cc b/lib/local-execution/src/ops/combine.cc similarity index 75% rename from lib/runtime/src/ops/combine.cc rename to lib/local-execution/src/ops/combine.cc index 46d5ebb4fe..6df09b53f4 100644 --- a/lib/runtime/src/ops/combine.cc +++ b/lib/local-execution/src/ops/combine.cc @@ -15,15 +15,11 @@ #include "combine.h" #include "kernels/combine_kernels.h" -#include "task_spec/op_task_invocation.h" +#include "local-execution/op_task_invocation.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; using namespace FlexFlow::Kernels::Combine; @@ -46,7 +42,7 @@ OpTaskInvocation backward(CombineAttrs const &attrs) { return {COMBINE_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); @@ -54,20 +50,13 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Combine] forward_time = %.2lfms\n", + "[Combine] forward_time = {:.2lf}ms\n", input, output); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); @@ -75,19 +64,11 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Combine] forward_time = %.2lfms\n", + "[Combine] backward_time = {:.2lf}ms\n", input_grad, output_grad); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, CombineAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -117,7 +98,7 @@ void register_task() { register_task(COMBINE_FWD_TASK_ID, "Combine Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -133,7 +114,7 @@ void register_task() { register_task(COMBINE_BWD_TASK_ID, "Combine Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/combine.h b/lib/local-execution/src/ops/combine.h similarity index 87% rename from lib/runtime/src/ops/combine.h rename to lib/local-execution/src/ops/combine.h index 6b3a43863b..f9349a01ef 100644 --- a/lib/runtime/src/ops/combine.h +++ b/lib/local-execution/src/ops/combine.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_COMBINE_H #define _FLEXFLOW_COMBINE_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/combine.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc similarity index 78% rename from lib/runtime/src/ops/concat.cc rename to lib/local-execution/src/ops/concat.cc index 1ce549cc57..f3c2eba48f 100644 --- a/lib/runtime/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -15,21 +15,16 @@ #include "concat.h" #include "kernels/concat_kernels.h" -#include "legion/legion_utilities.h" + +#include "local-execution/op_task_signature.h" +#include "local-execution/variadic_tensor_ref.h" #include "op-attrs/get_output_shapes.h" -#include "task_spec/op_task_signature.h" -#include "task_spec/variadic_tensor_ref.h" #include "utils/hash-utils.h" namespace FlexFlow { using namespace FlexFlow::Kernels::Concat; -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - enum Slots { INPUTS, OUTPUT, ATTRS, PROFILING, HANDLE, NUM_INPUTS }; OpTaskInvocation forward(ConcatAttrs const &attrs) { @@ -48,7 +43,7 @@ OpTaskInvocation backward(ConcatAttrs const &attrs) { return {CONCAT_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); @@ -59,21 +54,14 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Concat] forward_time = %.2lfms\n", + "[Concat] forward_time = {:.2lf}ms\n", output, inputs, attrs.axis); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto const &attrs = acc.get_argument(ATTRS); @@ -84,20 +72,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Concat] backward_time = %.2lfms\n", + "[Concat] backward_time = {:.2lf}ms\n", output_grad, input_grads, attrs.axis); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, ConcatAttrs const &attrs, @@ -132,6 +112,7 @@ CostMetrics template <> OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); + fwd.add_arg_slot(ATTRS); fwd.add_arg_slot(PROFILING); fwd.add_input_slot(INPUTS, SlotType::VARIADIC); @@ -145,13 +126,13 @@ void register_task() { register_task(CONCAT_FWD_TASK_ID, "Concat Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> OpTaskSignature bwd_signature() { OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(CONCAT_FWD_TASK_ID)); + infer_bwd_signature(fwd_signature()); return bwd; } @@ -161,7 +142,7 @@ void register_task() { register_task(CONCAT_BWD_TASK_ID, "Concat Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/concat.h b/lib/local-execution/src/ops/concat.h similarity index 89% rename from lib/runtime/src/ops/concat.h rename to lib/local-execution/src/ops/concat.h index 27dec47743..fa61d87e77 100644 --- a/lib/runtime/src/ops/concat.h +++ b/lib/local-execution/src/ops/concat.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_CONCAT_H #define _FLEXFLOW_CONCAT_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/concat.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc similarity index 78% rename from lib/runtime/src/ops/conv_2d.cc rename to lib/local-execution/src/ops/conv_2d.cc index 01d8abab55..eef4c21a45 100644 --- a/lib/runtime/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -1,17 +1,10 @@ #include "conv_2d.h" #include "kernels/conv_2d_kernels.h" -#include "legion/legion_utilities.h" -#include "mpark/variant.hpp" #include "op-attrs/get_output_shapes.h" #include "utils/hash-utils.h" namespace FlexFlow { -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - using namespace FlexFlow::Kernels::Conv2D; enum Slots { @@ -70,33 +63,23 @@ static DeviceSpecific auto filter_grad = acc.get_tensor_grad(FILTER); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, - attrs.activation, - attrs.kernel_h, - attrs.kernel_w, - attrs.groups, - attrs.padding_h, - attrs.padding_w, - attrs.stride_h, - attrs.stride_w, - input, - output, - filter.get_float_ptr(), - filter_grad.get_float_ptr())); + init_kernel(handle, + attrs.activation, + attrs.kernel_h, + attrs.kernel_w, + attrs.groups, + attrs.padding_h, + attrs.padding_w, + attrs.stride_h, + attrs.stride_w, + input, + output, + filter.get_float_ptr(), + filter_grad.get_float_ptr()); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -109,7 +92,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Conv2d] forward_time = %.2lfms\n", + "[Conv2d] forward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), output.get_float_ptr(), @@ -118,15 +101,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { attrs.activation); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -143,7 +119,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Conv2d] backward_time = %.2lfms\n", + "[Conv2d] backward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), input_grad.get_float_ptr(), @@ -155,14 +131,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { attrs.activation); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, Conv2DAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -228,7 +196,7 @@ void register_task() { register_task(CONV2D_INIT_TASK_ID, "Conv2d Init", init_signature(), - init_task); + init_task_impl); } template <> @@ -252,7 +220,7 @@ void register_task() { register_task(CONV2D_FWD_TASK_ID, "Conv2d Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -268,7 +236,7 @@ void register_task() { register_task(CONV2D_BWD_TASK_ID, "Conv2d Bwd", bwd_signature(), - backward_task); + backward_task_impl); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/conv_2d.h b/lib/local-execution/src/ops/conv_2d.h similarity index 89% rename from lib/runtime/src/ops/conv_2d.h rename to lib/local-execution/src/ops/conv_2d.h index 7225099a47..0c8181adce 100644 --- a/lib/runtime/src/ops/conv_2d.h +++ b/lib/local-execution/src/ops/conv_2d.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_CONV_2D_H #define _FLEXFLOW_CONV_2D_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/conv_2d.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/dropout.cc b/lib/local-execution/src/ops/dropout.cc similarity index 77% rename from lib/runtime/src/ops/dropout.cc rename to lib/local-execution/src/ops/dropout.cc index fe85afea38..9d680054ea 100644 --- a/lib/runtime/src/ops/dropout.cc +++ b/lib/local-execution/src/ops/dropout.cc @@ -1,18 +1,12 @@ #include "dropout.h" #include "kernels/dropout_kernels.h" -#include "legion/legion_utilities.h" +#include "local-execution/op_task_invocation.h" +#include "local-execution/op_task_signature.h" #include "op-attrs/get_output_shapes.h" -#include "task_spec/op_task_invocation.h" -#include "task_spec/task_signature.h" #include "utils/hash-utils.h" namespace FlexFlow { -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - using namespace FlexFlow::Kernels::Dropout; enum Slots { INPUT, OUTPUT, ATTRS, PER_DEVICE_STATE, FF_HANDLE, PROFILING }; @@ -54,21 +48,11 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, attrs.rate, attrs.seed, output.shape, allocator)); + init_kernel(handle, attrs.rate, attrs.seed, output.shape, allocator); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -77,21 +61,14 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Dropout] forward_time = %.2lfms\n", + "[Dropout] forward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -102,20 +79,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Dropout] backward_time = %.2lfms\n", + "[Dropout] backward_time = {:.2lf}ms\n", per_device_state, output_grad.get_float_ptr(), input_grad.get_float_ptr()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, DropoutAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -156,6 +125,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); + init.add_arg_slot(ATTRS); init.add_unchecked_arg_slot(FF_HANDLE); init.add_output_slot(OUTPUT); @@ -170,7 +140,7 @@ void register_task() { register_task(DROPOUT_INIT_TASK_ID, "Dropout Init", init_signature(), - init_task); + init_task_impl); } template <> @@ -191,7 +161,7 @@ void register_task() { register_task(DROPOUT_FWD_TASK_ID, "Dropout Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -207,7 +177,7 @@ void register_task() { register_task(DROPOUT_BWD_TASK_ID, "Dropout Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h similarity index 85% rename from lib/runtime/src/ops/dropout.h rename to lib/local-execution/src/ops/dropout.h index 88a255d140..53fbeb3857 100644 --- a/lib/runtime/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_DROPOUT_H #define _FLEXFLOW_DROPOUT_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" +#include "local-execution/tasks.h" #include "op-attrs/ops/dropout.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" -#include "tasks.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/element_binary.cc b/lib/local-execution/src/ops/element_binary.cc similarity index 81% rename from lib/runtime/src/ops/element_binary.cc rename to lib/local-execution/src/ops/element_binary.cc index f6be2198ca..a2e9ee2ba8 100644 --- a/lib/runtime/src/ops/element_binary.cc +++ b/lib/local-execution/src/ops/element_binary.cc @@ -1,16 +1,11 @@ #include "element_binary.h" #include "kernels/element_binary_kernels.h" -#include "legion/legion_utilities.h" + #include "op-attrs/get_output_shapes.h" #include "utils/hash-utils.h" namespace FlexFlow { -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - using namespace FlexFlow::Kernels::ElementBinary; enum Slots { @@ -66,27 +61,17 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, - attrs.type, - attrs.should_broadcast_lhs, - attrs.should_broadcast_rhs, - input_lhs.shape, - input_rhs.shape, - output.shape)); + init_kernel(handle, + attrs.type, + attrs.should_broadcast_lhs, + attrs.should_broadcast_rhs, + input_lhs.shape, + input_rhs.shape, + output.shape); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -99,7 +84,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[ElementBinary] forward_time = %.2lfms\n", + "[ElementBinary] forward_time = {:.2lf}ms\n", per_device_state, input_lhs.get_float_ptr(), input_rhs.get_float_ptr(), @@ -109,15 +94,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { handle); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -133,7 +111,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[ElementBinary] backward_time = %.2lfms\n", + "[ElementBinary] backward_time = {:.2lf}ms\n", per_device_state, output_grad.get_float_ptr(), input_lhs.get_float_ptr(), @@ -146,14 +124,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { handle); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, ElementBinaryAttrs const &attrs, @@ -221,7 +191,7 @@ void register_task() { register_task(ELEMENTBINARY_INIT_TASK_ID, "ElementBinary Init", init_signature(), - init_task); + init_task_impl); } template <> @@ -245,7 +215,7 @@ void register_task() { register_task(ELEMENTBINARY_FWD_TASK_ID, "ElementBinary Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -261,7 +231,7 @@ void register_task() { register_task(ELEMENTBINARY_BWD_TASK_ID, "ElementBinary Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/element_binary.h b/lib/local-execution/src/ops/element_binary.h similarity index 95% rename from lib/runtime/src/ops/element_binary.h rename to lib/local-execution/src/ops/element_binary.h index 342909c468..fa4202dffd 100644 --- a/lib/runtime/src/ops/element_binary.h +++ b/lib/local-execution/src/ops/element_binary.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_ELEMENT_BINARY_H #define _FLEXFLOW_ELEMENT_BINARY_H +#include "local-execution/sim_environment.h" #include "op-attrs/ops/element_binary.h" -#include "sim_environment.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc similarity index 76% rename from lib/runtime/src/ops/element_unary.cc rename to lib/local-execution/src/ops/element_unary.cc index f41a8b3551..2ad5d797f5 100644 --- a/lib/runtime/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,15 +1,11 @@ #include "element_unary.h" #include "kernels/element_unary_kernels.h" -#include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; using namespace FlexFlow::Kernels::ElementUnary; @@ -27,7 +23,6 @@ enum Slots { OpTaskInvocation init(ElementUnaryUnifiedAttrs const &attrs) { OpTaskBinding b; - b.bind_arg(HANDLE, ff_handle()); b.bind_arg(ATTRS, attrs); b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); @@ -58,32 +53,21 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); ProfilingSettings profiling = acc.get_argument(PROFILING); - PerDeviceFFHandle handle = acc.get_argument(HANDLE); ParallelTensorShape input_shape = acc.get_argument(INPUT_SHAPE); ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); - DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(input_shape, output_shape, attrs)); + DeviceSpecific per_device_state = init_kernel( + get_piece_shape(input_shape), get_piece_shape(output_shape), attrs); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); auto const &attrs = acc.get_argument(ATTRS); - auto &handle = acc.get_argument(HANDLE); + auto handle = acc.get_argument(HANDLE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = @@ -91,7 +75,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[ElementUnary] forward_time = %.2lfms\n", + "[ElementUnary] forward_time = {:.2lf}ms\n", per_device_state, attrs, handle, @@ -99,22 +83,15 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto input_grad = acc.get_tensor_grad(INPUT); auto output = acc.get_tensor(OUTPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); auto const &attrs = acc.get_argument(ATTRS); - auto &handle = acc.get_argument(HANDLE); + auto handle = acc.get_argument(HANDLE); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -122,7 +99,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[ElementUnary] backward_time = %.2lfms\n", + "[ElementUnary] backward_time = {:.2lf}ms\n", per_device_state, attrs, handle, @@ -132,14 +109,6 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { output_grad); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, ElementUnaryUnifiedAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -147,7 +116,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, MachineView const &mv) { auto env = sim.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, input_shape); + ParallelTensorShape output_shape = get_output_shape(attrs, input_shape.shape); SimTaskBinding init_binding; init_binding.bind_arg(HANDLE, ff_handle()); @@ -182,6 +151,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim, template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); + init.add_arg_slot(INPUT_SHAPE); init.add_arg_slot(ATTRS); init.add_unchecked_arg_slot(HANDLE); @@ -196,7 +166,7 @@ void register_task() { register_task(ELEMENTUNARY_INIT_TASK_ID, "ElementUnary Init", init_signature(), - init_task); + init_task_impl); } template <> @@ -217,7 +187,7 @@ void register_task() { register_task(ELEMENTUNARY_FWD_TASK_ID, "ElementUnary Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -233,7 +203,7 @@ void register_task() { register_task(ELEMENTUNARY_BWD_TASK_ID, "ElementUnary Bwd", bwd_signature(), - backward_task); + backward_task_impl); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/element_unary.h b/lib/local-execution/src/ops/element_unary.h similarity index 85% rename from lib/runtime/src/ops/element_unary.h rename to lib/local-execution/src/ops/element_unary.h index f44efc28db..e0f58e8a75 100644 --- a/lib/runtime/src/ops/element_unary.h +++ b/lib/local-execution/src/ops/element_unary.h @@ -1,15 +1,12 @@ #ifndef _ELEMENT_UNARY_H #define _ELEMENT_UNARY_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/element_unary.h" -#include "op_task_invocation.h" -#include "sim_environment.h" namespace FlexFlow { -using ElementUnaryUnifiedAttrs = - variant; - template <> void register_task(); template <> diff --git a/lib/runtime/src/ops/embedding.cc b/lib/local-execution/src/ops/embedding.cc similarity index 81% rename from lib/runtime/src/ops/embedding.cc rename to lib/local-execution/src/ops/embedding.cc index a1bc915d2f..00d6d033d4 100644 --- a/lib/runtime/src/ops/embedding.cc +++ b/lib/local-execution/src/ops/embedding.cc @@ -15,17 +15,11 @@ #include "embedding.h" #include "kernels/embedding_kernels.h" -#include "legion.h" +#include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/embedding.h" namespace FlexFlow { -// declare Legion names -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - using namespace FlexFlow::Kernels::Embedding; enum Slots { INPUT, WEIGHT, OUTPUT, ATTRS, PROFILING }; @@ -49,7 +43,7 @@ OpTaskInvocation backward(EmbeddingAttrs const &attrs) { return {EMBED_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto weight = acc.get_tensor(WEIGHT); auto output = acc.get_tensor(OUTPUT); @@ -59,7 +53,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Embedding] forward_time = %.2lfms\n", + "[Embedding] forward_time = {:.2lf}ms\n", input, output, weight, @@ -71,15 +65,8 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { input.shape[legion_dim_t(1)]); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); auto weight_grad = acc.get_tensor_grad(WEIGHT); @@ -89,7 +76,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Embedding] forward_time = %.2lfms\n", + "[Embedding] backward_time = {:.2lf}ms\n", input, output, weight_grad, @@ -98,15 +85,7 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { attrs.aggr, input.shape.get_dim(), output.shape.get_dim(), - input.shape[ff_dim_t(0)]); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + input.shape.at(ff_dim_t(0))); } CostMetrics measure_operator_cost(SimEnvFactory const &sim, @@ -158,7 +137,7 @@ void register_task() { register_task(EMBED_FWD_TASK_ID, "Embed Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -172,7 +151,7 @@ void register_task() { register_task(EMBED_BWD_TASK_ID, "Embed Bwd", bwd_signature(), - backward_task); + backward_task_impl); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/embedding.h b/lib/local-execution/src/ops/embedding.h similarity index 88% rename from lib/runtime/src/ops/embedding.h rename to lib/local-execution/src/ops/embedding.h index cd1b14fa66..c33b1161bf 100644 --- a/lib/runtime/src/ops/embedding.h +++ b/lib/local-execution/src/ops/embedding.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_EMBEDDING_H #define _FLEXFLOW_EMBEDDING_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/embedding.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/flat.cc b/lib/local-execution/src/ops/flat.cc similarity index 76% rename from lib/runtime/src/ops/flat.cc rename to lib/local-execution/src/ops/flat.cc index f53a6185b6..3c2499da79 100644 --- a/lib/runtime/src/ops/flat.cc +++ b/lib/local-execution/src/ops/flat.cc @@ -5,10 +5,6 @@ namespace FlexFlow { // declare Legion names -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; using namespace FlexFlow::Kernels::Flat; @@ -30,27 +26,20 @@ OpTaskInvocation backward(FlatAttrs const &attrs) { return {FLAT_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); return profile(forward_kernel, profiling, - "[Flat] forward_time = %.2lfms\n", + "[Flat] forward_time = {:.2lf}ms\n", input, output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); @@ -59,20 +48,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Flat] forward_time = %.2lfms\n", + "[Flat] backward_time = {:.2lf}ms\n", input, input_grad.get_float_ptr(), output_grad.get_float_ptr()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim, FlatAttrs const &attrs, InputParallelTensorDesc const &input_shape, @@ -115,7 +96,7 @@ void register_task() { register_task(FLAT_FWD_TASK_ID, "Flat Fwd", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -130,7 +111,7 @@ void register_task() { register_task(FLAT_BWD_TASK_ID, "Flat Bwd", bwd_signature(), - backward_task); + backward_task_impl); } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/flat.h b/lib/local-execution/src/ops/flat.h similarity index 93% rename from lib/runtime/src/ops/flat.h rename to lib/local-execution/src/ops/flat.h index 13246028fb..d9ea4d3985 100644 --- a/lib/runtime/src/ops/flat.h +++ b/lib/local-execution/src/ops/flat.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_FLAT_H #define _FLEXFLOW_FLAT_H +#include "local-execution/sim_environment.h" #include "op-attrs/ops/flat.h" -#include "sim_environment.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/gather.cc b/lib/local-execution/src/ops/gather.cc new file mode 100644 index 0000000000..50b27d72a6 --- /dev/null +++ b/lib/local-execution/src/ops/gather.cc @@ -0,0 +1,215 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gather.h" +#include "kernels/gather_kernels.h" +#include "local-execution/legion_tensor_shape.h" +#include "op-attrs/get_output_shapes.h" +#include + +namespace FlexFlow { + +using namespace FlexFlow::Kernels::Gather; + +enum Slots { INPUT, OUTPUT, INDEX, ATTRS, HANDLE, PROFILING, PER_DEVICE_STATE }; + +OpTaskInvocation init(GatherAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind(INPUT, input_tensor(0)); + binding.bind(INDEX, input_tensor(1)); + binding.bind(OUTPUT, output_tensor(0)); + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(HANDLE, ff_handle()); + + return {GATHER_INIT_TASK_ID, binding}; +} + +OpTaskInvocation forward(GatherAttrs const &attrs) { + OpTaskBinding binding; + + binding.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); + binding.bind(INDEX, weight_tensor(0)); + + return {GATHER_FWD_TASK_ID, binding}; +} + +OpTaskInvocation backward(GatherAttrs const &attrs) { + OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); + + return {GATHER_BWD_TASK_ID, binding}; +} + +static DeviceSpecific + init_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto index = acc.get_tensor(INDEX); + auto output = acc.get_tensor(OUTPUT); + + PerDeviceFFHandle handle = acc.get_argument(HANDLE); + auto const &attrs = acc.get_argument(ATTRS); + legion_dim_t legion_dim = to_legion(attrs.dim, input.shape.num_dims()); + + assert(input.shape.get_dim() == index.shape.get_dim()); + assert(output.shape.get_dim() == index.shape.get_dim()); + + for (int i = 0; i < input.shape.get_dim(); i++) { + assert(index.shape[legion_dim_t(i)] == output.shape[legion_dim_t(i)]); + if (i != legion_dim.value()) { + assert(input.shape[legion_dim_t(i)] == index.shape[legion_dim_t(i)]); + } + } + + return DeviceSpecific({handle, legion_dim}); +} + +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + + auto input = acc.get_tensor(INPUT); + auto index = acc.get_tensor(INDEX); + auto output = acc.get_tensor(OUTPUT); + + return profile(forward_kernel, + profiling, + "[Gather] forward_time = {:.2lf}ms\n", + per_device_state, + input, + index, + output); +} + +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + ProfilingSettings profiling = acc.get_argument(PROFILING); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); + + auto output_grad = acc.get_tensor_grad(OUTPUT); + auto index = acc.get_tensor(INDEX); + auto input_grad = acc.get_tensor_grad(INPUT); + + return profile(backward_kernel, + profiling, + "[Gather] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad, + index, + input_grad); +} + +CostMetrics measure_operator_cost(SimEnvFactory const &sim, + GatherAttrs const &attrs, + InputParallelTensorDesc const &input_shape, + InputParallelTensorDesc const &index_shape, + ProfilingSettings const &settings, + MachineView const &mv) { + + auto env = sim.new_environment(); + + std::vector output_shape = + get_output_shapes(attrs, input_shape.shape, index_shape.shape); + + SimTaskBinding fwd_binding; + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(ATTRS, attrs); + + fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind(INDEX, index_shape); + + SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); + + auto fwd_accessor = env.get_fwd_accessor(GATHER_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(GATHER_BWD_TASK_ID, bwd_binding); + + float forward_time = forward_task_impl(fwd_accessor).value(); + float backward_time = backward_task_impl(bwd_accessor).value(); + + float sync_time = default_estimate_sync_time(env); + return make_metrics(forward_time, backward_time, sync_time, env); +} + +template <> +OpTaskSignature init_signature() { + OpTaskSignature init(OpTaskType::INIT); + + init.add_input_slot(INPUT); + init.add_input_slot(INDEX); + init.add_output_slot(OUTPUT); + + init.add_arg_slot(ATTRS); + init.add_unchecked_arg_slot(HANDLE); + + init.add_return_value(); + + return init; +} + +template <> +void register_task() { + register_task(GATHER_INIT_TASK_ID, + "Gather Init", + init_signature(), + init_task_impl); +} + +template <> +OpTaskSignature fwd_signature() { + OpTaskSignature fwd(OpTaskType::FWD); + + fwd.add_arg_slot(PROFILING); + fwd.add_arg_slot(ATTRS); + + fwd.add_input_slot(INPUT); + fwd.add_output_slot(OUTPUT); + fwd.add_weight_slot(INDEX); + + return fwd; +} + +template <> +void register_task() { + register_task(GATHER_FWD_TASK_ID, + "Gather Fwd", + fwd_signature(), + forward_task_impl); +} + +template <> +OpTaskSignature bwd_signature() { + OpTaskSignature bwd = + infer_bwd_signature(fwd_signature()); + + return bwd; +} + +template <> +void register_task() { + register_task(GATHER_BWD_TASK_ID, + "Gather Bwd", + bwd_signature(), + backward_task_impl); +} + +}; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/src/ops/gather.h new file mode 100644 index 0000000000..e2de09d96a --- /dev/null +++ b/lib/local-execution/src/ops/gather.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_GATHER_H +#define _FLEXFLOW_GATHER_H + +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" +#include "op-attrs/ops/gather.h" + +namespace FlexFlow { + +template <> +void register_task(); +template <> +void register_task(); +template <> +void register_task(); + +OpTaskInvocation init(GatherAttrs const &); +OpTaskInvocation forward(GatherAttrs const &); +OpTaskInvocation backward(GatherAttrs const &); + +CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, + GatherAttrs const &attrs, + InputParallelTensorDesc const &input, + InputParallelTensorDesc const &index, + ProfilingSettings const &settings, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/runtime/src/ops/layer_norm.cc b/lib/local-execution/src/ops/layer_norm.cc similarity index 61% rename from lib/runtime/src/ops/layer_norm.cc rename to lib/local-execution/src/ops/layer_norm.cc index 6bc671c249..620758772c 100644 --- a/lib/runtime/src/ops/layer_norm.cc +++ b/lib/local-execution/src/ops/layer_norm.cc @@ -15,20 +15,17 @@ #include "layer_norm.h" #include "kernels/layer_norm_kernels.h" -#include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/exceptions.h" +#include "utils/exception.h" #include "utils/hash-utils.h" #include -using Legion::Context; -using Legion::PhysicalRegion; -using Legion::Runtime; -using Legion::Task; - namespace FlexFlow { +using namespace FlexFlow::Kernels::LayerNorm; + enum Slots { PROFILING, INPUT, @@ -59,7 +56,7 @@ OpTaskInvocation forward(LayerNormAttrs const &attrs) { b.bind(GAMMA, weight_tensor(0)); // todo, this may have some problem b.bind(BETA, weight_tensor(1)); // how to get gmmam and beta b.bind_arg(PROFILING, profiling_settings()); - b.bind_arg(PER_DEVICE_STATE, per_device_state()); + b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); return {LAYERNORM_FWD_TASK_ID, b}; } @@ -70,71 +67,56 @@ OpTaskInvocation backward(LayerNormAttrs const &attrs) { return {LAYERNORM_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto gamma = acc.get_tensor(GAMMA); - auto beta = acc.get_tensor(BETA); +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto gamma = acc.get_tensor(GAMMA); + auto beta = acc.get_tensor(BETA); ProfilingSettings profiling = acc.get_argument(PROFILING); auto &state = acc.get_argument(PER_DEVICE_STATE); return profile(forward_kernel, profiling, - "[LayerNorm] forward time = %.2lfms\n", + "[LayerNorm] forward time = {:.2lf}ms\n", state, - input.get_float_ptr(), - output.get_float_ptr(), - gamma.get_float_ptr(), - beta.get_float_ptr()); + input, + output, + gamma, + beta); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { + auto input = acc.get_tensor(INPUT); + auto gamma = acc.get_tensor(GAMMA); -static optional backward_task_impl(TaskArgumentAccessor const &acc) { - auto input = acc.get_tensor(INPUT); - auto gamma = acc.get_tensor(GAMMA); - - auto input_grad = acc.get_tensor(INPUT_GRAD); - auto gamma_grad = acc.get_tensor(GAMMA_GRAD); - auto beta_grad = acc.get_tensor(BETA_GRAD); - auto output_grad = acc.get_tensor(OUTPUT_GRAD); + auto input_grad = acc.get_tensor_grad(INPUT); + auto gamma_grad = acc.get_tensor_grad(GAMMA); + auto beta_grad = acc.get_tensor_grad(BETA); + auto output_grad = acc.get_tensor_grad(OUTPUT); ProfilingSettings profiling = acc.get_argument(PROFILING); auto &state = acc.get_argument(PER_DEVICE_STATE); return profile(backward_kernel, profiling, - "[LayerNorm] backward time = %.2lfms\n", + "[LayerNorm] backward time = {:.2lf}ms\n", state, - output_grad.get_float_ptr(), - input.get_float_ptr(), - input_grad.get_float_ptr(), - gamma.get_float_ptr(), - gamma_grad.get_float_ptr(), - beta_grad.get_float_ptr()); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + output_grad, + input, + input_grad, + gamma, + gamma_grad, + beta_grad); } static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); + auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); - auto input = acc.get_tensor(INPUT); - FFHandler handle = acc.get_argument(HANDLE); + auto input = acc.get_tensor(INPUT); + auto handle = acc.get_argument(HANDLE); // question: how to get batch_size and effective_num_elements int64_t effective_batch_size, effective_num_elements; @@ -143,29 +125,20 @@ static DeviceSpecific M *= input.shape.at(legion_dim_t(attrs.axes[i])); } int num_replicas = 1; - for (int i = 0; i < intput.shape.num_dims(); i++) { + for (int i = 0; i < input.shape.num_dims(); i++) { num_replicas *= input.shape.at(legion_dim_t(i)); effective_num_elements = M; effective_batch_size = input.shape.get_volume() / M; - - DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, - allocator, - attrs.elementwise_affine, - effective_batch_size, - effective_num_elements, - attrs.eps)); } -} -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); + DeviceSpecific per_device_state = + init_kernel(handle, + allocator, + attrs.elementwise_affine, + effective_batch_size, + effective_num_elements, + attrs.eps); + return per_device_state; } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -173,18 +146,19 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); SimTaskBinding init_binding; init_binding.bind_arg(HANDLE, ff_handle()); init_binding.bind_arg(ATTRS, attrs); - init.binding.bind(INPUT, input.shape); + init_binding.bind(INPUT, input.shape); auto init_accessor = env.get_init_accessor(LAYERNORM_INIT_TASK_ID, init_binding); - DeviceSpecific = init_task_impl(init_accessor); + DeviceSpecific per_device_state = + init_task_impl(init_accessor); SimTaskBinding fwd_binding; fwd_binding.bind(INPUT, input.shape); @@ -192,9 +166,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); - // TODO how to handle gamma and beta, where are they from - fwd_binding.bind(GAMMA, input_shape); - fwd_binding.bind(BETA, input_shape); + fwd_binding.bind(GAMMA, input.shape); + fwd_binding.bind(BETA, input.shape); SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); auto fwd_accessor = env.get_fwd_accessor(LAYERNORM_FWD_TASK_ID, fwd_binding); @@ -222,7 +195,7 @@ OpTaskSignature fwd_signature() { } template <> -OpTaskSignature bwd_signature() { +OpTaskSignature bwd_signature() { OpTaskSignature bwd = infer_bwd_signature(fwd_signature()); return bwd; @@ -231,6 +204,7 @@ OpTaskSignature bwd_signature() { template <> OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); + init.add_input_slot(INPUT); init.add_arg_slot(ATTRS); init.add_unchecked_arg_slot(HANDLE); @@ -245,7 +219,7 @@ void register_task() { register_task(LAYERNORM_INIT_TASK_ID, "LayerNorm init", init_signature(), - init_task); + init_task_impl); } template <> @@ -253,15 +227,15 @@ void register_task() { register_task(LAYERNORM_FWD_TASK_ID, "LayerNorm forward", fwd_signature(), - forward_task); + forward_task_impl); } template <> void register_task() { register_task(LAYERNORM_BWD_TASK_ID, "LayerNorm backward", - bwd_signature(), - backward_task); + bwd_signature(), + backward_task_impl); } } // namespace FlexFlow diff --git a/lib/runtime/src/ops/layer_norm.h b/lib/local-execution/src/ops/layer_norm.h similarity index 97% rename from lib/runtime/src/ops/layer_norm.h rename to lib/local-execution/src/ops/layer_norm.h index 83e6733bf6..4eadb9ff09 100644 --- a/lib/runtime/src/ops/layer_norm.h +++ b/lib/local-execution/src/ops/layer_norm.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_OPS_LAYER_NORM_H #define _FLEXFLOW_RUNTIME_SRC_OPS_LAYER_NORM_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/layer_norm.h" -#include "op_task_invocation.h" -#include "sim_environment.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc similarity index 54% rename from lib/runtime/src/ops/linear.cc rename to lib/local-execution/src/ops/linear.cc index 96d037913c..e2c9d9aef4 100644 --- a/lib/runtime/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -1,32 +1,14 @@ #include "linear.h" #include "kernels/linear_kernels.h" -#include "layer.h" -#include "legion/legion_utilities.h" +#include "local-execution/task_argument_accessor.h" #include "op-attrs/ff_dim.h" #include "op-attrs/get_output_shapes.h" -#include "utils/exceptions.h" +#include "utils/exception.h" #include "utils/graph/views.h" #include "utils/hash-utils.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - using namespace FlexFlow::Kernels::Linear; enum slots { @@ -43,12 +25,12 @@ enum slots { OpTaskInvocation init(LinearAttrs const &attrs) { OpTaskBinding binding; - bind.bind_arg(HANDLE, ff_handle()); - bind.bind_arg(ATTRS, attrs); + binding.bind_arg(HANDLE, ff_handle()); + binding.bind_arg(ATTRS, attrs); - bind.bind(INPUT, input_tensor(0)); // input - bind.bind(WEIGHT, weight_tensor(0)); // weight - bind.bind(OUTPUT, output_tensor(0)); // output + binding.bind(INPUT, input_tensor(0)); // input + binding.bind(WEIGHT, weight_tensor(0)); // weight + binding.bind(OUTPUT, output_tensor(0)); // output return {LINEAR_INIT_TASK_ID, binding}; } @@ -56,14 +38,17 @@ OpTaskInvocation init(LinearAttrs const &attrs) { OpTaskInvocation forward(LinearAttrs const &attrs) { OpTaskBinding binding; - bind.bind(INPUT, input_tensor(0)); // input - bind.bind(WEIGHT, weight_tensor(0)); // weight - bind.bind(OUTPUT, output_tensor(0)); // output - bind.bind(BIAS, bias_tensor(0)); // bias + binding.bind(INPUT, input_tensor(0)); // input + binding.bind(WEIGHT, weight_tensor(0)); // weight + binding.bind(OUTPUT, output_tensor(0)); // output + if (attrs.use_bias) { + binding.bind(BIAS, weight_tensor(1)); // bias + } - bing.bind_arg(PROFILING, profiling_settings()); - bind.bind_arg(PER_DEVICE_STATE, per_device_state()); - bind.bind_arg(ATTRS, attrs); + binding.bind_arg(PROFILING, profiling_settings()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + binding.bind_arg(ATTRS, attrs); return {LINEAR_FWD_TASK_ID, binding}; } @@ -74,51 +59,38 @@ OpTaskInvocation backward(LinearAttrs const &attrs) { return {LINEAR_BWD_TASK_ID, b}; } -static DeviceSpecific - init_task_impl(TaskArgumentAccessor const &acc) { - auto const &attrs = acc.get_argument(ATTRS); - Allocator allocator = acc.get_allocator(); +static LinearPerDeviceState init_task_impl(TaskArgumentAccessor const &acc) { + auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto input = acc.get_tensor(INPUT); auto weight = acc.get_tensor(WEIGHT); auto output = acc.get_tensor(OUTPUT); int out_dim = output.shape.at(ff_dim_t{0}); - int batch_size = output.shape.at.(ff_dim_t{1}); + int batch_size = output.shape.at(ff_dim_t{1}); float *one_ptr; - DeviceSpecific state = - acc.create_device_specific( - init_kernel(handle, - allocator, - one_ptr, - attrs.regularizer, - attrs.use_bias, - input.data_type, - weight.data_type, - output.data_type, - batch_size, - attrs.out_channels)); + LinearPerDeviceState state = init_kernel(handle, + one_ptr, + attrs.regularizer, + attrs.use_bias, + input.data_type, + weight.data_type, + output.data_type, + batch_size, + attrs.out_channels); return state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto weight = acc.get_tensor(WEIGHT); auto output = acc.get_tensor(OUTPUT); auto bias = acc.get_tensor(BIAS); - auto state = acc.get_device_specific(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto attrs = acc.get_argument(ATTRS); @@ -133,7 +105,7 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Linear] forward_time = %.2lfms\n", + "[Linear] forward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), output.get_float_ptr(), @@ -144,15 +116,10 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { batch_size); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -}; +; -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto weight = acc.get_tensor(WEIGHT); auto output = acc.get_tensor(OUTPUT); @@ -161,7 +128,8 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto input_grad = acc.get_tensor_grad(INPUT); auto weight_grad = acc.get_tensor_grad(WEIGHT); auto output_grad = acc.get_tensor_grad(OUTPUT); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + auto per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); auto attrs = acc.get_argument(ATTRS); @@ -176,65 +144,63 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { return profile(backward_kernel, profiling, - "[Linear] backward_time = %.2lfms\n", + "[Linear] backward_time = {:.2lf}ms\n", per_device_state, - input.get_float_ptr(), - input_grad.get_float_ptr(), - output.get_float_ptr(), - output_grad.get_float_ptr(), - weight.get_float_ptr(), - weight_grad.get_float_ptr(), - bias_ptr, + (void *)input.get_float_ptr(), + (void *)input_grad.get_float_ptr(), + (void *)output.get_float_ptr(), + (void *)output_grad.get_float_ptr(), + (void *)weight.get_float_ptr(), + (void *)weight_grad.get_float_ptr(), + (void *)bias_ptr, in_dim, out_dim, batch_size); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, LinearAttrs const &attrs, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); - ParallelTensorShape output_shape = get_output_shape(input.shape, attrs); + ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); + ParallelTensorShape weight_shape = get_weights_shape(attrs, input.shape); + ParallelTensorShape bias_shape = get_bias_shape(attrs, input.shape); SimTaskBinding init_binding; - init_binding.bind(INPUT, input_tensor(0)); - init_binding.bind(WEIGHT, weight_tensor(0)); - init_binding.bind(BIAS, bias_tensor(0)); - init_binding.bind(OUTPUT, output_tensor(0)); + init_binding.bind(INPUT, input.shape); + init_binding.bind(WEIGHT, weight_shape); + if (attrs.use_bias) { + init_binding.bind(BIAS, bias_shape); + } + init_binding.bind(OUTPUT, output_shape); init_binding.bind_arg(ATTRS, attrs); init_binding.bind_arg(HANDLE, ff_handle()); auto init_accessor = env.get_init_accessor(LINEAR_INIT_TASK_ID, init_binding); - DeviceSpecific per_device_state = - init_task_impl(init_accessor); + LinearPerDeviceState per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; - fwd_bind.bind(INPUT, input_tensor(0)); // input - fwd_bind.bind(WEIGHT, weight_tensor(0)); // weight - fwd_bind.bind(OUTPUT, output_tensor(0)); // output - fwd_bind.bind(BIAS, bias_tensor(0)); // bias + fwd_binding.bind(INPUT, input.shape); // input + fwd_binding.bind(WEIGHT, weight_shape); // weight + fwd_binding.bind(OUTPUT, output_shape); // output + if (attrs.use_bias) { + fwd_binding.bind(BIAS, bias_shape); // bias + } - fwd_bid.bind_arg(PROFILING, profiling_settings()); - fwd_bind.bind_arg(PER_DEVICE_STATE, per_device_state()); - fwd_bind.bind_arg(ATTRS, attrs); + fwd_binding.bind_arg(PROFILING, profiling_settings()); + fwd_binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + fwd_binding.bind_arg(ATTRS, attrs); SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_accessor(LINEAR_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_accessor(LINEAR_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = env.get_fwd_accessor(LINEAR_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = env.get_bwd_accessor(LINEAR_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); @@ -248,15 +214,14 @@ OpTaskSignature init_signature() { OpTaskSignature init(OpTaskType::INIT); init.add_input_slot(INPUT); - init.add_input_slot(WEIGHT); - init.add_input_slot(BIAS); + init.add_weight_slot(WEIGHT); init.add_output_slot(OUTPUT); init.add_arg_slot(ATTRS); init.add_unchecked_arg_slot(HANDLE); init.add_return_value(); - return init, + return init; } template <> @@ -264,8 +229,8 @@ OpTaskSignature fwd_signature() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_input_slot(INPUT); - fwd.add_input_slot(WEIGHT); - fwd.add_input_slot(BIAS); + fwd.add_weight_slot(WEIGHT); + fwd.add_optional_weight_slot(BIAS); fwd.add_output_slot(OUTPUT); fwd.add_arg_slot(PROFILING); @@ -281,13 +246,28 @@ OpTaskSignature bwd_signature() { return bwd; } +template <> +TaskImplFunction get_task_impl() { + return init_task_impl; +} + +template <> +TaskImplFunction get_task_impl() { + return forward_task_impl; +} + +template <> +TaskImplFunction get_task_impl() { + return backward_task_impl; +} + template <> void register_task() { register_task(LINEAR_INIT_TASK_ID, "Linear::init_task", init_signature(), - init_task); + init_task_impl); } template <> @@ -295,7 +275,7 @@ void register_task() { register_task(LINEAR_FWD_TASK_ID, "Linear::fwd_task", fwd_signature(), - forward_task); + forward_task_impl); } template <> @@ -303,7 +283,11 @@ void register_task() { register_task(LINEAR_BWD_TASK_ID, "Linear::bwd_task", bwd_signature(), - backward_task); + backward_task_impl); +} + +std::vector get_task_ids(LinearAttrs const &) { + return {LINEAR_INIT_TASK_ID, LINEAR_FWD_TASK_ID, LINEAR_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/linear.h b/lib/local-execution/src/ops/linear.h similarity index 98% rename from lib/runtime/src/ops/linear.h rename to lib/local-execution/src/ops/linear.h index 2b476382ef..2ff9016114 100644 --- a/lib/runtime/src/ops/linear.h +++ b/lib/local-execution/src/ops/linear.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LINEAR_H #define _FLEXFLOW_LINEAR_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/linear.h" -#include "op_task_invocation.h" -#include "sim_environment.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/noop.cc b/lib/local-execution/src/ops/noop.cc similarity index 95% rename from lib/runtime/src/ops/noop.cc rename to lib/local-execution/src/ops/noop.cc index 6b8510607a..168d547c17 100644 --- a/lib/runtime/src/ops/noop.cc +++ b/lib/local-execution/src/ops/noop.cc @@ -14,7 +14,7 @@ */ #include "noop.h" -#include "task_spec/op_task_invocation.h" +#include "local-execution/op_task_invocation.h" #include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/noop.h b/lib/local-execution/src/ops/noop.h similarity index 87% rename from lib/runtime/src/ops/noop.h rename to lib/local-execution/src/ops/noop.h index f5cf6cc98c..fab2cf1f86 100644 --- a/lib/runtime/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_NOOP_H #define _FLEXFLOW_NOOP_H +#include "local-execution/op_task_invocation.h" #include "op-attrs/ops/input.h" #include "op-attrs/ops/noop.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/parallel_op.h b/lib/local-execution/src/ops/parallel_op.h similarity index 96% rename from lib/runtime/src/ops/parallel_op.h rename to lib/local-execution/src/ops/parallel_op.h index 6b596a4fb5..e7bd98b8a8 100644 --- a/lib/runtime/src/ops/parallel_op.h +++ b/lib/local-execution/src/ops/parallel_op.h @@ -7,7 +7,7 @@ namespace FlexFlow { struct ParallelOpJoinResult { - optional op = nullopt; + std::optional op = std::nullopt; bool join_did_succeed = false; }; diff --git a/lib/runtime/src/ops/partition.cc b/lib/local-execution/src/ops/partition.cc similarity index 61% rename from lib/runtime/src/ops/partition.cc rename to lib/local-execution/src/ops/partition.cc index 2a974e96da..4b09ad026b 100644 --- a/lib/runtime/src/ops/partition.cc +++ b/lib/local-execution/src/ops/partition.cc @@ -13,32 +13,13 @@ * limitations under the License. */ -#include "parallel_ops/partition.h" #include "kernels/partition_kernels.h" -#include "op-attrs/get_output_shape.h" -#include "utils/exceptions.h" +#include "op-attrs/get_output_shapes.h" +#include "repartition.h" +#include "utils/exception.h" #include "utils/hash-utils.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Repartition; @@ -59,7 +40,7 @@ OpTaskInvocation forward(RepartitionAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(ATTRS, attrs); binding.bind_arg(PER_DEVICE_STATE, - per_device_state()); + per_device_op_state()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -79,64 +60,39 @@ static DeviceSpecific // Note: use the input data type DeviceSpecific per_device_state = - acc.create_device_specific_state( - init_kernel(handle, input.data_type)); + init_kernel(handle, input.data_type); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - return profiling(forward, - profiling, - "[Reparition/Partition] forward_time = %.2lfms\n", - per_device_state, - input, - output); -} - -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); + return profile(forward_kernel, + profiling, + "[Reparition/Partition] forward_time = {:.2lf}ms\n", + per_device_state, + input, + output); } -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); - return profiling(backward, - profiling, - "[Reparition/Partition] backward_time = %.2lfms\n", - per_device_state, - input_grad, - output_grad); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + return profile(backward_kernel, + profiling, + "[Reparition/Partition] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad, + input_grad); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -144,7 +100,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); @@ -165,8 +121,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto fwd_accessor = env.get_accessor(REPARTITION_FWD_TASK_ID, fwd_binding); - auto bwd_accessor = env.get_accessor(REPARTITION_BWD_TASK_ID, bwd_binding); + auto fwd_accessor = + env.get_fwd_accessor(REPARTITION_FWD_TASK_ID, fwd_binding); + auto bwd_accessor = + env.get_bwd_accessor(REPARTITION_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); @@ -185,7 +143,8 @@ void register_task() { init.add_return_value(); - register_task(REPARTITION_INIT_TASK_ID, "Repartition Init", init, init_task); + register_task( + REPARTITION_INIT_TASK_ID, "Repartition Init", init, init_task_impl); } template <> @@ -197,15 +156,19 @@ void register_task() { fwd.add_arg_slot(PROFILING); fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - register_task(REPARTITION_FWD_TASK_ID, "Repartition Fwd", fwd, forward_task); + register_task( + REPARTITION_FWD_TASK_ID, "Repartition Fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(REPARTITION_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(REPARTITION_BWD_TASK_ID, "Repartition Bwd", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(REPARTITION_FWD_TASK_ID)); + +// register_task(REPARTITION_BWD_TASK_ID, "Repartition Bwd", bwd, +// backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc similarity index 57% rename from lib/runtime/src/ops/pool_2d.cc rename to lib/local-execution/src/ops/pool_2d.cc index 577837c960..989f390380 100644 --- a/lib/runtime/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -1,10 +1,10 @@ #include "pool_2d.h" #include "kernels/pool_2d_kernels.h" -#include "legion/legion_utilities.h" + #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/pool_2d.h" #include "utils/exception.decl.h" -#include "utils/exceptions.h" +#include "utils/exception.h" #include "utils/hash-utils.h" using namespace FlexFlow::Kernels::Pool2D; @@ -23,13 +23,13 @@ OpTaskInvocation init(Pool2DAttrs const &attrs) { return {POOL2D_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); int input_w = input.shape.at(ff_dim_t(0)) + 1; int input_h = input.shape.at(ff_dim_t(1)) + 1; @@ -64,37 +64,27 @@ static DeviceSpecific printf("Warning: changing pool_padding_w to satisfy output_w size\n"); } - DeviceSpecific state = acc.create_device_specific( - init_kernel(handle, - attrs.activation, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - pad_h, - pad_w, - attrs.kernel_h, - attrs.kernel_w, - attrs.stride_h, - attrs.stride_w, - attrs.pool_type); + DeviceSpecific state = init_kernel(handle, + attrs.activation, + input_w, + input_h, + input_c, + input_n, + output_w, + output_h, + output_c, + output_n, + pad_h, + pad_w, + attrs.kernel_h, + attrs.kernel_w, + attrs.stride_h, + attrs.stride_w, + attrs.pool_type); return state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - OpTaskInvocation forward(Pool2DAttrs const &attrs) { OpTaskBinding binding; binding.bind(INPUT, input_tensor(0)); @@ -102,54 +92,47 @@ OpTaskInvocation forward(Pool2DAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(PER_DEVICE_STATE, - per_device_op_state()); + per_device_op_state()); return {POOL2D_FWD_TASK_ID, binding}; } -OpTaskInvocation backward(Pool2DAttrs const &) { +OpTaskInvocation backward(Pool2DAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); return {POOL2D_BWD_TASK_ID, b}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - Pool2dPerDeviceState state = - acc.get_argument(PER_DEVICE_STATE); + Pool2DPerDeviceState state = + acc.get_argument(PER_DEVICE_STATE); - auto input = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); + auto input = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); return profile(forward_kernel, - profilng, - "[Pool2D] forward_time = %.2lfms\n", + profiling, + "[Pool2D] forward_time = {:.2lf}ms\n", state, input.get_float_ptr(), output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - Pool2dPerDeviceState state = - acc.get_argument(PER_DEVICE_STATE); + Pool2DPerDeviceState state = + acc.get_argument(PER_DEVICE_STATE); - auto input = acc.get_tensor(INPUT); - auto input_grad = acc.get_tensor(INPUT); - auto output = acc.get_tensor(OUTPUT); - auto output_grad = acc.get_tensor(OUTPUT); + auto input = acc.get_tensor(INPUT); + auto input_grad = acc.get_tensor(INPUT); + auto output = acc.get_tensor(OUTPUT); + auto output_grad = acc.get_tensor(OUTPUT); return profile(backward_kernel, - profilng, - "[Pool2D] backward_time = %.2lfms\n", + profiling, + "[Pool2D] backward_time = {:.2lf}ms\n", state, input.get_float_ptr(), input_grad.get_float_ptr(), @@ -157,20 +140,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { output_grad.get_float_ptr()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, Pool2DAttrs const &attrs, - ParallelTensorShape const &input, + InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); SimTaskBinding init_binding; @@ -181,21 +156,21 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, auto init_accessor = env.get_init_accessor(POOL2D_INIT_TASK_ID, init_binding); - DeviceSpecific per_device_state = + DeviceSpecific per_device_state = init_task_impl(init_accessor); SimTaskBinding fwd_binding; - fwd_binding.bind(INPUT, input_shape); + fwd_binding.bind(INPUT, input.shape); fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); - auto fwd_accessor = env.get_accessor(POOL2D_FWD_TASK_ID, fwd_binding); + auto fwd_accessor = env.get_fwd_accessor(POOL2D_FWD_TASK_ID, fwd_binding); SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); - auto bwd_accessor = env.get_accessor(POOL2D_BWD_TASK_ID, bwd_binding); + auto bwd_accessor = env.get_bwd_accessor(POOL2D_BWD_TASK_ID, bwd_binding); float forward_time = forward_task_impl(fwd_accessor).value(); float backward_time = backward_task_impl(bwd_accessor).value(); @@ -217,7 +192,7 @@ void register_task() { init.add_return_value(); - register_task(POOL2D_INIT_TASK_ID, "Pool2D::init", init, init_taks); + register_task(POOL2D_INIT_TASK_ID, "Pool2D::init", init, init_task_impl); } template <> @@ -228,17 +203,20 @@ void register_task() { fwd.add_output_slot(OUTPUT); fwd.add_arg_slot(PROFILING); - fwd.add_arg_slot(PER_DEVICE_STATE); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); - register_task(POOL2D_FWD_TASK_ID, "Pool2D::forward", fwd, forward_task); + register_task(POOL2D_FWD_TASK_ID, "Pool2D::forward", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(POOL2D_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(POOL2D_BWD_TASK_ID, "Pool2D::backward", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(POOL2D_FWD_TASK_ID)); + +// register_task(POOL2D_BWD_TASK_ID, "Pool2D::backward", bwd, +// backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/pool_2d.h b/lib/local-execution/src/ops/pool_2d.h similarity index 95% rename from lib/runtime/src/ops/pool_2d.h rename to lib/local-execution/src/ops/pool_2d.h index f8701f461e..0537e9f1c4 100644 --- a/lib/runtime/src/ops/pool_2d.h +++ b/lib/local-execution/src/ops/pool_2d.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_POOL_2D_H #define _FLEXFLOW_POOL_2D_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/pool_2d.h" -#include "op_task_invocation.h" -#include "sim_environment.h" namespace FlexFlow { @@ -20,7 +20,7 @@ OpTaskInvocation backward(Pool2DAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, Pool2DAttrs const &attrs, - ParallelTensorShape const &input_shape, + InputParallelTensorDesc const &input_shape, ProfilingSettings const &settings, MachineView const &machine_view); diff --git a/lib/runtime/src/ops/reduce.cc b/lib/local-execution/src/ops/reduce.cc similarity index 57% rename from lib/runtime/src/ops/reduce.cc rename to lib/local-execution/src/ops/reduce.cc index 2674dc4fef..98d1a6f522 100644 --- a/lib/runtime/src/ops/reduce.cc +++ b/lib/local-execution/src/ops/reduce.cc @@ -1,27 +1,12 @@ #include "reduce.h" #include "kernels/reduce_kernels.h" -#include "legion/legion_utilities.h" -#include "op-attrs/get_output_shape.h" -#include "utils/exceptions.h" + +#include "op-attrs/get_output_shapes.h" +#include "utils/exception.h" #include "utils/hash-utils.h" #include "utils/type_traits_core.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Reduce; @@ -35,7 +20,7 @@ enum Slots { HANDLE }; -OpTaskInvocation init(TransposeAttrs const &attrs) { +OpTaskInvocation init(ReduceAttrs const &attrs) { OpTaskBinding binding; binding.bind_arg(HANDLE, ff_handle()); @@ -54,42 +39,33 @@ static DeviceSpecific auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - OperatorType = attrs.op_type; + OperatorType op_type = attrs.op_type; // Note: How to set the reduction size? size_t reduction_size = input.shape.get_volume() / output.shape.get_volume(); DeviceSpecific per_device_state = - acc.create_device_specific(init_kernel( - handle, op_type, reduction_size, input.shape, output.shape)); + init_kernel(handle, op_type, reduction_size, input.shape, output.shape); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - template <> void register_task() { - OpTaskSignature init(OpTaskType::INIT) + OpTaskSignature init(OpTaskType::INIT); - init.add_unchecked_arg_slot(HANDLE); + init.add_unchecked_arg_slot(HANDLE); init.add_arg_slot(ATTRS); init.add_return_value(); - register_task(REDUCE_INIT_TASK_ID, "Reduce::init", init, init_task); + register_task(REDUCE_INIT_TASK_ID, "Reduce::init", init, init_task_impl); } // Note: forward_kernel only needs ReducePerDeviceState, input, output OpTaskInvocation forward(ReduceAttrs const &attrs) { OpTaskBinding binding; - bind.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - bind.bind_arg(PROFILING, profiling_tensor()); + binding.bind_arg(PER_DEVICE_STATE, + per_device_op_state()); + binding.bind_arg(PROFILING, profiling_settings()); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -97,7 +73,7 @@ OpTaskInvocation forward(ReduceAttrs const &attrs) { return {REDUCE_FWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -107,31 +83,23 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[Reduce] forward_time = %.2lfms\n", + "[Reduce] forward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - template <> void register_task() { - OpTaskSignature fwd(OpTaskType::FORWARD); + OpTaskSignature fwd(OpTaskType::FWD); - fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); + fwd.add_unchecked_arg_slot(PER_DEVICE_STATE); fwd.add_arg_slot(PROFILING); fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(REDUCE_FWD_TASK_ID, "Reduce::forward", fwd, forward_task); + register_task(REDUCE_FWD_TASK_ID, "Reduce::forward", fwd, forward_task_impl); } OpTaskInvocation backward(ReduceAttrs const &attrs) { @@ -140,48 +108,44 @@ OpTaskInvocation backward(ReduceAttrs const &attrs) { return {REDUCE_BWD_TASK_ID, binding}; } -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); return profile(backward_kernel, profiling, - "[Reduce] backward_time = %.2lfms\n", + "[Reduce] backward_time = {:.2lf}ms\n", per_device_state, - input.get_float_ptr(), - output.get_float_ptr()); + output_grad.get_float_ptr(), + input_grad.get_float_ptr()); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} +// TODO: OpTaskSignature -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(REDUCE_FWD_TASK_ID)); +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(REDUCE_FWD_TASK_ID)); - reister_task(REDUCE_BWD_TASK_ID, "Reduce::backward", bwd, backward_task); -} +// register_task(REDUCE_BWD_TASK_ID, "Reduce::backward", bwd, +// backward_task_impl); +// } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ReduceAttrs const &attrs, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); SimTaskBinding init_binding; init_binding.bind_arg(ATTRS, attrs); - binding.bind_arg(HANDLE, ff_handle()); + init_binding.bind_arg(HANDLE, ff_handle()); auto init_accessor = env.get_init_accessor(REDUCE_INIT_TASK_ID, init_binding); DeviceSpecific per_device_state = @@ -189,10 +153,10 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, SimTaskBinding fwd_binding; ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); - fwd.bind(INPUT, input.shape); - fwd.bind(OUTPUT, output_shape); - fwd.bind_arg(PROFILING, settings); - fwd.bind_arg(PER_DEVICE_STATE, per_device_state); + fwd_binding.bind(INPUT, input.shape); + fwd_binding.bind(OUTPUT, output_shape); + fwd_binding.bind_arg(PROFILING, settings); + fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); SimTaskBinding bwd_binding = infer_bwd_binding(fwd_binding); diff --git a/lib/runtime/src/ops/reduce.h b/lib/local-execution/src/ops/reduce.h similarity index 96% rename from lib/runtime/src/ops/reduce.h rename to lib/local-execution/src/ops/reduce.h index 099083ed67..6d47ec2f4d 100644 --- a/lib/runtime/src/ops/reduce.h +++ b/lib/local-execution/src/ops/reduce.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_RUNTIME_SRC_OPS_REDUCE_H #define _FLEXFLOW_RUNTIME_SRC_OPS_REDUCE_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/reduce.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/reduction.cc b/lib/local-execution/src/ops/reduction.cc similarity index 59% rename from lib/runtime/src/ops/reduction.cc rename to lib/local-execution/src/ops/reduction.cc index 9a11d3a6f5..3fa300f64d 100644 --- a/lib/runtime/src/ops/reduction.cc +++ b/lib/local-execution/src/ops/reduction.cc @@ -13,32 +13,14 @@ * limitations under the License. */ -#include "parallel_ops/reduction.h" +#include "reduction.h" #include "kernels/reduction_kernels.h" -#include "op-attrs/get_output_shape.h" -#include "utils/exceptions.h" +#include "op-attrs/get_output_shapes.h" +#include "utils/exception.h" #include "utils/hash-utils.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Reduction; @@ -61,7 +43,7 @@ OpTaskInvocation backward(ReductionAttrs const &attrs) { return {REDUCTION_BWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling_settings = acc.get_argument(PROFILING); @@ -71,40 +53,25 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { size_t num_replicas = attrs.reduction_degree; - return profiling(forward_kernel, - profiling_settings, - "[Reduction] forward_time = %.2lfms\n", - input, - output, - num_replicas); + return profile(forward_kernel, + profiling_settings, + "[Reduction] forward_time = {:.2lf}ms\n", + input, + output, + num_replicas); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); - return profiling(backward_kernel, - profiling, - "[Reduction] backward_time = %.2lfms\n", - input_grad, - output_grad); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); + return profile(backward_kernel, + profiling, + "[Reduction] backward_time = {:.2lf}ms\n", + input_grad, + output_grad); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -114,13 +81,13 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, MachineView const &machine_view) { ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); SimTaskBinding fwd_binding; fwd_binding.bind_arg(PROFILING, settings); fwd_binding.bind_arg(ATTRS, attrs); fwd_binding.bind(INPUT, input.shape); - fwd.binding.bind(OUTPUT, output_shape); + fwd_binding.bind(OUTPUT, output_shape); auto fwd_accessor = env.get_fwd_accessor(REDUCTION_FWD_TASK_ID, fwd_binding); @@ -145,15 +112,18 @@ void register_task() { fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(REDUCTION_FWD_TASK_ID, "Reduction Fwd", fwd, forward_task); + register_task(REDUCTION_FWD_TASK_ID, "Reduction Fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(REDUCTION_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(REDUCTION_BWD_TASK_ID, "Reduction Bwd", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(REDUCTION_FWD_TASK_ID)); + +// register_task(REDUCTION_BWD_TASK_ID, "Reduction Bwd", bwd, +// backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/reduction.h b/lib/local-execution/src/ops/reduction.h similarity index 94% rename from lib/runtime/src/ops/reduction.h rename to lib/local-execution/src/ops/reduction.h index 978ca6b080..a69b75f310 100644 --- a/lib/runtime/src/ops/reduction.h +++ b/lib/local-execution/src/ops/reduction.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_REDUCTION_H #define _FLEXFLOW_REDUCTION_H -#include "op-attrs/ops/combine.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" +#include "op-attrs/ops/reduction.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/repartition.h b/lib/local-execution/src/ops/repartition.h similarity index 97% rename from lib/runtime/src/ops/repartition.h rename to lib/local-execution/src/ops/repartition.h index fccc0de7be..a73bd3f808 100644 --- a/lib/runtime/src/ops/repartition.h +++ b/lib/local-execution/src/ops/repartition.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_PARTITION_H #define _FLEXFLOW_PARTITION_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/repartition.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/replicate.cc b/lib/local-execution/src/ops/replicate.cc similarity index 66% rename from lib/runtime/src/ops/replicate.cc rename to lib/local-execution/src/ops/replicate.cc index 1675a62c5f..a441985b78 100644 --- a/lib/runtime/src/ops/replicate.cc +++ b/lib/local-execution/src/ops/replicate.cc @@ -13,39 +13,21 @@ * limitations under the License. */ -#include "parallel_ops/replicate.h" +#include "replicate.h" #include "kernels/replicate_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/exceptions.h" +#include "utils/exception.h" #include "utils/graph/serialparallel.h" #include "utils/hash-utils.h" #include namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::LogicalPartition; -using Legion::LogicalRegion; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Replicate; -enum Slots { INPUT, OUTPUT, PROFILING }; +enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; OpTaskInvocation forward(ReplicateAttrs const &attrs) { OpTaskBinding binding; @@ -54,6 +36,7 @@ OpTaskInvocation forward(ReplicateAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); + binding.bind_arg(ATTRS, attrs); return {REPLICATE_FWD_TASK_ID, binding}; } @@ -63,7 +46,7 @@ OpTaskInvocation backward(ReplicateAttrs const &attrs) { return {REPLICATE_BWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); @@ -71,38 +54,25 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[replicate] forward_time = %.2lfms\n", + "[replicate] forward_time = {:.2lf}ms\n", input, output); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); + auto const &attrs = acc.get_argument(ATTRS); return profile(backward_kernel, profiling, - "[replicate] backward_time = %.2lfms\n", + "[replicate] backward_time = {:.2lf}ms\n", input_grad, - output_grad); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + output_grad, + attrs.replicate_degree); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -110,7 +80,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); SimTaskBinding fwd_binding; fwd_binding.bind_arg(PROFILING, settings); ParallelTensorShape output = get_output_shape(attrs, input.shape); @@ -136,14 +106,18 @@ void register_task() { fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(REPLICATE_FWD_TASK_ID, "Replicate fwd", fwd, forward_task); + register_task(REPLICATE_FWD_TASK_ID, "Replicate fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(REPLICATE_BWD_TASK_ID, "Replicate bwd", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(CAST_FWD_TASK_ID)); + +// register_task(REPLICATE_BWD_TASK_ID, "Replicate bwd", bwd, +// backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/replicate.h b/lib/local-execution/src/ops/replicate.h similarity index 89% rename from lib/runtime/src/ops/replicate.h rename to lib/local-execution/src/ops/replicate.h index da2b71f098..339f805f2c 100644 --- a/lib/runtime/src/ops/replicate.h +++ b/lib/local-execution/src/ops/replicate.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_REPLICATE_H #define _FLEXFLOW_REPLICATE_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/replicate.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/reshape.cc b/lib/local-execution/src/ops/reshape.cc similarity index 68% rename from lib/runtime/src/ops/reshape.cc rename to lib/local-execution/src/ops/reshape.cc index c9dc8cff8d..efee73645b 100644 --- a/lib/runtime/src/ops/reshape.cc +++ b/lib/local-execution/src/ops/reshape.cc @@ -15,24 +15,10 @@ #include "reshape.h" #include "kernels/reshape_kernels.h" -#include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" namespace FlexFlow { // declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Reshape; @@ -69,74 +55,50 @@ static DeviceSpecific auto attrs = acc.get_argument(ATTRS); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(attrs.shape.data_type)); + init_kernel(attrs.shape.data_type); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - Profiling profiling = acc.get_argument(PROFILING); + ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); return profile(forward_kernel, profiling, - "[Reshape] forward time = %.2lfms\n", + "[Reshape] forward time = {:.2lf}ms\n", per_device_state, input, output); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto per_device_state = acc.get_argument(PER_DEVICE_STATE); - Profiling profiling = acc.get_argument(PROFILING); + ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); return profile(backward_kernel, profiling, - "[Reshape] backward time = %.2lfms\n", + "[Reshape] backward time = {:.2lf}ms\n", per_device_state, input_grad, output_grad); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, ReshapeAttrs const &attrs, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { + auto env = sim_factory.new_environment(); SimTaskBinding init_binding; init_binding.bind_arg(ATTRS, attrs); auto init_accessor = @@ -168,9 +130,9 @@ void register_task() { init.add_arg_slot(ATTRS); - init.add_return_value(PER_DEVICE_STATE); + init.add_return_value(); - register_task(RESHAPE_INIT_TASK_ID, "Reshape Init", init, init_task); + register_task(RESHAPE_INIT_TASK_ID, "Reshape Init", init, init_task_impl); } template <> @@ -183,15 +145,17 @@ void register_task() { fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(RESHAPE_FWD_TASK_ID, "Reshape Fwd", fwd, forward_task); + register_task(RESHAPE_FWD_TASK_ID, "Reshape Fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_binding(get_op_signature(RESHAPE_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(RESHAPE_BWD_TASK_ID, "Reshape Bwd", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_binding(get_op_signature(RESHAPE_FWD_TASK_ID)); + +// register_task(RESHAPE_BWD_TASK_ID, "Reshape Bwd", bwd, backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/reshape.h b/lib/local-execution/src/ops/reshape.h similarity index 97% rename from lib/runtime/src/ops/reshape.h rename to lib/local-execution/src/ops/reshape.h index f044e3f057..14b22561a0 100644 --- a/lib/runtime/src/ops/reshape.h +++ b/lib/local-execution/src/ops/reshape.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_RESHAPE_H #define _FLEXFLOW_RESHAPE_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/reshape.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc similarity index 67% rename from lib/runtime/src/ops/reverse.cc rename to lib/local-execution/src/ops/reverse.cc index ac64146cd1..7fefb3d357 100644 --- a/lib/runtime/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -19,23 +19,9 @@ #include "op-attrs/get_output_shapes.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; using namespace FlexFlow::Kernels::Reverse; +using coord_t = long long; enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; @@ -43,7 +29,7 @@ OpTaskInvocation forward(ReverseAttrs const &attrs) { OpTaskBinding binding; binding.bind_arg(PROFILING, profiling_settings()); - bind.bind_arg(ATTRS, attrs); + binding.bind_arg(ATTRS, attrs); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -56,28 +42,28 @@ OpTaskInvocation backward(ReverseAttrs const &attrs) { return {REVERSE_BWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); auto attrs = acc.get_argument(ATTRS); - int output_size = outtput.shape.get_volume(); + int output_size = output.shape.get_volume(); auto axis = attrs.axis; coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < output.shape.get_dim(); i++) { if (i < axis) { - in_blk_size *= output.shape[i]; + in_blk_size *= output.shape.at(ff_dim_t(i)); } else if (i == axis) { - reverse_dim_size = output.shape[i]; + reverse_dim_size = output.shape.at(ff_dim_t(i)); } else { - num_out_blks *= output.shape[i]; + num_out_blks *= output.shape.at(ff_dim_t(i)); } } return profile(forward_kernel, profiling, - "[reverse] forward_time = %.2lfms\n", + "[reverse] forward_time = {:.2lf}ms\n", input.get_float_ptr(), output.get_float_ptr(), num_out_blks, @@ -86,49 +72,34 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { output_size); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); auto attrs = acc.get_argument(ATTRS); - int axis = input.shape.get_dim() - attrs.axis - 1; + int axis = input_grad.shape.get_dim() - attrs.axis.value() - 1; coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < input_grad.shape.get_dim(); i++) { if (i < axis) { - in_blk_size *= input_grad.shape[i]; + in_blk_size *= input_grad.shape.at(ff_dim_t(i)); } else if (i == axis) { - reverse_dim_size = input_grad.shape[i]; + reverse_dim_size = input_grad.shape.at(ff_dim_t(i)); } else { - num_out_blks *= input_grad.shape[i]; + num_out_blks *= input_grad.shape.at(ff_dim_t(i)); } } return profile(backward_kernel, profiling, - "[reverse] backward_time = %.2lfms\n", + "[reverse] backward_time = {:.2lf}ms\n", output_grad.get_float_ptr(), input_grad.get_float_ptr(), num_out_blks, reverse_dim_size, in_blk_size, - input.shape.get_volume()); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + input_grad.shape.get_volume()); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -136,7 +107,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); SimTaskBinding fwd_binding; @@ -161,21 +132,23 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, } template <> -void register_task()) { +void register_task() { OpTaskSignature fwd(OpTaskType::FWD); fwd.add_arg_slot(PROFILING); fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(REVERSE_FWD_TASK_ID, "Reverse forward", fwd, forward_task); + register_task(REVERSE_FWD_TASK_ID, "Reverse forward", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(REVERSE_BWD_TASK_ID)); - register_task(REVERSE_BWD_TASK_ID, "Reverse backward", bwd, backward_task); -} +// TODO: OpTaskSignature +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(REVERSE_BWD_TASK_ID)); +// register_task(REVERSE_BWD_TASK_ID, "Reverse backward", bwd, +// backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/reverse.h b/lib/local-execution/src/ops/reverse.h similarity index 89% rename from lib/runtime/src/ops/reverse.h rename to lib/local-execution/src/ops/reverse.h index af4d335429..5be501698c 100644 --- a/lib/runtime/src/ops/reverse.h +++ b/lib/local-execution/src/ops/reverse.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_REVERSE_H_ #define _FLEXFLOW_REVERSE_H_ +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/reverse.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc similarity index 68% rename from lib/runtime/src/ops/softmax.cc rename to lib/local-execution/src/ops/softmax.cc index b67f9730a4..ea857c680b 100644 --- a/lib/runtime/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -17,26 +17,10 @@ #include "kernels/softmax_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/exceptions.h" +#include "utils/exception.h" #include "utils/hash-utils.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - using namespace FlexFlow::Kernels::Softmax; enum Slots { INPUT, OUTPUT, ATTRS, PROFILING, PER_DEVICE_STATE, HANDLE }; @@ -75,21 +59,11 @@ static DeviceSpecific auto const &attrs = acc.get_argument(ATTRS); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(handle, attrs.dim)); + init_kernel(handle, attrs.dim.value()); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); ProfilingSettings profiling = acc.get_argument(PROFILING); @@ -98,21 +72,14 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { return profile(forward_kernel, profiling, - "[SoftMax] forward_time = %.2lfms\n", + "[SoftMax] forward_time = {:.2lf}ms\n", per_device_state, input.get_float_ptr(), - output.get_float_ptr(), ); + output.get_float_ptr()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); @@ -124,22 +91,12 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { assert(output_grad.shape == output.shape); - return profile( - backward_kernel, - profiling, - "[SoftMax] backward_time = %.2lfms\n", - input_grad.get_float_ptr(), - output_grad.get_float_ptr(), - output_grad.shape.volume(), // Note(lambda): get num_elements, maybe wrong - ); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + return profile(backward_kernel, + profiling, + "[SoftMax] backward_time = {:.2lf}ms\n", + input_grad.get_float_ptr(), + output_grad.get_float_ptr(), + output_grad.shape.get_volume()); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -147,7 +104,7 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); @@ -162,7 +119,6 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, init_task_impl(init_accessor); SimTaskBinding fwd_binding; - ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); fwd_binding.bind(INPUT, input.shape); fwd_binding.bind(OUTPUT, output_shape); fwd_binding.bind_arg(PROFILING, settings); @@ -186,9 +142,9 @@ void register_task() { init.add_unchecked_arg_slot(HANDLE); init.add_arg_slot(ATTRS); - init.add_return_value_slot(); + init.add_return_value(); - register_task(SOFTMAX_INIT_TASK_ID, "SoftMax Init", init, init_task); + register_task(SOFTMAX_INIT_TASK_ID, "SoftMax Init", init, init_task_impl); } template <> @@ -201,15 +157,17 @@ void register_task() { fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(SOFTMAX_FWD_TASK_ID, "SoftMax Fwd", fwd, forward_task); + register_task(SOFTMAX_FWD_TASK_ID, "SoftMax Fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(SOFTMAX_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(SOFTMAX_BWD_TASK_ID, "SoftMax Bwd", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(SOFTMAX_FWD_TASK_ID)); + +// register_task(SOFTMAX_BWD_TASK_ID, "SoftMax Bwd", bwd, backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/softmax.h b/lib/local-execution/src/ops/softmax.h similarity index 97% rename from lib/runtime/src/ops/softmax.h rename to lib/local-execution/src/ops/softmax.h index 06b9d09d60..a83d8f4116 100644 --- a/lib/runtime/src/ops/softmax.h +++ b/lib/local-execution/src/ops/softmax.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_SOFTMAX_H #define _FLEXFLOW_SOFTMAX_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/softmax.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/split.cc b/lib/local-execution/src/ops/split.cc similarity index 58% rename from lib/runtime/src/ops/split.cc rename to lib/local-execution/src/ops/split.cc index 2af5d42874..13e95d37f9 100644 --- a/lib/runtime/src/ops/split.cc +++ b/lib/local-execution/src/ops/split.cc @@ -16,28 +16,14 @@ #include "split.h" #include "kernels/array_shape.h" #include "kernels/split_kernels.h" -#include "utils/exceptions.h" +#include "op-attrs/get_output_shapes.h" +#include "utils/exception.h" #include "utils/hash-utils.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; using namespace FlexFlow::Kernels::Split; +using coord_t = long long; enum Slots { INPUT, OUTPUT, ATTRS, PROFILING }; @@ -58,96 +44,86 @@ OpTaskInvocation backward(SplitAttrs const &attrs) { return {SPLIT_BWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +void calc_block_size(coord_t &num_blocks, + coord_t &block_size, + ArrayShape const &array_shape, + int axis) { + num_blocks = 1; + block_size = 1; + for (int d = 0; d < array_shape.num_elements(); d++) { + if (d <= axis) { + block_size *= array_shape.at(legion_dim_t(d)); + } else { + num_blocks *= array_shape.at(legion_dim_t(d)); + } + } +} + +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); auto attrs = acc.get_argument(ATTRS); - coord_t num_blks, in_blk_size, out_blk_size[MAX_NUM_OUTPUTS]; - calc_block_size(num_blks, in_blk_size, input.shape, attrs.axis); + coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS]; + calc_block_size(num_blocks, in_block_size, input.shape, attrs.axis.value()); for (int i = 0; i < attrs.splits.size(); i++) { - coord_t out_num_blks; + coord_t out_num_blocks; calc_block_size( - out_num_blks, out_blk_size[i], output.shape, split->legion_axis); + out_num_blocks, out_block_size[i], output.shape, attrs.axis.value()); } + float *output_float_ptr = output.get_float_ptr(); return profile(forward_kernel, profiling, - "Split forward_time = %.2lfms\n", - &output.get_float_ptr(), + "Split forward_time = {:.2lf}ms\n", + &output_float_ptr, input.get_float_ptr(), - out_blk_size, - in_blk_size, - num_blks, + out_block_size, + in_block_size, + num_blocks, attrs.splits.size()); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - // maybe we should add assert like the original code -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); auto output_grad = acc.get_tensor_grad(OUTPUT); auto attrs = acc.get_argument(ATTRS); - coord_t num_blks, in_blk_size, out_blk_size[MAX_NUM_OUTPUTS]; - calc_block_size(num_blks, in_blk_size, input_grade.shape, attrs.axis); + coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS]; + calc_block_size( + num_blocks, in_block_size, input_grad.shape, attrs.axis.value()); for (int i = 0; i < attrs.splits.size(); i++) { - coord_t out_num_blks; - calc_block_size( - out_num_blks, out_blk_size[i], output_grad.shape, split->legion_axis); + coord_t out_num_blocks; + calc_block_size(out_num_blocks, + out_block_size[i], + output_grad.shape, + attrs.axis.value()); } + float const *output_grad_ptr = output_grad.get_float_ptr(); return profile(backward_kernel, profiling, - "Split backward_time = %.2lfms\n", + "Split backward_time = {:.2lf}ms\n", input_grad.get_float_ptr(), - &output_grad.get_float_ptr(), - out_blk_size, - in_blk_size, - num_blks, + &output_grad_ptr, + out_block_size, + in_block_size, + num_blocks, attrs.splits.size()); } -void calc_block_size(coord_t &num_blks, - coord_t &blk_size, - ArrayShape const &array_shape, - int axis) { - num_blks = 1; - blk_size = 1; - for (int d = 0; d < array_shape.get_dim(); d++) { - if (d <= axis) { - blk_size *= (domain.hi()[d] - domain.lo()[d] + 1); - blk_size *= array_shape.at(legion_dim_t(d)) + 1 - } else { - num_blks *= array_shape.at(legion_dim_t(d)) + 1 - } - } -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, SplitAttrs const &attrs, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); - ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); + std::vector output_shape = + get_output_shapes(attrs, input.shape); SimTaskBinding fwd_binding; fwd_binding.bind(INPUT, input.shape); @@ -166,6 +142,8 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, return make_metrics(forward_time, backward_time, sync_time, env); } +// TODO: OpTaskSignature + template <> void register_task() { OpTaskSignature fwd(OpTaskType::FWD); @@ -174,15 +152,15 @@ void register_task() { fwd.add_input_slot(INPUT); fwd.add_output_slot(OUTPUT); - register_task(SPLIT_FWD_TASK_ID, "Split Fwd", fwd, forward_task); + register_task(SPLIT_FWD_TASK_ID, "Split Fwd", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = - infer_bwd_signature(get_op_signature(SPLIT_FWD_TASK_ID)); +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(SPLIT_FWD_TASK_ID)); - register_task(SPLIT_BWD_TASK_ID, "Split Bwd", bwd, backward_task); -} +// register_task(SPLIT_BWD_TASK_ID, "Split Bwd", bwd, backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/split.h b/lib/local-execution/src/ops/split.h similarity index 93% rename from lib/runtime/src/ops/split.h rename to lib/local-execution/src/ops/split.h index d63212e836..f51e0ea6af 100644 --- a/lib/runtime/src/ops/split.h +++ b/lib/local-execution/src/ops/split.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_SPLIT_H #define _FLEXFLOW_SPLIT_H +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/split.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { @@ -20,7 +20,7 @@ OpTaskInvocation backward(SplitAttrs const &); CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, SplitAttrs const &attrs, - InputParallelTensorDes const &input, + InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view); diff --git a/lib/runtime/src/ops/topk.cc b/lib/local-execution/src/ops/topk.cc similarity index 60% rename from lib/runtime/src/ops/topk.cc rename to lib/local-execution/src/ops/topk.cc index 958516a6d9..8aceb9c6d4 100644 --- a/lib/runtime/src/ops/topk.cc +++ b/lib/local-execution/src/ops/topk.cc @@ -16,28 +16,9 @@ #include "topk.h" #include "kernels/topk_kernels.h" #include "op-attrs/get_output_shapes.h" -#include "utils/exceptions.h" +#include "utils/exception.h" namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::InlineLauncher; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; using namespace FlexFlow::Kernels::TopK; @@ -50,7 +31,7 @@ enum Slots { INPUT, OUTPUT, INDICES, ATTRS, PROFILING, PER_DEVICE_STATE }; OpTaskInvocation init(TopKAttrs const &attrs) { OpTaskBinding binding; - bind.bind_arg(ATTRS, attrs); + binding.bind_arg(ATTRS, attrs); return {TOPK_INIT_TASK_ID, binding}; } @@ -60,7 +41,7 @@ OpTaskInvocation forward(TopKAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); binding.bind_arg(PROFILING, profiling_settings()); - bind.bind_arg(ATTRS, attrs); + binding.bind_arg(ATTRS, attrs); binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); @@ -81,23 +62,14 @@ static DeviceSpecific auto attrs = acc.get_argument(ATTRS); DeviceSpecific per_device_state = - acc.create_device_specific(init_kernel(attrs.sorted)); + init_kernel(attrs.sorted); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} - -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); auto per_device_state = - acc.get_device_specific(PER_DEVICE_STATE); + acc.get_argument(PER_DEVICE_STATE); auto profiling = acc.get_argument(PROFILING); auto input = acc.get_tensor(INPUT); @@ -107,31 +79,24 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { size_t batch_size = input.shape.get_volume() / length; auto indices = acc.get_tensor(INDICES); - return profiling(forward_kernel, - profiling, - "[TopK] forward_time = %.2lfms\n", - per_device_state, - input.get_float_ptr(), - output.get_float_ptr(), - indices.get_int32_ptr(), - batch_size, - length, - attrs.k, - attrs.sorted); + return profile(forward_kernel, + profiling, + "[TopK] forward_time = {:.2lf}ms\n", + per_device_state, + input.get_float_ptr(), + output.get_float_ptr(), + indices.get_int32_ptr(), + batch_size, + length, + attrs.k, + attrs.sorted); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); auto per_device_state = - acc.get_device_specific(PER_DEVICE_STATE); + acc.get_argument(PER_DEVICE_STATE); auto profiling = acc.get_argument(PROFILING); auto input_grad = acc.get_tensor_grad(INPUT); @@ -139,27 +104,19 @@ static optional backward_task_impl(TaskArgumentAccessor const &acc) { auto indices = acc.get_tensor(INDICES); - int length = input.shape.at(legion_dim_t(0)) + 1; - size_t batch_size = input.shape.get_volume() / length; - - return profiling(backward_kernel, - profiling, - "[TopK] backward_time = %.2lfms\n", - per_device_state, - output_grad.get_float_ptr(), - indices.get_int32_ptr(), - input_grad.get_float_ptr(), - batch_size, - length, - attrs.k); -} - -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); + int length = input_grad.shape.at(legion_dim_t(0)) + 1; + size_t batch_size = input_grad.shape.get_volume() / length; + + return profile(backward_kernel, + profiling, + "[TopK] backward_time = {:.2lf}ms\n", + per_device_state, + output_grad.get_float_ptr(), + indices.get_int32_ptr(), + input_grad.get_float_ptr(), + batch_size, + length, + attrs.k); } CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, @@ -167,9 +124,9 @@ CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, InputParallelTensorDesc const &input, ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); - ParallelTensorShape output_shape = get_output_shapes(attrs, input.shape); + ParallelTensorShape output_shape = get_output_shape(attrs, input.shape); SimTaskBinding init_binding; init_binding.bind_arg(ATTRS, attrs); @@ -204,7 +161,7 @@ void register_task() { init.add_arg_slot(ATTRS); // Note: this may have some question init.add_return_value(); - register_task(TOPK_INIT_TASK_ID, "Topk Init", init, init_task); + register_task(TOPK_INIT_TASK_ID, "Topk Init", init, init_task_impl); } template <> @@ -219,14 +176,17 @@ void register_task() { fwd.add_output_slot(OUTPUT); fwd.add_output_slot(INDICES); - register_task(TOPK_FWD_TASK_ID, "TopK Forward", fwd, forward_task); + register_task(TOPK_FWD_TASK_ID, "TopK Forward", fwd, forward_task_impl); } -template <> -void register_task() { - OpTaskSignature bwd = infer_bwd_signature(get_op_signature(TOPK_FWD_TASK_ID)); +// TODO: OpTaskSignature - register_task(TOPK_BWD_TASK_ID, "TopK Backward", bwd, backward_task); -} +// template <> +// void register_task() { +// OpTaskSignature bwd = +// infer_bwd_signature(get_op_signature(TOPK_FWD_TASK_ID)); + +// register_task(TOPK_BWD_TASK_ID, "TopK Backward", bwd, backward_task_impl); +// } }; // namespace FlexFlow diff --git a/lib/runtime/src/ops/topk.h b/lib/local-execution/src/ops/topk.h similarity index 97% rename from lib/runtime/src/ops/topk.h rename to lib/local-execution/src/ops/topk.h index f15ff6de81..db85fd9d03 100644 --- a/lib/runtime/src/ops/topk.h +++ b/lib/local-execution/src/ops/topk.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_TOPK_H_ #define _FLEXFLOW_TOPK_H_ +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/topk.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc similarity index 55% rename from lib/runtime/src/ops/transpose.cc rename to lib/local-execution/src/ops/transpose.cc index ea6182772f..c998484455 100644 --- a/lib/runtime/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -15,27 +15,10 @@ #include "transpose.h" #include "kernels/transpose_kernels.h" -#include "legion/legion_utilities.h" +#include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/transpose.h" #include "utils/exception.decl.h" -namespace FlexFlow { -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; - using namespace FlexFlow::Kernels::Transpose; namespace FlexFlow { @@ -57,33 +40,26 @@ OpTaskInvocation init(TransposeAttrs const &attrs) { static DeviceSpecific init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); - std::vector perm = attrs.perm; // default convert stack_vector to vector + std::vector perm = static_cast>(attrs.perm); DeviceSpecific per_device_state = - acc.create_device_specific( - init_kernel(perm.size(), perm)); + init_kernel(perm.size(), perm); return per_device_state; } -static DeviceSpecific - init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - return init_task_impl(acc); -} +// TODO: OpTaskSignature -template <> -void register_task(); -OpTaskSignature init(OpTaskType::INIT) +// template <> +// void register_task() { +// OpTaskSignature init(OpTaskType::INIT); - init.add_arg_slot(ATTRS); +// init.add_arg_slot(ATTRS); -init.add_return_value(); +// init.add_return_value(); -register_task(TRANSPOSE_INIT_TASK_ID, "Transpose::init", init, init_task); -} // namespace FlexFlow +// register_task(TRANSPOSE_INIT_TASK_ID, "Transpose::init", init, +// init_task_impl); +// } OpTaskInvocation forward(TransposeAttrs const &attrs) { OpTaskBinding binding; @@ -92,13 +68,13 @@ OpTaskInvocation forward(TransposeAttrs const &attrs) { per_device_op_state()); binding.bind_arg(PROFILING, profiling_settings()); - bind.bind(INPUT, input_tensor(0)); - bind.bind(OUTPUT, output_tensor(0)); + binding.bind(INPUT, input_tensor(0)); + binding.bind(OUTPUT, output_tensor(0)); return {TRANSPOSE_FWD_TASK_ID, binding}; } -static optional forward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = acc.get_argument(PER_DEVICE_STATE); @@ -106,47 +82,32 @@ static optional forward_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - return profiling(forward_kernel, - profiling, - "[Transpose] Forward_time = %.2lf [ms]", - per_device_state, - input, - output); + return profile(forward_kernel, + profiling, + "[Transpose] Forward_time = {:.2lf} [ms]", + per_device_state, + input, + output); } -static void forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - forward_task_impl(acc); -} - -static optional backward_task_impl(TaskArgumentAccessor const &acc) { +static std::optional + backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto per_device_state = - acc.get_per_device_state(PER_DEVICE_STATE); + acc.get_argument(PER_DEVICE_STATE); - auto input_grad = acc.get_tensor_grad(INPUT); - auto output_grad = acc.get_tensor_grad(OUTPUT); + auto input_grad = acc.get_tensor_grad(INPUT); + auto output_grad = acc.get_tensor_grad(OUTPUT); - return profiling(backward_kernel, - profiling, - "[Transpose] Backward_time = %.2lf [ms]", - per_device_state, - input_grad, - output_grad); + return profile(backward_kernel, + profiling, + "[Transpose] Backward_time = {:.2lf} [ms]", + per_device_state, + input_grad, + output_grad); } -static void backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - TaskArgumentAccessor acc(task, regions, ctx, runtime); - backward_task_impl(acc); -} - -OpTaskInvocation backward(TransposeAttrs const &) { +OpTaskInvocation backward(TransposeAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); return {TRANSPOSE_BWD_TASK_ID, binding}; @@ -159,7 +120,7 @@ CostMetrics &input_descs, // Note:this may have some problem ProfilingSettings const &settings, MachineView const &machine_view) { - auto env = sim.new_environment(); + auto env = sim_factory.new_environment(); SimTaskBinding init_binding; init_binding.bind_arg(ATTRS, attrs); @@ -169,12 +130,13 @@ CostMetrics DeviceSpecific per_device_state = init_task_impl(init_accessor); - ParallelTensorShape output_shape = get_output_shape(attrs, input_descs.shape); + ParallelTensorShape output_shape = + get_output_shape(attrs, input_descs.shapes); SimTaskBinding fwd_binding; fwd_binding.bind_arg(PER_DEVICE_STATE, per_device_state); fwd_binding.bind_arg(PROFILING, settings); - fwd_binding.bind(INPUT, input_descs.shape); + fwd_binding.bind(INPUT, input_descs.shapes); fwd_binding.bind(OUTPUT, output_shape); auto fwd_accessor = env.get_fwd_accessor(TRANSPOSE_FWD_TASK_ID, fwd_binding); @@ -189,4 +151,4 @@ CostMetrics return make_metrics(forward_time, backward_time, sync_time, env); } -}; // namespace FlexFlow +} // namespace FlexFlow diff --git a/lib/runtime/src/ops/transpose.h b/lib/local-execution/src/ops/transpose.h similarity index 97% rename from lib/runtime/src/ops/transpose.h rename to lib/local-execution/src/ops/transpose.h index 52e824ebbf..daa64e8e59 100644 --- a/lib/runtime/src/ops/transpose.h +++ b/lib/local-execution/src/ops/transpose.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_TRANSPOSE_H_ #define _FLEXFLOW_TRANSPOSE_H_ +#include "local-execution/op_task_invocation.h" +#include "local-execution/sim_environment.h" #include "op-attrs/ops/transpose.h" -#include "sim_environment.h" -#include "task_spec/op_task_invocation.h" namespace FlexFlow { diff --git a/lib/runtime/src/permissions.cc b/lib/local-execution/src/permissions.cc similarity index 66% rename from lib/runtime/src/permissions.cc rename to lib/local-execution/src/permissions.cc index 2992780ae1..e5c46b42f8 100644 --- a/lib/runtime/src/permissions.cc +++ b/lib/local-execution/src/permissions.cc @@ -1,38 +1,8 @@ -#include "permissions.h" +#include "local-execution/permissions.h" #include "utils/exception.h" namespace FlexFlow { -Legion::PrivilegeMode to_legion(Permissions p) { - switch (p) { - case Permissions::NONE: - return LEGION_NO_ACCESS; - case Permissions::RO: - return LEGION_READ_ONLY; - case Permissions::WO: - return LEGION_WRITE_ONLY; - case Permissions::RW: - return LEGION_READ_WRITE; - default: - throw mk_runtime_error("Unknown permission {}", static_cast(p)); - } -} - -optional from_legion(Legion::PrivilegeMode p) { - switch (p) { - case LEGION_NO_ACCESS: - return Permissions::NONE; - case LEGION_READ_ONLY: - return Permissions::RO; - case LEGION_WRITE_ONLY: - return Permissions::WO; - case LEGION_READ_WRITE: - return Permissions::RW; - default: - return nullopt; - } -} - Permissions join(Permissions lhs, Permissions rhs) { if (lhs <= rhs) { return rhs; diff --git a/lib/runtime/src/task_spec/runtime_arg_ref.cc b/lib/local-execution/src/runtime_arg_ref.cc similarity index 55% rename from lib/runtime/src/task_spec/runtime_arg_ref.cc rename to lib/local-execution/src/runtime_arg_ref.cc index a0aa242ce6..df4f024f1d 100644 --- a/lib/runtime/src/task_spec/runtime_arg_ref.cc +++ b/lib/local-execution/src/runtime_arg_ref.cc @@ -1,5 +1,5 @@ -#include "runtime_arg_ref.h" -#include "device_specific.h" +#include "local-execution/runtime_arg_ref.h" +#include "local-execution/device_specific.h" namespace FlexFlow { @@ -11,4 +11,8 @@ RuntimeArgRef> ff_handle() { return {RuntimeArgRefType::FF_HANDLE}; } +RuntimeArgRef> iteration_config() { + return {RuntimeArgRefType::FF_ITERATION_CONFIG}; +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/tracked_allocator.cc b/lib/local-execution/src/tracked_allocator.cc index 6d06714252..68636906c3 100644 --- a/lib/local-execution/src/tracked_allocator.cc +++ b/lib/local-execution/src/tracked_allocator.cc @@ -1,4 +1,4 @@ -#include "tracked_allocator.h" +#include "local-execution/tracked_allocator.h" #include "kernels/device.h" namespace FlexFlow { diff --git a/lib/local-execution/src/variadic_tensor_ref.cc b/lib/local-execution/src/variadic_tensor_ref.cc new file mode 100644 index 0000000000..efd43a6648 --- /dev/null +++ b/lib/local-execution/src/variadic_tensor_ref.cc @@ -0,0 +1,9 @@ +#include "local-execution/variadic_tensor_ref.h" + +namespace FlexFlow { + +VariadicTensorRef get_input_tensors() { + return {VariadicTensorRefType::INPUT_TENSORS}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/CMakeLists.txt b/lib/op-attrs/CMakeLists.txt index 778be53d7c..9a9721ef2d 100644 --- a/lib/op-attrs/CMakeLists.txt +++ b/lib/op-attrs/CMakeLists.txt @@ -12,3 +12,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/op-attrs/include/op-attrs/activation.dtg.h b/lib/op-attrs/include/op-attrs/activation.dtg.h new file mode 100644 index 0000000000..a4c0e97882 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class Activation { RELU, SIGMOID, TANH, GELU }; +std::string format_as(Activation); +std::ostream &operator<<(std::ostream &, Activation); +void to_json(::nlohmann::json &, Activation); +void from_json(::nlohmann::json const &, Activation &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Activation) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_ACTIVATION_DTG_H diff --git a/lib/op-attrs/include/op-attrs/activation.enum.toml b/lib/op-attrs/include/op-attrs/activation.enum.toml new file mode 100644 index 0000000000..66119da9b1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/activation.enum.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "Activation" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "RELU" + +[[values]] +name = "SIGMOID" + +[[values]] +name = "TANH" + +[[values]] +name = "GELU" diff --git a/lib/op-attrs/include/op-attrs/activation.h b/lib/op-attrs/include/op-attrs/activation.h deleted file mode 100644 index 8fa07825fd..0000000000 --- a/lib/op-attrs/include/op-attrs/activation.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_ACTIVATION_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class Activation { RELU, SIGMOID, TANH, GELU }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::Activation> : formatter { - template - auto format(::FlexFlow::Activation a, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (a) { - case Activation::RELU: - name = "ReLU"; - break; - case Activation::SIGMOID: - name = "Sigmoid"; - break; - case Activation::TANH: - name = "Tanh"; - break; - case Activation::GELU: - name = "GeLU"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h new file mode 100644 index 0000000000..3ff3848dca --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class AggregateOp { SUM }; +std::string format_as(AggregateOp); +std::ostream &operator<<(std::ostream &, AggregateOp); +void to_json(::nlohmann::json &, AggregateOp); +void from_json(::nlohmann::json const &, AggregateOp &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AggregateOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AGGREGATE_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml new file mode 100644 index 0000000000..27aa50f38f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/aggregate_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "AggregateOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[value]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h new file mode 100644 index 0000000000..d92557c2f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/as_dot.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &); +RecordFormatter as_dot(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h new file mode 100644 index 0000000000..cc45628145 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.dtg.h @@ -0,0 +1,471 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "cc0ab49405423594ffa1d8f541235a48" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct ComputationGraphOpAttrs { + ComputationGraphOpAttrs() = delete; + explicit ComputationGraphOpAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BatchNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::BroadcastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ConcatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Conv2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::DropoutAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::GatherAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::InputAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LayerNormAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::LinearAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::Pool2DAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReduceAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReverseAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::ReshapeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SplitAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::TransposeAttrs const &); + explicit ComputationGraphOpAttrs(::FlexFlow::WeightAttrs const &); + template + static constexpr bool IsPartOfComputationGraphOpAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::BroadcastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::WeightAttrs>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ComputationGraphOpAttrs", + this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfComputationGraphOpAttrs_v, + "ComputationGraphOpAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::BroadcastAttrs, ::FlexFlow::CastAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReverseAttrs, ::FlexFlow::ReshapeAttrs, " + "::FlexFlow::SplitAttrs, ::FlexFlow::SoftmaxAttrs, " + "::FlexFlow::TopKAttrs, ::FlexFlow::TransposeAttrs, " + "::FlexFlow::WeightAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(ComputationGraphOpAttrs const &) const; + bool operator!=(ComputationGraphOpAttrs const &) const; + bool operator<(ComputationGraphOpAttrs const &) const; + bool operator>(ComputationGraphOpAttrs const &) const; + bool operator<=(ComputationGraphOpAttrs const &) const; + bool operator>=(ComputationGraphOpAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::BroadcastAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs, + ::FlexFlow::WeightAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::ComputationGraphOpAttrs> { + size_t operator()(::FlexFlow::ComputationGraphOpAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::ComputationGraphOpAttrs> { + static ::FlexFlow::ComputationGraphOpAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::ComputationGraphOpAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ComputationGraphOpAttrs> { + static Gen<::FlexFlow::ComputationGraphOpAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::ComputationGraphOpAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h new file mode 100644 index 0000000000..4be17798f7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml new file mode 100644 index 0000000000..bb25514e1d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -0,0 +1,148 @@ +namespace = "FlexFlow" +name = "ComputationGraphOpAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", + "op-attrs/ops/weight_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" + +[[values]] +type = "::FlexFlow::WeightAttrs" +key = "weight" diff --git a/lib/op-attrs/include/op-attrs/datatype.dtg.h b/lib/op-attrs/include/op-attrs/datatype.dtg.h new file mode 100644 index 0000000000..7052dba3b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; +std::string format_as(DataType); +std::ostream &operator<<(std::ostream &, DataType); +void to_json(::nlohmann::json &, DataType); +void from_json(::nlohmann::json const &, DataType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DataType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DATATYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/datatype.enum.toml b/lib/op-attrs/include/op-attrs/datatype.enum.toml new file mode 100644 index 0000000000..15210cfe29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype.enum.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DataType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "BOOL" + +[[values]] +name = "INT32" + +[[values]] +name = "INT64" + +[[values]] +name = "HALF" + +[[values]] +name = "FLOAT" + +[[values]] +name = "DOUBLE" diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 643fe44c41..a435c1bc12 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -1,14 +1,13 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_DATATYPE_H +#include "op-attrs/datatype.dtg.h" #include "utils/fmt.h" #include "utils/fp16.h" #include namespace FlexFlow { -enum class DataType { BOOL, INT32, INT64, HALF, FLOAT, DOUBLE }; - template struct data_type_enum_to_class; @@ -54,46 +53,11 @@ using DataTypeValue = std::variant, real_type, real_type, real_type, - real_type, + /* real_type, */ real_type>; size_t size_of_datatype(DataType); } // namespace FlexFlow -namespace fmt { -template <> -struct formatter<::FlexFlow::DataType> : formatter { - template - auto format(::FlexFlow::DataType dt, FormatContext &ctx) - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (dt) { - case DataType::BOOL: - name = "BOOL"; - break; - case DataType::INT32: - name = "INT32"; - break; - case DataType::INT64: - name = "INT64"; - break; - case DataType::HALF: - name = "HALF"; - break; - case DataType::FLOAT: - name = "FLOAT"; - break; - case DataType::DOUBLE: - name = "DOUBLE"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - #endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index b726d0687f..dbc237a03d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -1,7 +1,8 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/json.h" #include "utils/stack_vector.h" namespace FlexFlow { @@ -28,11 +29,19 @@ struct DimOrdered { : contents(contents.begin(), contents.end()) {} T const &at(Idx idx) const { - return this->contents.at(idx.value()); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T &at(Idx idx) { - return this->contents.at(idx.value()); + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); } T const &operator[](Idx idx) const { @@ -43,6 +52,14 @@ struct DimOrdered { return this->at(idx); } + bool idx_is_valid(Idx const &idx) const { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return (raw >= 0 && raw < this->contents.size()); + } + bool operator==(DimOrdered const &other) const { return this->contents == other.contents; } @@ -133,6 +150,17 @@ struct DimOrdered { template using FFOrdered = DimOrdered; +template +std::string format_as(FFOrdered const &v) { + std::vector as_vec(v.cbegin(), v.cend()); + return fmt::format("", as_vec); +} + +template +std::ostream &operator<<(std::ostream &s, FFOrdered const &v) { + return (s << fmt::to_string(v)); +} + template auto inner_to_outer(FFOrdered const &ff_ordered) -> decltype(reversed_container(ff_ordered)) { @@ -160,6 +188,29 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { } // namespace FlexFlow +/* template */ +/* void to_json(json &j, DimOrdered const &x) { */ +/* /1* j = std::vector{x.cbegin(), x.cend()}; *1/ */ +/* } */ + +/* template */ +/* void from_json(json const &j, DimOrdered &x) { */ +/* /1* x = DimOrdered{j.template get>()}; *1/ */ +/* } */ + +namespace nlohmann { +template +struct adl_serializer<::FlexFlow::DimOrdered> { + static ::FlexFlow::DimOrdered from_json(json const &j) { + return {j.template get>()}; + } + + static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + j = std::vector{x.cbegin(), x.cend()}; + } +}; +} // namespace nlohmann + namespace std { template @@ -174,4 +225,16 @@ struct hash<::FlexFlow::DimOrdered> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::DimOrdered> { + static Gen<::FlexFlow::DimOrdered> arbitrary() { + return gen::construct<::FlexFlow::DimOrdered>( + gen::arbitrary<::FlexFlow::stack_vector>()); + } +}; + +} // namespace rc + #endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h new file mode 100644 index 0000000000..4d6e82b71b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +DimOrdered nonoverloaded_slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { + auto to_raw_idx = [](std::optional const &idx) -> std::optional { + return transform(idx, [](Idx const &i) { return i.value; }); + }; + + return DimOrdered{ + subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; +} + +template +DimOrdered slice(DimOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + +template +DimOrdered slice(DimOrdered const &d, + std::nullopt_t const &start, + Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered slice(DimOrdered const &d, + Idx const &start, + std::nullopt_t const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +DimOrdered + slice(DimOrdered const &d, Idx const &start, Idx const &end) { + return nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h new file mode 100644 index 0000000000..880f13b4d4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers.h" +#include "utils/containers/vector_transform.h" + +namespace FlexFlow { + +template +DimOrdered> + transform(DimOrdered const &d, F f) { + using Out = std::invoke_result_t; + + return DimOrdered{vector_transform(as_vector(d), f)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim.dtg.h b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h new file mode 100644 index 0000000000..1697f78196 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.dtg.h @@ -0,0 +1,54 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct ff_dim_t { + ff_dim_t() = delete; + ff_dim_t(int const &value); + + bool operator==(ff_dim_t const &) const; + bool operator!=(ff_dim_t const &) const; + bool operator<(ff_dim_t const &) const; + bool operator>(ff_dim_t const &) const; + bool operator<=(ff_dim_t const &) const; + bool operator>=(ff_dim_t const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ff_dim_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ff_dim_t from_json(json const &); + static void to_json(json &, FlexFlow::ff_dim_t const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &); +std::ostream &operator<<(std::ostream &, ff_dim_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h index be1f148a70..e78ce4b51e 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ b/lib/op-attrs/include/op-attrs/ff_dim.h @@ -1,18 +1,18 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H -#define _FLEXFLOW_OPATTRS_INCLUDE_FF_DIM_H -#include "utils/strong_typedef.h" -#include +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -namespace FlexFlow { +#include "op-attrs/ff_dim.dtg.h" +#include "rapidcheck.h" -struct ff_dim_t : public numerical_typedef { - using numerical_typedef::numerical_typedef; +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); + } }; +} // namespace rc -} // namespace FlexFlow - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::ff_dim_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::ff_dim_t, "ff_dim"); - -#endif +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml new file mode 100644 index 0000000000..441f9826ca --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index a2db4ab5f0..39541aa8d6 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -1,8 +1,37 @@ #ifndef _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H #define _FLEXFLOW_OP_ATTRS_GET_OP_TYPE_H -#include "operator_attrs.h" -#include "utils/variant.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/ops/weight_attrs.dtg.h" namespace FlexFlow { @@ -32,23 +61,13 @@ OperatorType get_op_type(SplitAttrs const &); OperatorType get_op_type(SoftmaxAttrs const &); OperatorType get_op_type(TopKAttrs const &); OperatorType get_op_type(TransposeAttrs const &); +OperatorType get_op_type(WeightAttrs const &); + OperatorType get_op_type(CombineAttrs const &); OperatorType get_op_type(ReductionAttrs const &); OperatorType get_op_type(RepartitionAttrs const &); OperatorType get_op_type(ReplicateAttrs const &); -struct GetOpTypeFunctor { - template - OperatorType operator()(T const &t) { - return get_op_type(t); - } -}; - -template -OperatorType get_op_type(std::variant const &attrs) { - return visit(GetOpTypeFunctor{}, attrs); -} - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 6fb93aac91..a826e1cb54 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -112,28 +112,14 @@ std::vector get_output_shapes(Attrs const &attrs, ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(CastAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(CombineAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ConcatAttrs const &, std::vector const &); ParallelTensorShape get_output_shape(Conv2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementUnaryAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ElementScalarUnaryAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(EmbeddingAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(GatherAttrs const &, @@ -141,16 +127,10 @@ std::vector get_output_shapes(GatherAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(LinearAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReductionAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(RepartitionAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReplicateAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h new file mode 100644 index 0000000000..1d4747db7e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct L1RegularizerAttrs { + L1RegularizerAttrs() = delete; + L1RegularizerAttrs(float const &lambda); + + bool operator==(L1RegularizerAttrs const &) const; + bool operator!=(L1RegularizerAttrs const &) const; + bool operator<(L1RegularizerAttrs const &) const; + bool operator>(L1RegularizerAttrs const &) const; + bool operator<=(L1RegularizerAttrs const &) const; + bool operator>=(L1RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L1RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L1RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L1RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L1RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L1_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..60fabfb94a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L1RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h new file mode 100644 index 0000000000..981d3f4905 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct L2RegularizerAttrs { + L2RegularizerAttrs() = delete; + L2RegularizerAttrs(float const &lambda); + + bool operator==(L2RegularizerAttrs const &) const; + bool operator!=(L2RegularizerAttrs const &) const; + bool operator<(L2RegularizerAttrs const &) const; + bool operator>(L2RegularizerAttrs const &) const; + bool operator<=(L2RegularizerAttrs const &) const; + bool operator>=(L2RegularizerAttrs const &) const; + float lambda; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::L2RegularizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::L2RegularizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::L2RegularizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, L2RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_L2_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml new file mode 100644 index 0000000000..adce4397a4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "L2RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lambda" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/op.h b/lib/op-attrs/include/op-attrs/op.h deleted file mode 100644 index 9ad83c3641..0000000000 --- a/lib/op-attrs/include/op-attrs/op.h +++ /dev/null @@ -1,369 +0,0 @@ -#ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H -#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_OP_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class Op { - NOOP, - INPUT, - WEIGHT, - CONV2D, - DROPOUT, - LINEAR, - BATCHMATMUL, - POOL2D, - SCALAR_MULTIPLY, - SCALAR_ADD, - SCALAR_FLOOR_DIV, - SCALAR_TRUE_DIV, - SCALAR_SUB, - RELU, - IDENTITY, - SIGMOID, - TANH, - ELU, - FLAT, - SOFTMAX, - BATCHNORM, - CONCAT, - SPLIT, - EMBEDDING, - CACHE, - // OP_ELEMENTWISE, - RESHAPE, - REVERSE, - TRANSPOSE, - EW_ADD, - EW_MUL, - MATMUL, - MUL, - ENLARGE, - SQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Squeeze - UNSQUEEZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze - EW_SUB, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sub - EW_DIV, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Div - EW_EQUAL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Equal - EW_GREATER, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Greater - EW_LESS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Less - EW_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Max - EW_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Min - REDUCE_ARGMAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMax - REDUCE_ARGMIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ArgMin - REDUCE_MAX, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMax - REDUCE_MEAN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMean - REDUCE_MIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceMin - REDUCE_PROD, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceProd - REDUCE_SUM, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#ReduceSum - PAD, // https://github.com/dmlc/tvm/blob/master/topi/python/topi/nn/pad.py - SHAPE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Shape - SIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Size - TOPK, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#TopK - WHERE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where - CEIL, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Ceil - CAST, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cast - EXP, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Exp - ROUND, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Round - LOG, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Log - LOGICAL_NOT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Not - SQRT, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sqrt - SIN, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sin - COS, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cos - LEAKYRELU, - SLICE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice - RESIZE, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Resize - PRELU, // https://github.com/onnx/onnx/blob/master/docs/Operators.md#PRelu - GELU, - MULTIHEAD_ATTENTION, - FUSED, // Fused operator type for internal fusion optimizations - RSQRT, // https://pytorch.org/docs/stable/generated/torch.rsqrt.html - POW, // https://pytorch.org/docs/stable/generated/torch.pow.html - MEAN, // https://pytorch.org/docs/stable/generated/torch.mean.html - LAYERNORM, - GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html - BROADCAST, - // Parallel Ops - REPARTITION, - COMBINE, - REPLICATE, - REDUCTION, - BATCH, - PIPELINE, - FUSED_PARALLEL, -}; - -using OperatorType = Op; - -std::string get_operator_type_name(Op op); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::Op> : formatter { - template - auto format(::FlexFlow::Op ot, FormatContext &ctx) -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ot) { - case Op::CONV2D: - name = "Conv2D"; - break; - case Op::DROPOUT: - name = "Dropout"; - break; - case Op::LINEAR: - name = "Dense"; - break; - case Op::BATCHMATMUL: - name = "BatchMatMul"; - break; - case Op::POOL2D: - name = "Pool2D"; - break; - case Op::SCALAR_MULTIPLY: - name = "ScalarMultiply"; - break; - case Op::SCALAR_ADD: - name = "ScalarAdd"; - break; - case Op::SCALAR_FLOOR_DIV: - name = "ScalarFloorDiv"; - break; - case Op::SCALAR_TRUE_DIV: - name = "ScalarTrueDiv"; - break; - case Op::SCALAR_SUB: - name = "ScalarSub"; - break; - case Op::RELU: - name = "ReLU"; - break; - case Op::SIGMOID: - name = "Sigmoid"; - break; - case Op::TANH: - name = "Tanh"; - break; - case Op::ELU: - name = "Elu"; - break; - case Op::FLAT: - name = "Flat"; - break; - case Op::SOFTMAX: - name = "Softmax"; - break; - case Op::BATCHNORM: - name = "BatchNorm"; - break; - case Op::CONCAT: - name = "Concat"; - break; - case Op::SPLIT: - name = "Split"; - break; - case Op::EMBEDDING: - name = "Embedding"; - break; - case Op::GATHER: - name = "Gather"; - break; - case Op::CACHE: - name = "Cache"; - break; - case Op::RESHAPE: - name = "Reshape"; - break; - case Op::REVERSE: - name = "Reverse"; - break; - case Op::TRANSPOSE: - name = "Transpose"; - break; - case Op::EW_ADD: - name = "Add"; - break; - case Op::EW_MUL: - name = "Mul"; - break; - case Op::MATMUL: - name = "Matmul"; - break; - case Op::MUL: - name = "Mul"; - break; - case Op::ENLARGE: - name = "Enlarge"; - break; - case Op::SQUEEZE: - name = "Squeeze"; - break; - case Op::UNSQUEEZE: - name = "Unsqueeze"; - break; - case Op::EW_SUB: - name = "Sub"; - break; - case Op::EW_DIV: - name = "Div"; - break; - case Op::EW_EQUAL: - name = "Equal"; - break; - case Op::EW_GREATER: - name = "Greater"; - break; - case Op::EW_LESS: - name = "Less"; - break; - case Op::EW_MAX: - name = "Max"; - break; - case Op::EW_MIN: - name = "Min"; - break; - case Op::REDUCE_ARGMAX: - name = "ReduceArgMax"; - break; - case Op::REDUCE_ARGMIN: - name = "ReduceArgMin"; - break; - case Op::REDUCE_MAX: - name = "ReduceMax"; - break; - case Op::REDUCE_MEAN: - name = "ReduceMean"; - break; - case Op::REDUCE_MIN: - name = "ReduceMin"; - break; - case Op::REDUCE_PROD: - name = "ReduceProd"; - break; - case Op::REDUCE_SUM: - name = "ReduceSum"; - break; - case Op::PAD: - name = "Pad"; - break; - case Op::SHAPE: - name = "Shape"; - break; - case Op::SIZE: - name = "Size"; - break; - case Op::TOPK: - name = "TopK"; - break; - case Op::WHERE: - name = "Where"; - break; - case Op::CEIL: - name = "Ceil"; - break; - case Op::CAST: - name = "Cast"; - break; - case Op::EXP: - name = "Exp"; - break; - case Op::SIN: - name = "Sin"; - break; - case Op::COS: - name = "Cos"; - break; - case Op::ROUND: - name = "Round"; - break; - case Op::LOG: - name = "Log"; - break; - case Op::LOGICAL_NOT: - name = "LogicalNot"; - break; - case Op::SQRT: - name = "Sqrt"; - break; - case Op::LEAKYRELU: - name = "LeakyReLU"; - break; - case Op::SLICE: - name = "Slice"; - break; - case Op::RESIZE: - name = "Resize"; - break; - case Op::PRELU: - name = "PReLU"; - break; - case Op::MULTIHEAD_ATTENTION: - name = "MultiHeadAttention"; - break; - case Op::INPUT: - name = "Input"; - break; - case Op::WEIGHT: - name = "Weight"; - break; - case Op::NOOP: - name = "NoOp"; - break; - case Op::FUSED: - name = "FusedOp"; - break; - case Op::RSQRT: - name = "Rsqrt"; - break; - case Op::POW: - name = "Pow"; - break; - case Op::MEAN: - name = "Mean"; - break; - case Op::LAYERNORM: - name = "LayerNorm"; - break; - case Op::IDENTITY: - name = "Identity"; - break; - // Parallel Ops - case Op::REPARTITION: - name = "Repartition"; - break; - case Op::COMBINE: - name = "Combine"; - break; - case Op::REPLICATE: - name = "Replicate"; - break; - case Op::REDUCTION: - name = "Reduction"; - break; - case Op::PIPELINE: - name = "Pipeline"; - break; - case Op::FUSED_PARALLEL: - name = "FusedParallelOp"; - break; - case Op::GELU: - name = "GeLU"; - break; - case Op::BROADCAST: - name = "Broadcast"; - break; - case Op::BATCH: - name = "Batch"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index b63563cd67..268554b5be 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -2,6 +2,7 @@ #define _OPERATOR_PARAMS_H #include "op-attrs/ops/core.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "ops/attention.h" #include "ops/batch_matmul.h" #include "ops/batch_norm.h" @@ -31,102 +32,16 @@ #include "ops/split.h" #include "ops/topk.h" #include "ops/transpose.h" +#include "utils/record_formatter.h" #include "utils/variant.h" #include namespace FlexFlow { -using SharedOperatorAttrs = std::variant; - -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); -static_assert(is_valid_opattr::value, ""); - -using ParallelOperatorAttrs = std:: - variant; - -using ComputationGraphAttrs = - variant_join>; -using CompGraphOperatorAttrs = ComputationGraphAttrs; - -using PCGOperatorAttrs = - variant_join; - -static_assert(is_equal_comparable::value, - "ComputationGraphAttrs must support =="); -static_assert(elements_satisfy::value, - ""); -static_assert(is_neq_comparable::value, - "ComputationGraphAttrs must support !="); -static_assert(is_lt_comparable::value, - "ComputationGraphAttrs must support <"); -static_assert(is_hashable::value, - "ComputationGraphAttrs must be hashable"); - -static_assert(is_equal_comparable::value, - "PCGOperatorAttrs must support =="); -static_assert(is_neq_comparable::value, - "PCGOperatorAttrs must support !="); -static_assert(is_lt_comparable::value, - "PCGOperatorAttrs must support <"); -static_assert(is_hashable::value, - "PCGOperatorAttrs must be hashable"); - -/* OperatorType get_op_type(CompGraphOperatorAttrs const &); */ -/* OperatorType get_op_type(PCGOperatorAttrs const &); */ - -RecordFormatter as_dot(CompGraphOperatorAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); - std::vector get_output_shapes( PCGOperatorAttrs const &op_params, std::vector const &input_tensor_shapes); -bool is_parallel_op(PCGOperatorAttrs const &); bool is_valid(PCGOperatorAttrs const &, std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/operator_type.dtg.h b/lib/op-attrs/include/op-attrs/operator_type.dtg.h new file mode 100644 index 0000000000..3b4bd86552 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.dtg.h @@ -0,0 +1,124 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(OperatorType); +std::ostream &operator<<(std::ostream &, OperatorType); +void to_json(::nlohmann::json &, OperatorType); +void from_json(::nlohmann::json const &, OperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/operator_type.enum.toml b/lib/op-attrs/include/op-attrs/operator_type.enum.toml new file mode 100644 index 0000000000..8815d69dda --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "OperatorType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP" }, + { name = "INPUT" }, + { name = "WEIGHT" }, + { name = "CONV2D" }, + { name = "DROPOUT" }, + { name = "LINEAR" }, + { name = "BATCHMATMUL" }, + { name = "POOL2D" }, + { name = "SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB" }, + { name = "RELU" }, + { name = "IDENTITY" }, + { name = "SIGMOID" }, + { name = "TANH" }, + { name = "ELU" }, + { name = "FLAT" }, + { name = "SOFTMAX" }, + { name = "BATCHNORM" }, + { name = "CONCAT" }, + { name = "SPLIT" }, + { name = "EMBEDDING" }, + { name = "CACHE" }, + { name = "RESHAPE" }, + { name = "REVERSE" }, + { name = "TRANSPOSE" }, + { name = "EW_ADD" }, + { name = "EW_MUL" }, + { name = "MATMUL" }, + { name = "MUL" }, + { name = "ENLARGE" }, + { name = "SQUEEZE" }, + { name = "UNSQUEEZE" }, + { name = "EW_SUB" }, + { name = "EW_DIV" }, + { name = "EW_EQUAL" }, + { name = "EW_GREATER" }, + { name = "EW_LESS" }, + { name = "EW_MAX" }, + { name = "EW_MIN" }, + { name = "REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN" }, + { name = "REDUCE_MAX" }, + { name = "REDUCE_MEAN" }, + { name = "REDUCE_MIN" }, + { name = "REDUCE_PROD" }, + { name = "REDUCE_SUM" }, + { name = "PAD" }, + { name = "SHAPE" }, + { name = "SIZE" }, + { name = "TOPK" }, + { name = "WHERE" }, + { name = "CEIL" }, + { name = "CAST" }, + { name = "EXP" }, + { name = "ROUND" }, + { name = "LOG" }, + { name = "LOGICAL_NOT" }, + { name = "SQRT" }, + { name = "SIN" }, + { name = "COS" }, + { name = "LEAKYRELU" }, + { name = "SLICE" }, + { name = "RESIZE" }, + { name = "PRELU" }, + { name = "GELU" }, + { name = "MULTIHEAD_ATTENTION" }, + { name = "FUSED" }, + { name = "RSQRT" }, + { name = "POW" }, + { name = "MEAN" }, + { name = "LAYERNORM" }, + { name = "GATHER" }, + { name = "BROADCAST" }, + { name = "REPARTITION" }, + { name = "COMBINE" }, + { name = "REPLICATE" }, + { name = "REDUCTION" }, + { name = "BATCH" }, + { name = "PIPELINE" }, + { name = "FUSED_PARALLEL" }, +] + diff --git a/lib/op-attrs/include/op-attrs/operator_type.h b/lib/op-attrs/include/op-attrs/operator_type.h new file mode 100644 index 0000000000..4750af51ee --- /dev/null +++ b/lib/op-attrs/include/op-attrs/operator_type.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPERATOR_TYPE_H + +#include "op-attrs/operator_type.dtg.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType); +bool is_parallel_op(OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index ec3e592607..8233775e63 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -2,73 +2,62 @@ #define _FLEXFLOW_ATTENTION_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct MultiHeadAttentionAttrs { - req embed_dim, num_heads, kdim, vdim; - req dropout; - req bias, add_bias_kv, add_zero_attn; -}; -FF_VISITABLE_STRUCT(MultiHeadAttentionAttrs, - embed_dim, - num_heads, - kdim, - vdim, - dropout, - bias, - add_bias_kv, - add_zero_attn); - -template -struct MultiHeadAttentionInputs - : public use_visitable_cmp> { -public: - MultiHeadAttentionInputs() = delete; - - MultiHeadAttentionInputs(TensorType const &query, - TensorType const &key, - TensorType const &value) - : query(query), key(key), value(value) {} - - template - MultiHeadAttentionInputs(MultiHeadAttentionInputs const &sub) - : query(sub.query), key(sub.key), value(sub.value) {} - -public: - TensorType query; - TensorType key; - TensorType value; -}; - int get_qProjSize(MultiHeadAttentionAttrs const &); int get_vProjSize(MultiHeadAttentionAttrs const &); int get_kProjSize(MultiHeadAttentionAttrs const &); int get_oProjSize(MultiHeadAttentionAttrs const &); -int get_qSize(MultiHeadAttentionInputs const &); -int get_kSize(MultiHeadAttentionInputs const &); -int get_vSize(MultiHeadAttentionInputs const &); +int get_qSize(MultiHeadAttentionParallelInputs const &); +int get_qSize(MultiHeadAttentionInputs const &); + +int get_kSize(MultiHeadAttentionParallelInputs const &); +int get_kSize(MultiHeadAttentionInputs const &); + +int get_vSize(MultiHeadAttentionParallelInputs const &); +int get_vSize(MultiHeadAttentionInputs const &); + int get_oSize(ParallelTensorShape const &); +int get_oSize(TensorShape const &); + +int get_qoSeqLength(MultiHeadAttentionParallelInputs const &); +int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_qoSeqLength(MultiHeadAttentionInputs const &); -int get_kvSeqLength(MultiHeadAttentionInputs const &); +int get_kvSeqLength(MultiHeadAttentionParallelInputs const &); +int get_kvSeqLength(MultiHeadAttentionInputs const &); -int get_num_samples(MultiHeadAttentionInputs const &); +int get_num_samples(MultiHeadAttentionParallelInputs const &); +int get_num_samples(MultiHeadAttentionInputs const &); -TensorShape get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -ParallelTensorShape +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); -ParallelTensorShape +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); +tl::expected get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &); + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); CHECK_VALID_OP_ATTR(MultiHeadAttentionAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h new file mode 100644 index 0000000000..7b61305a1a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionInputs { + MultiHeadAttentionInputs() = delete; + MultiHeadAttentionInputs(size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionInputs const &) const; + bool operator!=(MultiHeadAttentionInputs const &) const; + bool operator<(MultiHeadAttentionInputs const &) const; + bool operator>(MultiHeadAttentionInputs const &) const; + bool operator<=(MultiHeadAttentionInputs const &) const; + bool operator>=(MultiHeadAttentionInputs const &) const; + size_t batch_size; + size_t sequence_length; + size_t query_size; + size_t key_size; + size_t value_size; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h new file mode 100644 index 0000000000..aed9f577ff --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml new file mode 100644 index 0000000000..b82b285451 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml @@ -0,0 +1,39 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "query_size" +type = "size_t" + +[[fields]] +name = "key_size" +type = "size_t" + +[[fields]] +name = "value_size" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h new file mode 100644 index 0000000000..297b1f8f1c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionParallelInputs { + MultiHeadAttentionParallelInputs() = delete; + MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(MultiHeadAttentionParallelInputs const &) const; + bool operator!=(MultiHeadAttentionParallelInputs const &) const; + bool operator<(MultiHeadAttentionParallelInputs const &) const; + bool operator>(MultiHeadAttentionParallelInputs const &) const; + bool operator<=(MultiHeadAttentionParallelInputs const &) const; + bool operator>=(MultiHeadAttentionParallelInputs const &) const; + ::FlexFlow::ShardParallelDim batch_dim; + ::FlexFlow::ShardParallelDim sequence_dim; + ::FlexFlow::ShardParallelDim query_dim; + ::FlexFlow::ShardParallelDim key_dim; + ::FlexFlow::ShardParallelDim value_dim; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionParallelInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionParallelInputs from_json(json const &); + static void to_json(json &, + FlexFlow::MultiHeadAttentionParallelInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &); +std::ostream &operator<<(std::ostream &, + MultiHeadAttentionParallelInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h new file mode 100644 index 0000000000..53cc3167f2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_MULTIHEAD_ATTENTION_PARALLEL_INPUTS_H + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include + +namespace FlexFlow { + +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml new file mode 100644 index 0000000000..b0636db353 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml @@ -0,0 +1,46 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionParallelInputs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", +] + +[[fields]] +name = "batch_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sequence_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "query_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "key_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "value_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h new file mode 100644 index 0000000000..18b2906759 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.dtg.h @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct MultiHeadAttentionAttrs { + MultiHeadAttentionAttrs() = delete; + MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn); + + bool operator==(MultiHeadAttentionAttrs const &) const; + bool operator!=(MultiHeadAttentionAttrs const &) const; + bool operator<(MultiHeadAttentionAttrs const &) const; + bool operator>(MultiHeadAttentionAttrs const &) const; + bool operator<=(MultiHeadAttentionAttrs const &) const; + bool operator>=(MultiHeadAttentionAttrs const &) const; + int embed_dim; + int num_heads; + int kdim; + int vdim; + float dropout; + bool bias; + bool add_bias_kv; + bool add_zero_attn; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MultiHeadAttentionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MultiHeadAttentionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::MultiHeadAttentionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &); +std::ostream &operator<<(std::ostream &, MultiHeadAttentionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ATTENTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml new file mode 100644 index 0000000000..d96d8af69c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "MultiHeadAttentionAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "embed_dim" +type = "int" + +[[fields]] +name = "num_heads" +type = "int" + +[[fields]] +name = "kdim" +type = "int" + +[[fields]] +name = "vdim" +type = "int" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "bias" +type = "bool" + +[[fields]] +name = "add_bias_kv" +type = "bool" + +[[fields]] +name = "add_zero_attn" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h new file mode 100644 index 0000000000..a8ab52d2b3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct BatchMatmulAttrs { + BatchMatmulAttrs() = delete; + BatchMatmulAttrs(int const &a_seq_length_dim, int const &b_seq_length_dim); + + bool operator==(BatchMatmulAttrs const &) const; + bool operator!=(BatchMatmulAttrs const &) const; + bool operator<(BatchMatmulAttrs const &) const; + bool operator>(BatchMatmulAttrs const &) const; + bool operator<=(BatchMatmulAttrs const &) const; + bool operator>=(BatchMatmulAttrs const &) const; + int a_seq_length_dim; + int b_seq_length_dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchMatmulAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchMatmulAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchMatmulAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &); +std::ostream &operator<<(std::ostream &, BatchMatmulAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index b05a5eb022..57760d1110 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -1,18 +1,26 @@ -#ifndef _FF_OP_META_BATCH_MATMUL_ATTRS_H -#define _FF_OP_META_BATCH_MATMUL_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H -#include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct BatchMatmulAttrs { - req a_seq_length_dim, b_seq_length_dim; -}; -FF_VISITABLE_STRUCT(BatchMatmulAttrs, a_seq_length_dim, b_seq_length_dim); +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); -CHECK_VALID_OP_ATTR(BatchMatmulAttrs); +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs); + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs); } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml new file mode 100644 index 0000000000..3b1dd3f687 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "BatchMatmulAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "a_seq_length_dim" +type = "int" + +[[fields]] +name = "b_seq_length_dim" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 4ec823d4ae..b9a1d87a75 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -2,17 +2,13 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #include "core.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct BatchNormAttrs { - req relu; -}; -FF_VISITABLE_STRUCT(BatchNormAttrs, relu); - -ParallelTensorShape get_output_shape(BatchNormAttrs const &); +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h new file mode 100644 index 0000000000..f153bfde7e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct BatchNormAttrs { + BatchNormAttrs() = delete; + BatchNormAttrs(bool const &relu); + + bool operator==(BatchNormAttrs const &) const; + bool operator!=(BatchNormAttrs const &) const; + bool operator<(BatchNormAttrs const &) const; + bool operator>(BatchNormAttrs const &) const; + bool operator<=(BatchNormAttrs const &) const; + bool operator>=(BatchNormAttrs const &) const; + bool relu; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BatchNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BatchNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BatchNormAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &); +std::ostream &operator<<(std::ostream &, BatchNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml new file mode 100644 index 0000000000..bc82f3c743 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "BatchNormAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "relu" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h new file mode 100644 index 0000000000..e4de3dcc75 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "12715c970e8416eacbd0750f338478e5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct BroadcastAttrs { + BroadcastAttrs() = delete; + BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims); + + bool operator==(BroadcastAttrs const &) const; + bool operator!=(BroadcastAttrs const &) const; + bool operator<(BroadcastAttrs const &) const; + bool operator>(BroadcastAttrs const &) const; + bool operator<=(BroadcastAttrs const &) const; + bool operator>=(BroadcastAttrs const &) const; + ::FlexFlow::stack_vector target_dims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::BroadcastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::BroadcastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::BroadcastAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &); +std::ostream &operator<<(std::ostream &, BroadcastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 433bf23241..ad44060400 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -1,18 +1,13 @@ -#ifndef _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H -#define _FLEXFLOW_INCLUDE_OPATTRS_OPS_BROADCAST_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H -#include "core.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct BroadcastAttrs { - req> target_dims; -}; -FF_VISITABLE_STRUCT(BroadcastAttrs, target_dims); - -CHECK_VALID_OP_ATTR(BroadcastAttrs); +ParallelTensorShape get_output_shape(BroadcastAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml new file mode 100644 index 0000000000..c87afa59b5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "BroadcastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", +] + +[[fields]] +name = "target_dims" +type = "::FlexFlow::stack_vector" diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index 63563f8df8..117dcb1e01 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CAST_ATTRS_H #include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { -struct CastAttrs { - req dtype; -}; -FF_VISITABLE_STRUCT(CastAttrs, dtype); - CHECK_VALID_OP_ATTR(CastAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h new file mode 100644 index 0000000000..33391eb221 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct CastAttrs { + CastAttrs() = delete; + CastAttrs(DataType const &dtype); + + bool operator==(CastAttrs const &) const; + bool operator!=(CastAttrs const &) const; + bool operator<(CastAttrs const &) const; + bool operator>(CastAttrs const &) const; + bool operator<=(CastAttrs const &) const; + bool operator>=(CastAttrs const &) const; + DataType dtype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CastAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CastAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CastAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(CastAttrs const &); +std::ostream &operator<<(std::ostream &, CastAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CAST_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml new file mode 100644 index 0000000000..6c12680ea1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "CastAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h" +] + +[[fields]] +name = "dtype" +type = "DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/combine.h b/lib/op-attrs/include/op-attrs/ops/combine.h index deaba9e093..d9b20fc2c5 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine.h +++ b/lib/op-attrs/include/op-attrs/ops/combine.h @@ -1,20 +1,18 @@ -#ifndef _FLEXFLOW_COMBINE_ATTRS_H -#define _FLEXFLOW_COMBINE_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct CombineAttrs { - ff_dim_t combine_dim; - req combine_degree; -}; -FF_VISITABLE_STRUCT(CombineAttrs, combine_dim, combine_degree); CHECK_VALID_OP_ATTR(CombineAttrs); +tl::expected + get_output_shape(CombineAttrs const &, ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h new file mode 100644 index 0000000000..43db204bc5 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "58fc5a388fd1a325ef4142094607e39a" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct CombineAttrs { + CombineAttrs() = delete; + CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree); + + bool operator==(CombineAttrs const &) const; + bool operator!=(CombineAttrs const &) const; + bool operator<(CombineAttrs const &) const; + bool operator>(CombineAttrs const &) const; + bool operator<=(CombineAttrs const &) const; + bool operator>=(CombineAttrs const &) const; + ::FlexFlow::ff_dim_t combine_dim; + int combine_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CombineAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::CombineAttrs from_json(json const &); + static void to_json(json &, FlexFlow::CombineAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(CombineAttrs const &); +std::ostream &operator<<(std::ostream &, CombineAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_COMBINE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml new file mode 100644 index 0000000000..585295fe1c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "CombineAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "combine_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "combine_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 78f848f18b..8a72708971 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -2,17 +2,10 @@ #define _FLEXFLOW_CONCAT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { -struct ConcatAttrs { - ff_dim_t axis; - req num_inputs; -}; -FF_VISITABLE_STRUCT(ConcatAttrs, axis, num_inputs); CHECK_VALID_OP_ATTR(ConcatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h new file mode 100644 index 0000000000..3c26473a4e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "68e0520b143e0579140a2f2cdd390759" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ConcatAttrs { + ConcatAttrs() = delete; + ConcatAttrs(::FlexFlow::ff_dim_t const &axis, int const &num_inputs); + + bool operator==(ConcatAttrs const &) const; + bool operator!=(ConcatAttrs const &) const; + bool operator<(ConcatAttrs const &) const; + bool operator>(ConcatAttrs const &) const; + bool operator<=(ConcatAttrs const &) const; + bool operator>=(ConcatAttrs const &) const; + ::FlexFlow::ff_dim_t axis; + int num_inputs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConcatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConcatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConcatAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &); +std::ostream &operator<<(std::ostream &, ConcatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml new file mode 100644 index 0000000000..4faa870bc4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ConcatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h" +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "num_inputs" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 79980d545d..7759380088 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -2,35 +2,26 @@ #define _FLEXFLOW_CONV_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct Conv2DAttrs { - int out_channels, kernel_h, kernel_w, stride_h, stride_w, padding_h, - padding_w, groups; - std::optional activation; - req use_bias; -}; - -FF_VISITABLE_STRUCT(Conv2DAttrs, - out_channels, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - groups, - activation, - use_bias); CHECK_VALID_OP_ATTR(Conv2DAttrs); -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &); -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &); +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &input); +TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &input); + +ParallelTensorShape get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); +ParallelTensorShape get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h new file mode 100644 index 0000000000..2e7833064c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h @@ -0,0 +1,72 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DInputShape { + Conv2DInputShape() = delete; + Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DInputShape const &) const; + bool operator!=(Conv2DInputShape const &) const; + bool operator<(Conv2DInputShape const &) const; + bool operator>(Conv2DInputShape const &) const; + bool operator<=(Conv2DInputShape const &) const; + bool operator>=(Conv2DInputShape const &) const; + size_t num_samples; + size_t num_channels; + size_t height; + size_t width; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h new file mode 100644 index 0000000000..043f5854ae --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml new file mode 100644 index 0000000000..77e8c51244 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml @@ -0,0 +1,35 @@ +namespace = "FlexFlow" +name = "Conv2DInputShape" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "num_samples" +type = "size_t" + +[[fields]] +name = "num_channels" +type = "size_t" + +[[fields]] +name = "height" +type = "size_t" + +[[fields]] +name = "width" +type = "size_t" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h new file mode 100644 index 0000000000..846c9e413a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct Conv2DParallelInputShape { + Conv2DParallelInputShape() = delete; + Conv2DParallelInputShape(::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype); + + bool operator==(Conv2DParallelInputShape const &) const; + bool operator!=(Conv2DParallelInputShape const &) const; + bool operator<(Conv2DParallelInputShape const &) const; + bool operator>(Conv2DParallelInputShape const &) const; + bool operator<=(Conv2DParallelInputShape const &) const; + bool operator>=(Conv2DParallelInputShape const &) const; + ::FlexFlow::ShardParallelDim sample_dim; + ::FlexFlow::ShardParallelDim channel_dim; + ::FlexFlow::ShardParallelDim height_dim; + ::FlexFlow::ShardParallelDim width_dim; + int sum_reduction_degree; + int discard_copy_reduction_degree; + ::FlexFlow::DataType datatype; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DParallelInputShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DParallelInputShape from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DParallelInputShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &); +std::ostream &operator<<(std::ostream &, Conv2DParallelInputShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h new file mode 100644 index 0000000000..accc64e751 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_CONV_2D_PARALLEL_INPUT_SHAPE_H + +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" + +namespace FlexFlow { + +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml new file mode 100644 index 0000000000..68cbd878d1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "Conv2DParallelInputShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "sample_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "channel_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "height_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "width_dim" +type = "::FlexFlow::ShardParallelDim" + +[[fields]] +name = "sum_reduction_degree" +type = "int" + +[[fields]] +name = "discard_copy_reduction_degree" +type = "int" + +[[fields]] +name = "datatype" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h new file mode 100644 index 0000000000..06827656da --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.dtg.h @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "rapidcheck.h" +#include "utils/json.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct Conv2DAttrs { + Conv2DAttrs() = delete; + Conv2DAttrs(int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias); + + bool operator==(Conv2DAttrs const &) const; + bool operator!=(Conv2DAttrs const &) const; + bool operator<(Conv2DAttrs const &) const; + bool operator>(Conv2DAttrs const &) const; + bool operator<=(Conv2DAttrs const &) const; + bool operator>=(Conv2DAttrs const &) const; + int out_channels; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + int groups; + std::optional<::FlexFlow::Activation> activation; + bool use_bias; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Conv2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Conv2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Conv2DAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &); +std::ostream &operator<<(std::ostream &, Conv2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONV_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml new file mode 100644 index 0000000000..353ef93004 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "Conv2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "op-attrs/activation.dtg.h", + "utils/json.h", +] + +fields = [ + { name = "out_channels", type = "int" }, + { name = "kernel_h", type = "int" }, + { name = "kernel_w", type = "int" }, + { name = "stride_h", type = "int" }, + { name = "stride_w", type = "int" }, + { name = "padding_h", type = "int" }, + { name = "padding_w", type = "int" }, + { name = "groups", type = "int" }, + { name = "activation", type = "std::optional<::FlexFlow::Activation>" }, + { name = "use_bias", type = "bool" }, +] diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index 8e0049f526..a0493301c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -2,16 +2,14 @@ #define _FLEXFLOW_DROPOUT_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct DropoutAttrs { - req rate; - req seed; -}; -FF_VISITABLE_STRUCT(DropoutAttrs, rate, seed); +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(DropoutAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h new file mode 100644 index 0000000000..ef86e49560 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct DropoutAttrs { + DropoutAttrs() = delete; + DropoutAttrs(float const &rate, unsigned long long const &seed); + + bool operator==(DropoutAttrs const &) const; + bool operator!=(DropoutAttrs const &) const; + bool operator<(DropoutAttrs const &) const; + bool operator>(DropoutAttrs const &) const; + bool operator<=(DropoutAttrs const &) const; + bool operator>=(DropoutAttrs const &) const; + float rate; + unsigned long long seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DropoutAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DropoutAttrs from_json(json const &); + static void to_json(json &, FlexFlow::DropoutAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &); +std::ostream &operator<<(std::ostream &, DropoutAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_DROPOUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml new file mode 100644 index 0000000000..8731e0780b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "DropoutAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "rate" +type = "float" + +[[fields]] +name = "seed" +type = "unsigned long long" diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary.h b/lib/op-attrs/include/op-attrs/ops/element_binary.h index c4a096166d..d51c3a3afa 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_binary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_binary.h @@ -1,25 +1,20 @@ #ifndef _FLEXFLOW_ELEMENT_BINARY_ATTRS_H #define _FLEXFLOW_ELEMENT_BINARY_ATTRS_H -#include "core.h" -#include "op-attrs/datatype.h" -#include "op-attrs/op.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include namespace FlexFlow { -struct ElementBinaryAttrs { - req type; - req compute_type; - req should_broadcast_lhs; - req should_broadcast_rhs; -}; -FF_VISITABLE_STRUCT(ElementBinaryAttrs, - type, - compute_type, - should_broadcast_lhs, - should_broadcast_rhs); +tl::expected get_output_shape( + ElementBinaryAttrs const &, TensorShape const &, TensorShape const &); +tl::expected + get_output_shape(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(ElementBinaryAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h new file mode 100644 index 0000000000..10d93c87d3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.dtg.h @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "2bb947c9cc92e3833ee88c908c539629" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementBinaryAttrs { + ElementBinaryAttrs() = delete; + ElementBinaryAttrs(::FlexFlow::OperatorType const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs); + + bool operator==(ElementBinaryAttrs const &) const; + bool operator!=(ElementBinaryAttrs const &) const; + bool operator<(ElementBinaryAttrs const &) const; + bool operator>(ElementBinaryAttrs const &) const; + bool operator<=(ElementBinaryAttrs const &) const; + bool operator>=(ElementBinaryAttrs const &) const; + ::FlexFlow::OperatorType type; + ::FlexFlow::DataType compute_type; + bool should_broadcast_lhs; + bool should_broadcast_rhs; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementBinaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementBinaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementBinaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementBinaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_BINARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml new file mode 100644 index 0000000000..d167c67aed --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ElementBinaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "compute_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "should_broadcast_lhs" +type = "bool" + +[[fields]] +name = "should_broadcast_rhs" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h new file mode 100644 index 0000000000..a9fe63ca71 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementScalarUnaryAttrs { + ElementScalarUnaryAttrs() = delete; + ElementScalarUnaryAttrs(::FlexFlow::OperatorType const &op_type, + float const &scalar); + + bool operator==(ElementScalarUnaryAttrs const &) const; + bool operator!=(ElementScalarUnaryAttrs const &) const; + bool operator<(ElementScalarUnaryAttrs const &) const; + bool operator>(ElementScalarUnaryAttrs const &) const; + bool operator<=(ElementScalarUnaryAttrs const &) const; + bool operator>=(ElementScalarUnaryAttrs const &) const; + ::FlexFlow::OperatorType op_type; + float scalar; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementScalarUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementScalarUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementScalarUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementScalarUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_SCALAR_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml new file mode 100644 index 0000000000..609805ab98 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ElementScalarUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "scalar" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary.h b/lib/op-attrs/include/op-attrs/ops/element_unary.h index 5e19b81c8c..471a2a30f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary.h +++ b/lib/op-attrs/include/op-attrs/ops/element_unary.h @@ -1,26 +1,32 @@ #ifndef _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H #define _FLEXFLOW_ELEMENTARY_UNARY_ATTRS_H -#include "core.h" -#include "op-attrs/op.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct ElementUnaryAttrs { - req op_type; -}; -FF_VISITABLE_STRUCT(ElementUnaryAttrs, op_type); -CHECK_VALID_OP_ATTR(ElementUnaryAttrs); +tl::expected + get_output_shape(ElementUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementUnaryAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, TensorShape const &); +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &, + ParallelTensorShape const &); -struct ElementScalarUnaryAttrs { - Op op_type; - req scalar; -}; -FF_VISITABLE_STRUCT(ElementScalarUnaryAttrs, op_type, scalar); +CHECK_VALID_OP_ATTR(ElementUnaryAttrs); CHECK_VALID_OP_ATTR(ElementScalarUnaryAttrs); +using ElementUnaryUnifiedAttrs = + std::variant; + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h new file mode 100644 index 0000000000..3220234bd1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "75272cff78d3db866122dbb1001aedbe" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_type.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ElementUnaryAttrs { + ElementUnaryAttrs() = delete; + ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type); + + bool operator==(ElementUnaryAttrs const &) const; + bool operator!=(ElementUnaryAttrs const &) const; + bool operator<(ElementUnaryAttrs const &) const; + bool operator>(ElementUnaryAttrs const &) const; + bool operator<=(ElementUnaryAttrs const &) const; + bool operator>=(ElementUnaryAttrs const &) const; + ::FlexFlow::OperatorType op_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ElementUnaryAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ElementUnaryAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ElementUnaryAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &); +std::ostream &operator<<(std::ostream &, ElementUnaryAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_ELEMENT_UNARY_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml new file mode 100644 index 0000000000..b0e23aa5c7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ElementUnaryAttrs" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.h" +] + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" diff --git a/lib/op-attrs/include/op-attrs/ops/embedding.h b/lib/op-attrs/include/op-attrs/ops/embedding.h index 733a6523da..aa67c6cb04 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding.h +++ b/lib/op-attrs/include/op-attrs/ops/embedding.h @@ -1,51 +1,26 @@ #ifndef _FLEXFLOW_EMBEDDING_ATTRS_H #define _FLEXFLOW_EMBEDDING_ATTRS_H -#include "core.h" -#include "op-attrs/datatype.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" -#include "utils/fmt.h" -#include "utils/visitable.h" +#include namespace FlexFlow { -enum class AggregateOp { SUM, AVG }; - -struct EmbeddingAttrs { - int num_entries, out_channels; - std::optional aggr; - req data_type; -}; -FF_VISITABLE_STRUCT(EmbeddingAttrs, num_entries, out_channels, aggr, data_type); CHECK_VALID_OP_ATTR(EmbeddingAttrs); -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::AggregateOp> : formatter { - template - auto format(::FlexFlow::AggregateOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; +tl::expected get_output_shape(EmbeddingAttrs const &, + TensorShape const &); +tl::expected get_weights_shape(EmbeddingAttrs const &, + TensorShape const &); - string_view name = "unknown"; - switch (o) { - case AggregateOp::SUM: - name = "Sum"; - break; - case AggregateOp::AVG: - name = "Avg"; - break; - } - return formatter::format(name, ctx); - } -}; +tl::expected + get_output_shape(EmbeddingAttrs const &, ParallelTensorShape const &); +tl::expected + get_weights_shape(EmbeddingAttrs const &, ParallelTensorShape const &); -} // namespace fmt +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h new file mode 100644 index 0000000000..f1cae86460 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "f2bdea52e23dee6f674f598f8691d994" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct EmbeddingAttrs { + EmbeddingAttrs() = delete; + EmbeddingAttrs(int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type); + + bool operator==(EmbeddingAttrs const &) const; + bool operator!=(EmbeddingAttrs const &) const; + bool operator<(EmbeddingAttrs const &) const; + bool operator>(EmbeddingAttrs const &) const; + bool operator<=(EmbeddingAttrs const &) const; + bool operator>=(EmbeddingAttrs const &) const; + int num_entries; + int out_channels; + std::optional<::FlexFlow::AggregateOp> aggr; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::EmbeddingAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::EmbeddingAttrs from_json(json const &); + static void to_json(json &, FlexFlow::EmbeddingAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &); +std::ostream &operator<<(std::ostream &, EmbeddingAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_EMBEDDING_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml new file mode 100644 index 0000000000..f0772c351e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "EmbeddingAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "num_entries" +type = "int" + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "aggr" +type = "std::optional<::FlexFlow::AggregateOp>" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 706689199d..d5d9069f51 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -2,13 +2,11 @@ #define _FLEXFLOW_FLAT_ATTRS_H #include "core.h" +#include "op-attrs/ops/flat_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct FlatAttrs {}; -FF_VISITABLE_STRUCT(FlatAttrs); CHECK_VALID_OP_ATTR(FlatAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h new file mode 100644 index 0000000000..a94c0aeff3 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct FlatAttrs { + bool operator==(FlatAttrs const &) const; + bool operator!=(FlatAttrs const &) const; + bool operator<(FlatAttrs const &) const; + bool operator>(FlatAttrs const &) const; + bool operator<=(FlatAttrs const &) const; + bool operator>=(FlatAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::FlatAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::FlatAttrs from_json(json const &); + static void to_json(json &, FlexFlow::FlatAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &); +std::ostream &operator<<(std::ostream &, FlatAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_FLAT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml new file mode 100644 index 0000000000..e445535e29 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "FlatAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index ca2406ef75..79516a8862 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -2,16 +2,11 @@ #define _FLEXFLOW_GATHER_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct GatherAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(GatherAttrs, dim); CHECK_VALID_OP_ATTR(GatherAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h new file mode 100644 index 0000000000..e7a35e5800 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct GatherAttrs { + GatherAttrs() = delete; + GatherAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(GatherAttrs const &) const; + bool operator!=(GatherAttrs const &) const; + bool operator<(GatherAttrs const &) const; + bool operator>(GatherAttrs const &) const; + bool operator<=(GatherAttrs const &) const; + bool operator>=(GatherAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GatherAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GatherAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GatherAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(GatherAttrs const &); +std::ostream &operator<<(std::ostream &, GatherAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_GATHER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml new file mode 100644 index 0000000000..c8bb88dcc7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "GatherAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h" +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 26c486c9ac..9fe0ee2c2d 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -2,14 +2,15 @@ #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #include "core.h" -#include "utils/visitable.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct InputAttrs {}; -FF_VISITABLE_STRUCT(InputAttrs); CHECK_VALID_OP_ATTR(InputAttrs); +ParallelTensorShape get_output_shape(InputAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h new file mode 100644 index 0000000000..aa2ca1e933 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct InputAttrs { + bool operator==(InputAttrs const &) const; + bool operator!=(InputAttrs const &) const; + bool operator<(InputAttrs const &) const; + bool operator>(InputAttrs const &) const; + bool operator<=(InputAttrs const &) const; + bool operator>=(InputAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::InputAttrs from_json(json const &); + static void to_json(json &, FlexFlow::InputAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &); +std::ostream &operator<<(std::ostream &, InputAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_INPUT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml new file mode 100644 index 0000000000..7e29de78df --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "InputAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index dab055b2c9..01130139f1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -2,18 +2,14 @@ #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" namespace FlexFlow { -struct LayerNormAttrs { - stack_vector axes; - req elementwise_affine; - req eps; -}; -FF_VISITABLE_STRUCT(LayerNormAttrs, axes, elementwise_affine, eps); +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &); + CHECK_VALID_OP_ATTR(LayerNormAttrs); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h new file mode 100644 index 0000000000..c945206863 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.dtg.h @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct LayerNormAttrs { + LayerNormAttrs() = delete; + LayerNormAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps); + + bool operator==(LayerNormAttrs const &) const; + bool operator!=(LayerNormAttrs const &) const; + bool operator<(LayerNormAttrs const &) const; + bool operator>(LayerNormAttrs const &) const; + bool operator<=(LayerNormAttrs const &) const; + bool operator>=(LayerNormAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + bool elementwise_affine; + float eps; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerNormAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerNormAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerNormAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &); +std::ostream &operator<<(std::ostream &, LayerNormAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LAYER_NORM_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml new file mode 100644 index 0000000000..ec60d39f7f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "LayerNormAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "elementwise_affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index a46df59282..dd6948165e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,42 +1,31 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include namespace FlexFlow { -struct L1RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L1RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L1RegularizerAttrs); - -struct L2RegularizerAttrs { - req lambda; -}; -FF_VISITABLE_STRUCT(L2RegularizerAttrs, lambda); -CHECK_VALID_OP_ATTR(L2RegularizerAttrs); - -using RegularizerAttrs = std::variant; - -struct LinearAttrs { - int out_channels; - bool use_bias; - DataType data_type; - std::optional activation; - req> regularizer; -}; -FF_VISITABLE_STRUCT( - LinearAttrs, out_channels, use_bias, data_type, activation, regularizer); CHECK_VALID_OP_ATTR(LinearAttrs); -TensorShape get_weights_shape(LinearAttrs const &attrs, - TensorShape const &input); -TensorShape get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); +tl::expected get_bias_shape(LinearAttrs const &attrs, + TensorShape const &input); +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input); + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); +tl::expected + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); +tl::expected + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h new file mode 100644 index 0000000000..28cd2a8b33 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "rapidcheck.h" +#include "utils/json.h" +#include +#include +#include + +namespace FlexFlow { +struct LinearAttrs { + LinearAttrs() = delete; + LinearAttrs(int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + std::optional<::FlexFlow::Activation> const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer); + + bool operator==(LinearAttrs const &) const; + bool operator!=(LinearAttrs const &) const; + bool operator<(LinearAttrs const &) const; + bool operator>(LinearAttrs const &) const; + bool operator<=(LinearAttrs const &) const; + bool operator>=(LinearAttrs const &) const; + int out_channels; + bool use_bias; + ::FlexFlow::DataType data_type; + std::optional<::FlexFlow::Activation> activation; + std::optional<::FlexFlow::RegularizerAttrs> regularizer; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LinearAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LinearAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LinearAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(LinearAttrs const &); +std::ostream &operator<<(std::ostream &, LinearAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LINEAR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml new file mode 100644 index 0000000000..4ac8f83ec9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "LinearAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "utils/json.h", +] + +[[fields]] +name = "out_channels" +type = "int" + +[[fields]] +name = "use_bias" +type = "bool" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" + +[[fields]] +name = "activation" +type = "std::optional<::FlexFlow::Activation>" + +[[fields]] +name = "regularizer" +type = "std::optional<::FlexFlow::RegularizerAttrs>" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index 658e1b7d98..eb01009259 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -2,14 +2,16 @@ #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #include "core.h" -#include "utils/visitable.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct NoopAttrs {}; -FF_VISITABLE_STRUCT(NoopAttrs); CHECK_VALID_OP_ATTR(NoopAttrs); +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h new file mode 100644 index 0000000000..ed0d8c9348 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct NoopAttrs { + bool operator==(NoopAttrs const &) const; + bool operator!=(NoopAttrs const &) const; + bool operator<(NoopAttrs const &) const; + bool operator>(NoopAttrs const &) const; + bool operator<=(NoopAttrs const &) const; + bool operator>=(NoopAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NoopAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NoopAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NoopAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &); +std::ostream &operator<<(std::ostream &, NoopAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_NOOP_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml new file mode 100644 index 0000000000..3d9202093c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "NoopAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h new file mode 100644 index 0000000000..d3903bd3b2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "b76a39763275090d8376e1c27668d2cb" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ParallelMultiHeadAttentionInputs { + ParallelMultiHeadAttentionInputs() = delete; + ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value); + + bool operator==(ParallelMultiHeadAttentionInputs const &) const; + bool operator!=(ParallelMultiHeadAttentionInputs const &) const; + ::FlexFlow::ParallelTensorShape query; + ::FlexFlow::ParallelTensorShape key; + ::FlexFlow::ParallelTensorShape value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelMultiHeadAttentionInputs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelMultiHeadAttentionInputs from_json(json const &); + static void to_json(json &, + FlexFlow::ParallelMultiHeadAttentionInputs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &); +std::ostream &operator<<(std::ostream &, + ParallelMultiHeadAttentionInputs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_PARALLEL_ATTENTION_INPUTS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml new file mode 100644 index 0000000000..4809ee998a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelMultiHeadAttentionInputs" +features = [ + "eq", + # "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.h" +] + +[[fields]] +name = "query" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "key" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "value" +type = "::FlexFlow::ParallelTensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index efe29b3b2e..162f9aef05 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -2,57 +2,16 @@ #define _FLEXFLOW_POOL_2D_ATTRS_H #include "core.h" -#include "op-attrs/activation.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -enum class PoolOp { - MAX, - AVG, -}; - -struct Pool2DAttrs { - req kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w; - req pool_type; - req activation; -}; -FF_VISITABLE_STRUCT(Pool2DAttrs, - kernel_h, - kernel_w, - stride_h, - stride_w, - padding_h, - padding_w, - pool_type, - activation); CHECK_VALID_OP_ATTR(Pool2DAttrs); -} // namespace FlexFlow - -namespace fmt { +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &); -template <> -struct formatter<::FlexFlow::PoolOp> : formatter { - template - auto format(::FlexFlow::PoolOp o, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (o) { - case PoolOp::AVG: - name = "Avg"; - break; - case PoolOp::MAX: - name = "Max"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h new file mode 100644 index 0000000000..a5c6603302 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.dtg.h @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct Pool2DAttrs { + Pool2DAttrs() = delete; + Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation); + + bool operator==(Pool2DAttrs const &) const; + bool operator!=(Pool2DAttrs const &) const; + bool operator<(Pool2DAttrs const &) const; + bool operator>(Pool2DAttrs const &) const; + bool operator<=(Pool2DAttrs const &) const; + bool operator>=(Pool2DAttrs const &) const; + int kernel_h; + int kernel_w; + int stride_h; + int stride_w; + int padding_h; + int padding_w; + ::FlexFlow::PoolOp pool_type; + ::FlexFlow::Activation activation; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Pool2DAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::Pool2DAttrs from_json(json const &); + static void to_json(json &, FlexFlow::Pool2DAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &); +std::ostream &operator<<(std::ostream &, Pool2DAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_POOL_2D_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml new file mode 100644 index 0000000000..56bf682f50 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "Pool2DAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/pool_op.dtg.h", + "op-attrs/activation.dtg.h", +] + +[[fields]] +name = "kernel_h" +type = "int" + +[[fields]] +name = "kernel_w" +type = "int" + +[[fields]] +name = "stride_h" +type = "int" + +[[fields]] +name = "stride_w" +type = "int" + +[[fields]] +name = "padding_h" +type = "int" + +[[fields]] +name = "padding_w" +type = "int" + +[[fields]] +name = "pool_type" +type = "::FlexFlow::PoolOp" + +[[fields]] +name = "activation" +type = "::FlexFlow::Activation" diff --git a/lib/op-attrs/include/op-attrs/ops/reduce.h b/lib/op-attrs/include/op-attrs/ops/reduce.h index 193d3b0dc8..04e44b4161 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce.h +++ b/lib/op-attrs/include/op-attrs/ops/reduce.h @@ -1,23 +1,17 @@ #ifndef _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H #define _FLEXFLOW_OP_META_OPS_REDUCE_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReduceAttrs { - stack_vector axes; - req op_type; - req keepdims; -}; -FF_VISITABLE_STRUCT(ReduceAttrs, axes, op_type, keepdims); CHECK_VALID_OP_ATTR(ReduceAttrs); +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h new file mode 100644 index 0000000000..af27bf35be --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "097463446e254f662c7bdf5df4e12d17" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/operator_type.dtg.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct ReduceAttrs { + ReduceAttrs() = delete; + ReduceAttrs(::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, + MAX_TENSOR_DIM> const &axes, + ::FlexFlow::OperatorType const &op_type, + bool const &keepdims); + + bool operator==(ReduceAttrs const &) const; + bool operator!=(ReduceAttrs const &) const; + bool operator<(ReduceAttrs const &) const; + bool operator>(ReduceAttrs const &) const; + bool operator<=(ReduceAttrs const &) const; + bool operator>=(ReduceAttrs const &) const; + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> axes; + ::FlexFlow::OperatorType op_type; + bool keepdims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReduceAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReduceAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReduceAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &); +std::ostream &operator<<(std::ostream &, ReduceAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml new file mode 100644 index 0000000000..717e7954e8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "ReduceAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "utils/stack_vector.h", +] + +[[fields]] +name = "axes" +type = "::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>" + +[[fields]] +name = "op_type" +type = "::FlexFlow::OperatorType" + +[[fields]] +name = "keepdims" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/reduction.h b/lib/op-attrs/include/op-attrs/ops/reduction.h index f848f879fc..a6047b38f9 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduction.h +++ b/lib/op-attrs/include/op-attrs/ops/reduction.h @@ -1,20 +1,19 @@ #ifndef _FLEXFLOW_REDUCTION_ATTRS_H #define _FLEXFLOW_REDUCTION_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct ReductionAttrs { - ff_dim_t reduction_dim; - req reduction_degree; -}; -FF_VISITABLE_STRUCT(ReductionAttrs, reduction_dim, reduction_degree); CHECK_VALID_OP_ATTR(ReductionAttrs); +tl::expected + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h new file mode 100644 index 0000000000..9de5eb2252 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReductionAttrs { + ReductionAttrs() = delete; + ReductionAttrs(int const &reduction_degree); + + bool operator==(ReductionAttrs const &) const; + bool operator!=(ReductionAttrs const &) const; + bool operator<(ReductionAttrs const &) const; + bool operator>(ReductionAttrs const &) const; + bool operator<=(ReductionAttrs const &) const; + bool operator>=(ReductionAttrs const &) const; + int reduction_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReductionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReductionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReductionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &); +std::ostream &operator<<(std::ostream &, ReductionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REDUCTION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml new file mode 100644 index 0000000000..ee0ae54132 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ReductionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "reduction_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/repartition.h b/lib/op-attrs/include/op-attrs/ops/repartition.h index 83c4ae870b..559e7278f5 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition.h +++ b/lib/op-attrs/include/op-attrs/ops/repartition.h @@ -1,20 +1,19 @@ #ifndef _FLEXFLOW_PARTITION_ATTRS_H #define _FLEXFLOW_PARTITION_ATTRS_H -#include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct RepartitionAttrs { - ff_dim_t repartition_dim; - req repartition_degree; -}; -FF_VISITABLE_STRUCT(RepartitionAttrs, repartition_dim, repartition_degree); CHECK_VALID_OP_ATTR(RepartitionAttrs); +tl::expected + get_output_shape(RepartitionAttrs const &, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h new file mode 100644 index 0000000000..66c21466f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct RepartitionAttrs { + RepartitionAttrs() = delete; + RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree); + + bool operator==(RepartitionAttrs const &) const; + bool operator!=(RepartitionAttrs const &) const; + bool operator<(RepartitionAttrs const &) const; + bool operator>(RepartitionAttrs const &) const; + bool operator<=(RepartitionAttrs const &) const; + bool operator>=(RepartitionAttrs const &) const; + ::FlexFlow::ff_dim_t repartition_dim; + int repartition_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::RepartitionAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::RepartitionAttrs from_json(json const &); + static void to_json(json &, FlexFlow::RepartitionAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &); +std::ostream &operator<<(std::ostream &, RepartitionAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPARTITION_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml new file mode 100644 index 0000000000..25a33c0c15 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "RepartitionAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "repartition_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "repartition_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 92e64a4120..4c46bf88a9 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -2,19 +2,16 @@ #define _FLEXFLOW_REPLICATE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReplicateAttrs { - ff_dim_t replicate_dim; - req replicate_degree; -}; -FF_VISITABLE_STRUCT(ReplicateAttrs, replicate_dim, replicate_degree); CHECK_VALID_OP_ATTR(ReplicateAttrs); +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h new file mode 100644 index 0000000000..ea3f0d46c7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicateAttrs { + ReplicateAttrs() = delete; + ReplicateAttrs(int const &replicate_degree); + + bool operator==(ReplicateAttrs const &) const; + bool operator!=(ReplicateAttrs const &) const; + bool operator<(ReplicateAttrs const &) const; + bool operator>(ReplicateAttrs const &) const; + bool operator<=(ReplicateAttrs const &) const; + bool operator>=(ReplicateAttrs const &) const; + int replicate_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicateAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicateAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReplicateAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &); +std::ostream &operator<<(std::ostream &, ReplicateAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REPLICATE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml new file mode 100644 index 0000000000..4e43ea747a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ReplicateAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ ] + +[[fields]] +name = "replicate_degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index b118482a2b..cd2ca80c3a 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -2,17 +2,16 @@ #define _FLEXFLOW_RESHAPE_ATTRS_H #include "core.h" -#include "op-attrs/tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReshapeAttrs { - TensorShape shape; -}; -FF_VISITABLE_STRUCT(ReshapeAttrs, shape); CHECK_VALID_OP_ATTR(ReshapeAttrs); +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h new file mode 100644 index 0000000000..612874790f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "015d04de0ccb982e7eaa013a842880ca" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/tensor_shape.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReshapeAttrs { + ReshapeAttrs() = delete; + ReshapeAttrs(::FlexFlow::TensorShape const &shape); + + bool operator==(ReshapeAttrs const &) const; + bool operator!=(ReshapeAttrs const &) const; + bool operator<(ReshapeAttrs const &) const; + bool operator>(ReshapeAttrs const &) const; + bool operator<=(ReshapeAttrs const &) const; + bool operator>=(ReshapeAttrs const &) const; + ::FlexFlow::TensorShape shape; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReshapeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReshapeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReshapeAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &); +std::ostream &operator<<(std::ostream &, ReshapeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_RESHAPE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml new file mode 100644 index 0000000000..69ac761859 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ReshapeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index 6030285f14..adc62dc9ae 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -2,17 +2,16 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "utils/visitable.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct ReverseAttrs { - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(ReverseAttrs, axis); CHECK_VALID_OP_ATTR(ReverseAttrs); +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h new file mode 100644 index 0000000000..8c8c8a7a9e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReverseAttrs { + ReverseAttrs() = delete; + ReverseAttrs(::FlexFlow::ff_dim_t const &axis); + + bool operator==(ReverseAttrs const &) const; + bool operator!=(ReverseAttrs const &) const; + bool operator<(ReverseAttrs const &) const; + bool operator>(ReverseAttrs const &) const; + bool operator<=(ReverseAttrs const &) const; + bool operator>=(ReverseAttrs const &) const; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReverseAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReverseAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ReverseAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &); +std::ostream &operator<<(std::ostream &, ReverseAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml new file mode 100644 index 0000000000..198346e5dd --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ReverseAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index 9a776737f5..d855716cfb 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_SOFTMAX_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct SoftmaxAttrs { - ff_dim_t dim; -}; -FF_VISITABLE_STRUCT(SoftmaxAttrs, dim); CHECK_VALID_OP_ATTR(SoftmaxAttrs); +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h new file mode 100644 index 0000000000..1c855d90f4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SoftmaxAttrs { + SoftmaxAttrs() = delete; + SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim); + + bool operator==(SoftmaxAttrs const &) const; + bool operator!=(SoftmaxAttrs const &) const; + bool operator<(SoftmaxAttrs const &) const; + bool operator>(SoftmaxAttrs const &) const; + bool operator<=(SoftmaxAttrs const &) const; + bool operator>=(SoftmaxAttrs const &) const; + ::FlexFlow::ff_dim_t dim; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SoftmaxAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SoftmaxAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SoftmaxAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &); +std::ostream &operator<<(std::ostream &, SoftmaxAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SOFTMAX_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml new file mode 100644 index 0000000000..8b839c122a --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "SoftmaxAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index fa66bc46f5..8fc2257760 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -2,19 +2,18 @@ #define _FLEXFLOW_SPLIT_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include namespace FlexFlow { -struct SplitAttrs { - req> splits; - ff_dim_t axis; -}; -FF_VISITABLE_STRUCT(SplitAttrs, splits, axis); CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h new file mode 100644 index 0000000000..b602015e2e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include "utils/stack_vector.h" +#include +#include +#include + +namespace FlexFlow { +struct SplitAttrs { + SplitAttrs() = delete; + SplitAttrs(::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis); + + bool operator==(SplitAttrs const &) const; + bool operator!=(SplitAttrs const &) const; + bool operator<(SplitAttrs const &) const; + bool operator>(SplitAttrs const &) const; + bool operator<=(SplitAttrs const &) const; + bool operator>=(SplitAttrs const &) const; + ::FlexFlow::stack_vector splits; + ::FlexFlow::ff_dim_t axis; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SplitAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SplitAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SplitAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SplitAttrs const &); +std::ostream &operator<<(std::ostream &, SplitAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_SPLIT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml new file mode 100644 index 0000000000..8cdf7728af --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "SplitAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "utils/stack_vector.h", + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", +] + +[[fields]] +name = "splits" +type = "::FlexFlow::stack_vector" + +[[fields]] +name = "axis" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index 413855913c..c6af40dd48 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_TOPK_ATTRS_H #include "core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct TopKAttrs { - req k; - req sorted; -}; -FF_VISITABLE_STRUCT(TopKAttrs, k, sorted); CHECK_VALID_OP_ATTR(TopKAttrs); +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h new file mode 100644 index 0000000000..d1f32f67b7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TopKAttrs { + TopKAttrs() = delete; + TopKAttrs(int const &k, bool const &sorted); + + bool operator==(TopKAttrs const &) const; + bool operator!=(TopKAttrs const &) const; + bool operator<(TopKAttrs const &) const; + bool operator>(TopKAttrs const &) const; + bool operator<=(TopKAttrs const &) const; + bool operator>=(TopKAttrs const &) const; + int k; + bool sorted; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TopKAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TopKAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TopKAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &); +std::ostream &operator<<(std::ostream &, TopKAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TOPK_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml new file mode 100644 index 0000000000..9ecbf1d725 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "TopKAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "k" +type = "int" + +[[fields]] +name = "sorted" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 87db435979..6e23d91d78 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -2,18 +2,16 @@ #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #include "core.h" -#include "op-attrs/ff_dim.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/visitable.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { -struct TransposeAttrs { - req> perm; -}; -FF_VISITABLE_STRUCT(TransposeAttrs, perm); CHECK_VALID_OP_ATTR(TransposeAttrs); +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h new file mode 100644 index 0000000000..f4d932845f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "de62a505821a59c4b77197c100e204f7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TransposeAttrs { + TransposeAttrs() = delete; + TransposeAttrs(::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm); + + bool operator==(TransposeAttrs const &) const; + bool operator!=(TransposeAttrs const &) const; + bool operator<(TransposeAttrs const &) const; + bool operator>(TransposeAttrs const &) const; + bool operator<=(TransposeAttrs const &) const; + bool operator>=(TransposeAttrs const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> perm; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TransposeAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TransposeAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TransposeAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &); +std::ostream &operator<<(std::ostream &, TransposeAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_TRANSPOSE_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml new file mode 100644 index 0000000000..756091f653 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TransposeAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.h", + "op-attrs/ff_dim.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "perm" +type = "::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>" diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h new file mode 100644 index 0000000000..4a19909c25 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct WeightAttrs { + bool operator==(WeightAttrs const &) const; + bool operator!=(WeightAttrs const &) const; + bool operator<(WeightAttrs const &) const; + bool operator>(WeightAttrs const &) const; + bool operator<=(WeightAttrs const &) const; + bool operator>=(WeightAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::WeightAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::WeightAttrs from_json(json const &); + static void to_json(json &, FlexFlow::WeightAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &); +std::ostream &operator<<(std::ostream &, WeightAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml new file mode 100644 index 0000000000..28810a437e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "WeightAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h new file mode 100644 index 0000000000..4115d4ce1f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.dtg.h @@ -0,0 +1,128 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +/* proj-data +{ + "generated_from": "f382ff547aae62777e5091f00d034d84" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelDim { + ParallelDim() = delete; + explicit ParallelDim(::FlexFlow::ShardParallelDim const &); + explicit ParallelDim(::FlexFlow::ReplicaParallelDim const &); + template + static constexpr bool IsPartOfParallelDim_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::ShardParallelDim>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ReplicaParallelDim>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ParallelDim", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::has() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfParallelDim_v, + "ParallelDim::get() expected one of [::FlexFlow::ShardParallelDim, " + "::FlexFlow::ReplicaParallelDim], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(ParallelDim const &) const; + bool operator!=(ParallelDim const &) const; + bool operator<(ParallelDim const &) const; + bool operator>(ParallelDim const &) const; + bool operator<=(ParallelDim const &) const; + bool operator>=(ParallelDim const &) const; + std::variant<::FlexFlow::ShardParallelDim, ::FlexFlow::ReplicaParallelDim> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::ParallelDim> { + size_t operator()(::FlexFlow::ParallelDim const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::ParallelDim> { + static ::FlexFlow::ParallelDim from_json(json const &); + static void to_json(json &, ::FlexFlow::ParallelDim const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::ParallelDim> { + static Gen<::FlexFlow::ParallelDim> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ParallelDim const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::ParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 9d407ec469..5397ad7c68 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -1,24 +1,17 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_DIM_H -#include "utils/type_traits.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim.dtg.h" namespace FlexFlow { -struct ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(ParallelDim, - size, - degree, - is_replica_dim); - bool is_valid(ParallelDim const &); bool is_replica_dim(ParallelDim const &); +ParallelDim with_size_set_to(ParallelDim const &, size_t); +ParallelDim with_degree_set_to(ParallelDim const &, int); +ParallelDim with_is_replica_set_to(ParallelDim const &, bool); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml new file mode 100644 index 0000000000..e27e6509fe --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_dim.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim.dtg.h", +] + +[[values]] +type = "::FlexFlow::ShardParallelDim" +key = "shard_dim" + +[[values]] +type = "::FlexFlow::ReplicaParallelDim" +key = "replica_dim" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h new file mode 100644 index 0000000000..71ad517095 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.dtg.h @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "aec3b6b66e34be0d5ce3055822479430" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "rapidcheck.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorDims { + ParallelTensorDims() = delete; + ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims); + + bool operator==(ParallelTensorDims const &) const; + bool operator!=(ParallelTensorDims const &) const; + bool operator<(ParallelTensorDims const &) const; + bool operator>(ParallelTensorDims const &) const; + bool operator<=(ParallelTensorDims const &) const; + bool operator>=(ParallelTensorDims const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> shard_dims; + ::FlexFlow::ReplicaParallelDimSet replica_dims; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorDims from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorDims const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &); +std::ostream &operator<<(std::ostream &, ParallelTensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index d38ba75232..8e02e3607b 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -1,53 +1,32 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H -#include "parallel_dim.h" -#include "utils/visitable.h" +#include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" namespace FlexFlow { -struct ParallelTensorDims : public use_visitable_cmp { - explicit ParallelTensorDims(TensorDims const &); - - size_t get_volume() const; - size_t num_dims() const; - - using iterator = typename FFOrdered::iterator; - using const_iterator = typename FFOrdered::const_iterator; - using reverse_iterator = typename FFOrdered::reverse_iterator; - using const_reverse_iterator = - typename FFOrdered::const_reverse_iterator; - using value_type = typename FFOrdered::value_type; - using pointer = typename FFOrdered::pointer; - using const_pointer = typename FFOrdered::const_pointer; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - - iterator begin(); - const_iterator begin() const; - const_iterator cbegin() const; - iterator end(); - const_iterator end() const; - const_iterator cend() const; - reverse_iterator rbegin(); - const_reverse_iterator rbegin() const; - const_reverse_iterator crbegin() const; - reverse_iterator rend(); - const_reverse_iterator rend() const; - const_reverse_iterator crend() const; - -public: - FFOrdered data; -}; +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &); +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &); +std::unordered_set replica_dims(ParallelTensorDims const &); + +/* size_t get_volume(ParallelTensorDims const &); */ +size_t num_shard_dims(ParallelTensorDims const &); + +int total_replica_degree(ParallelTensorDims const &); +int total_shard_degree(ParallelTensorDims const &); +int total_parallel_degree(ParallelTensorDims const &); + +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &, ff_dim_t); bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &); -} // namespace FlexFlow +TensorDims get_reduced_dims(ParallelTensorDims const &); -VISITABLE_STRUCT(::FlexFlow::ParallelTensorDims, data); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorDims); +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml new file mode 100644 index 0000000000..ae6eab1e58 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ParallelTensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/dim_ordered.h", + "op-attrs/shard_parallel_dim.dtg.h", + "op-attrs/replica_parallel_dim_set.dtg.h", + "", + "utils/fmt/unordered_map.h", + "utils/fmt/pair.h", +] + +[[fields]] +name = "shard_dims" +type = "::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>" + +[[fields]] +name = "replica_dims" +type = "::FlexFlow::ReplicaParallelDimSet" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h new file mode 100644 index 0000000000..62d291fa4f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorShape { + ParallelTensorShape() = delete; + ParallelTensorShape(::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(ParallelTensorShape const &) const; + bool operator!=(ParallelTensorShape const &) const; + bool operator<(ParallelTensorShape const &) const; + bool operator>(ParallelTensorShape const &) const; + bool operator<=(ParallelTensorShape const &) const; + bool operator>=(ParallelTensorShape const &) const; + ::FlexFlow::ParallelTensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorShape from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &); +std::ostream &operator<<(std::ostream &, ParallelTensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index fd560352bb..99be635ffc 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,55 +1,47 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H -#include "datatype.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.h" -#include "parallel_tensor_dims.h" -#include "utils/bidict.h" -#include "utils/record_formatter.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" -#include #include namespace FlexFlow { -/** - * @brief Represent the shape of a ParallelTensor. - */ -struct ParallelTensorShape : public use_visitable_cmp { - ParallelTensorShape() = delete; +int num_shard_dims(ParallelTensorShape const &); +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); - template - ParallelTensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); - ParallelTensorShape(TensorShape const &); +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); - int num_dims() const; - - ParallelDim const &at(ff_dim_t const &) const; - ParallelDim &at(ff_dim_t const &); - ParallelDim const &operator[](ff_dim_t const &) const; - ParallelDim &operator[](ff_dim_t const &); - -public: - ParallelTensorDims dims; - DataType data_type; -}; +ParallelTensorShape lift_to_parallel(TensorShape const &); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); +std::unordered_set + replica_dims(ParallelTensorShape const &); TensorShape get_piece_shape(ParallelTensorShape const &); int get_num_replica_dims(ParallelTensorShape const &); int get_num_replicas(ParallelTensorShape const &); +int get_sum_degree(ParallelTensorShape const &); +int get_discard_copy_degree(ParallelTensorShape const &); + +int get_total_parallel_degree(ParallelTensorShape const &); + bool is_valid(ParallelTensorShape const &); TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); -} // namespace FlexFlow +TensorShape get_reduced_shape(ParallelTensorShape const &); -VISITABLE_STRUCT(::FlexFlow::ParallelTensorShape, data_type, dims); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensorShape); +} // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml new file mode 100644 index 0000000000..e6197bcd51 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ParallelTensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_dims.h", + "op-attrs/datatype.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::ParallelTensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h new file mode 100644 index 0000000000..a820bfe81c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct DiscardCopyDegree { + DiscardCopyDegree() = delete; + DiscardCopyDegree(int const &value); + + bool operator==(DiscardCopyDegree const &) const; + bool operator!=(DiscardCopyDegree const &) const; + bool operator<(DiscardCopyDegree const &) const; + bool operator>(DiscardCopyDegree const &) const; + bool operator<=(DiscardCopyDegree const &) const; + bool operator>=(DiscardCopyDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DiscardCopyDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::DiscardCopyDegree from_json(json const &); + static void to_json(json &, FlexFlow::DiscardCopyDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &); +std::ostream &operator<<(std::ostream &, DiscardCopyDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_DISCARD_COPY_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml new file mode 100644 index 0000000000..b4905fb0ce --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DiscardCopyDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h new file mode 100644 index 0000000000..17388f8d05 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SumDegree { + SumDegree() = delete; + SumDegree(int const &value); + + bool operator==(SumDegree const &) const; + bool operator!=(SumDegree const &) const; + bool operator<(SumDegree const &) const; + bool operator>(SumDegree const &) const; + bool operator<=(SumDegree const &) const; + bool operator>=(SumDegree const &) const; + int value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SumDegree const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SumDegree from_json(json const &); + static void to_json(json &, FlexFlow::SumDegree const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &); +std::ostream &operator<<(std::ostream &, SumDegree const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_SHAPE_SUM_DEGREE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml new file mode 100644 index 0000000000..d86917211e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "SumDegree" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/param_sync.dtg.h b/lib/op-attrs/include/op-attrs/param_sync.dtg.h new file mode 100644 index 0000000000..785105fbc4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ParamSync { PS, NCCL }; +std::string format_as(ParamSync); +std::ostream &operator<<(std::ostream &, ParamSync); +void to_json(::nlohmann::json &, ParamSync); +void from_json(::nlohmann::json const &, ParamSync &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParamSync) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARAM_SYNC_DTG_H diff --git a/lib/op-attrs/include/op-attrs/param_sync.enum.toml b/lib/op-attrs/include/op-attrs/param_sync.enum.toml new file mode 100644 index 0000000000..b16a47ab3c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/param_sync.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParamSync" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PS" + +[[values]] +name = "NCCL" diff --git a/lib/op-attrs/include/op-attrs/param_sync.h b/lib/op-attrs/include/op-attrs/param_sync.h index bfae1e712b..dd7048ff36 100644 --- a/lib/op-attrs/include/op-attrs/param_sync.h +++ b/lib/op-attrs/include/op-attrs/param_sync.h @@ -1,36 +1,8 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_PARAM_SYNC_H -#include "utils/fmt.h" +#include "param_sync_t.h" -namespace FlexFlow { - -enum class ParamSync { PS, NCCL }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::ParamSync> : formatter { - template - auto format(::FlexFlow::ParamSync ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case ParamSync::PS: - name = "ParameterServer"; - break; - case ParamSync::NCCL: - name = "NCCL"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +namespace FlexFlow {} #endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h new file mode 100644 index 0000000000..5370773a45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.h @@ -0,0 +1,495 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1d10b0c7c98524c27886bdae0972321" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/cast_attrs.dtg.h" +#include "op-attrs/ops/combine_attrs.dtg.h" +#include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" +#include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/ops/gather_attrs.dtg.h" +#include "op-attrs/ops/input_attrs.dtg.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" +#include "op-attrs/ops/linear_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" +#include "op-attrs/ops/split_attrs.dtg.h" +#include "op-attrs/ops/topk_attrs.dtg.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct PCGOperatorAttrs { + PCGOperatorAttrs() = delete; + explicit PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CastAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::CombineAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::FlatAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::GatherAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::InputAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::LinearAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::NoopAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SplitAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TopKAttrs const &); + explicit PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &); + template + static constexpr bool IsPartOfPCGOperatorAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::BatchMatmulAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::BatchNormAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::CastAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::CombineAttrs>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::ConcatAttrs>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Conv2DAttrs>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::DropoutAttrs>()); + return result; + } + case 7: { + ReturnType result = v(this->get<::FlexFlow::ElementBinaryAttrs>()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::ElementUnaryAttrs>()); + return result; + } + case 9: { + ReturnType result = v(this->get<::FlexFlow::ElementScalarUnaryAttrs>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::EmbeddingAttrs>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::FlatAttrs>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::GatherAttrs>()); + return result; + } + case 13: { + ReturnType result = v(this->get<::FlexFlow::InputAttrs>()); + return result; + } + case 14: { + ReturnType result = v(this->get<::FlexFlow::LayerNormAttrs>()); + return result; + } + case 15: { + ReturnType result = v(this->get<::FlexFlow::LinearAttrs>()); + return result; + } + case 16: { + ReturnType result = v(this->get<::FlexFlow::MultiHeadAttentionAttrs>()); + return result; + } + case 17: { + ReturnType result = v(this->get<::FlexFlow::NoopAttrs>()); + return result; + } + case 18: { + ReturnType result = v(this->get<::FlexFlow::Pool2DAttrs>()); + return result; + } + case 19: { + ReturnType result = v(this->get<::FlexFlow::ReduceAttrs>()); + return result; + } + case 20: { + ReturnType result = v(this->get<::FlexFlow::ReductionAttrs>()); + return result; + } + case 21: { + ReturnType result = v(this->get<::FlexFlow::RepartitionAttrs>()); + return result; + } + case 22: { + ReturnType result = v(this->get<::FlexFlow::ReplicateAttrs>()); + return result; + } + case 23: { + ReturnType result = v(this->get<::FlexFlow::ReverseAttrs>()); + return result; + } + case 24: { + ReturnType result = v(this->get<::FlexFlow::ReshapeAttrs>()); + return result; + } + case 25: { + ReturnType result = v(this->get<::FlexFlow::SplitAttrs>()); + return result; + } + case 26: { + ReturnType result = v(this->get<::FlexFlow::SoftmaxAttrs>()); + return result; + } + case 27: { + ReturnType result = v(this->get<::FlexFlow::TopKAttrs>()); + return result; + } + case 28: { + ReturnType result = v(this->get<::FlexFlow::TransposeAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type PCGOperatorAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::has() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfPCGOperatorAttrs_v, + "PCGOperatorAttrs::get() expected one of " + "[::FlexFlow::BatchMatmulAttrs, ::FlexFlow::BatchNormAttrs, " + "::FlexFlow::CastAttrs, ::FlexFlow::CombineAttrs, " + "::FlexFlow::ConcatAttrs, ::FlexFlow::Conv2DAttrs, " + "::FlexFlow::DropoutAttrs, ::FlexFlow::ElementBinaryAttrs, " + "::FlexFlow::ElementUnaryAttrs, ::FlexFlow::ElementScalarUnaryAttrs, " + "::FlexFlow::EmbeddingAttrs, ::FlexFlow::FlatAttrs, " + "::FlexFlow::GatherAttrs, ::FlexFlow::InputAttrs, " + "::FlexFlow::LayerNormAttrs, ::FlexFlow::LinearAttrs, " + "::FlexFlow::MultiHeadAttentionAttrs, ::FlexFlow::NoopAttrs, " + "::FlexFlow::Pool2DAttrs, ::FlexFlow::ReduceAttrs, " + "::FlexFlow::ReductionAttrs, ::FlexFlow::RepartitionAttrs, " + "::FlexFlow::ReplicateAttrs, ::FlexFlow::ReverseAttrs, " + "::FlexFlow::ReshapeAttrs, ::FlexFlow::SplitAttrs, " + "::FlexFlow::SoftmaxAttrs, ::FlexFlow::TopKAttrs, " + "::FlexFlow::TransposeAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(PCGOperatorAttrs const &) const; + bool operator!=(PCGOperatorAttrs const &) const; + bool operator<(PCGOperatorAttrs const &) const; + bool operator>(PCGOperatorAttrs const &) const; + bool operator<=(PCGOperatorAttrs const &) const; + bool operator>=(PCGOperatorAttrs const &) const; + std::variant<::FlexFlow::BatchMatmulAttrs, + ::FlexFlow::BatchNormAttrs, + ::FlexFlow::CastAttrs, + ::FlexFlow::CombineAttrs, + ::FlexFlow::ConcatAttrs, + ::FlexFlow::Conv2DAttrs, + ::FlexFlow::DropoutAttrs, + ::FlexFlow::ElementBinaryAttrs, + ::FlexFlow::ElementUnaryAttrs, + ::FlexFlow::ElementScalarUnaryAttrs, + ::FlexFlow::EmbeddingAttrs, + ::FlexFlow::FlatAttrs, + ::FlexFlow::GatherAttrs, + ::FlexFlow::InputAttrs, + ::FlexFlow::LayerNormAttrs, + ::FlexFlow::LinearAttrs, + ::FlexFlow::MultiHeadAttentionAttrs, + ::FlexFlow::NoopAttrs, + ::FlexFlow::Pool2DAttrs, + ::FlexFlow::ReduceAttrs, + ::FlexFlow::ReductionAttrs, + ::FlexFlow::RepartitionAttrs, + ::FlexFlow::ReplicateAttrs, + ::FlexFlow::ReverseAttrs, + ::FlexFlow::ReshapeAttrs, + ::FlexFlow::SplitAttrs, + ::FlexFlow::SoftmaxAttrs, + ::FlexFlow::TopKAttrs, + ::FlexFlow::TransposeAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::PCGOperatorAttrs> { + size_t operator()(::FlexFlow::PCGOperatorAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::PCGOperatorAttrs> { + static ::FlexFlow::PCGOperatorAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::PCGOperatorAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::PCGOperatorAttrs> { + static Gen<::FlexFlow::PCGOperatorAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::PCGOperatorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h new file mode 100644 index 0000000000..0ad7a9f829 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PCG_OPERATOR_ATTRS_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &); +OperatorType get_op_type(PCGOperatorAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml new file mode 100644 index 0000000000..ddb8a109d8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -0,0 +1,158 @@ +namespace = "FlexFlow" +name = "PCGOperatorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/conv_2d_attrs.dtg.h", + "op-attrs/ops/dropout_attrs.dtg.h", + "op-attrs/ops/element_binary_attrs.dtg.h", + "op-attrs/ops/element_scalar_unary_attrs.dtg.h", + "op-attrs/ops/element_unary_attrs.dtg.h", + "op-attrs/ops/embedding_attrs.dtg.h", + "op-attrs/ops/flat_attrs.dtg.h", + "op-attrs/ops/gather_attrs.dtg.h", + "op-attrs/ops/input_attrs.dtg.h", + "op-attrs/ops/layer_norm_attrs.dtg.h", + "op-attrs/ops/linear_attrs.dtg.h", + "op-attrs/ops/noop_attrs.dtg.h", + "op-attrs/ops/pool_2d_attrs.dtg.h", + "op-attrs/ops/reduce_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", + "op-attrs/ops/reshape_attrs.dtg.h", + "op-attrs/ops/reverse_attrs.dtg.h", + "op-attrs/ops/softmax_attrs.dtg.h", + "op-attrs/ops/split_attrs.dtg.h", + "op-attrs/ops/topk_attrs.dtg.h", + "op-attrs/ops/transpose_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::BatchMatmulAttrs" +key = "batch_matmul" + +[[values]] +type = "::FlexFlow::BatchNormAttrs" +key = "batch_norm" + +[[values]] +type = "::FlexFlow::CastAttrs" +key = "cast" + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ConcatAttrs" +key = "concat" + +[[values]] +type = "::FlexFlow::Conv2DAttrs" +key = "conv2d" + +[[values]] +type = "::FlexFlow::DropoutAttrs" +key = "dropout" + +[[values]] +type = "::FlexFlow::ElementBinaryAttrs" +key = "element_binary" + +[[values]] +type = "::FlexFlow::ElementUnaryAttrs" +key = "element_unary" + +[[values]] +type = "::FlexFlow::ElementScalarUnaryAttrs" +key = "element_scalar_unary" + +[[values]] +type = "::FlexFlow::EmbeddingAttrs" +key = "embedding" + +[[values]] +type = "::FlexFlow::FlatAttrs" +key = "flat" + +[[values]] +type = "::FlexFlow::GatherAttrs" +key = "gather" + +[[values]] +type = "::FlexFlow::InputAttrs" +key = "input" + +[[values]] +type = "::FlexFlow::LayerNormAttrs" +key = "layer_norm" + +[[values]] +type = "::FlexFlow::LinearAttrs" +key = "linear" + +[[values]] +type = "::FlexFlow::MultiHeadAttentionAttrs" +key = "multi_head_attention" + +[[values]] +type = "::FlexFlow::NoopAttrs" +key = "noop" + +[[values]] +type = "::FlexFlow::Pool2DAttrs" +key = "pool2d" + +[[values]] +type = "::FlexFlow::ReduceAttrs" +key = "reduce" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + +[[values]] +type = "::FlexFlow::ReverseAttrs" +key = "reverse" + +[[values]] +type = "::FlexFlow::ReshapeAttrs" +key = "reshape" + +[[values]] +type = "::FlexFlow::SplitAttrs" +key = "split" + +[[values]] +type = "::FlexFlow::SoftmaxAttrs" +key = "softmax" + +[[values]] +type = "::FlexFlow::TopKAttrs" +key = "topk" + +[[values]] +type = "::FlexFlow::TransposeAttrs" +key = "transpose" diff --git a/lib/op-attrs/include/op-attrs/pool_op.dtg.h b/lib/op-attrs/include/op-attrs/pool_op.dtg.h new file mode 100644 index 0000000000..3511589b52 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class PoolOp { MAX, AVG }; +std::string format_as(PoolOp); +std::ostream &operator<<(std::ostream &, PoolOp); +void to_json(::nlohmann::json &, PoolOp); +void from_json(::nlohmann::json const &, PoolOp &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PoolOp) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_POOL_OP_DTG_H diff --git a/lib/op-attrs/include/op-attrs/pool_op.enum.toml b/lib/op-attrs/include/op-attrs/pool_op.enum.toml new file mode 100644 index 0000000000..88f4dfea19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/pool_op.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "PoolOp" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "MAX" + +[[values]] +name = "AVG" diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h new file mode 100644 index 0000000000..2621b4b12c --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.dtg.h @@ -0,0 +1,128 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "ea060a8ab344c9772102f084903883ea" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/l1_regularizer_attrs.dtg.h" +#include "op-attrs/l2_regularizer_attrs.dtg.h" +#include "rapidcheck.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct RegularizerAttrs { + RegularizerAttrs() = delete; + explicit RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &); + explicit RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &); + template + static constexpr bool IsPartOfRegularizerAttrs_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::L1RegularizerAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::L2RegularizerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type RegularizerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::has() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfRegularizerAttrs_v, + "RegularizerAttrs::get() expected one of " + "[::FlexFlow::L1RegularizerAttrs, " + "::FlexFlow::L2RegularizerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(RegularizerAttrs const &) const; + bool operator!=(RegularizerAttrs const &) const; + bool operator<(RegularizerAttrs const &) const; + bool operator>(RegularizerAttrs const &) const; + bool operator<=(RegularizerAttrs const &) const; + bool operator>=(RegularizerAttrs const &) const; + std::variant<::FlexFlow::L1RegularizerAttrs, ::FlexFlow::L2RegularizerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::RegularizerAttrs> { + size_t operator()(::FlexFlow::RegularizerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::RegularizerAttrs> { + static ::FlexFlow::RegularizerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::RegularizerAttrs const &); +}; +} // namespace nlohmann +namespace rc { +template <> +struct Arbitrary<::FlexFlow::RegularizerAttrs> { + static Gen<::FlexFlow::RegularizerAttrs> arbitrary(); +}; +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::RegularizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REGULARIZER_ATTRS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml new file mode 100644 index 0000000000..d650c7f6a9 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "RegularizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/l1_regularizer_attrs.dtg.h", + "op-attrs/l2_regularizer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::L1RegularizerAttrs" +key = "l1" + +[[values]] +type = "::FlexFlow::L2RegularizerAttrs" +key = "l2" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h new file mode 100644 index 0000000000..250ba29947 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/replica_type.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDim { + ReplicaParallelDim() = delete; + ReplicaParallelDim(int const °ree, + ::FlexFlow::ReplicaType const &replica_type); + + bool operator==(ReplicaParallelDim const &) const; + bool operator!=(ReplicaParallelDim const &) const; + bool operator<(ReplicaParallelDim const &) const; + bool operator>(ReplicaParallelDim const &) const; + bool operator<=(ReplicaParallelDim const &) const; + bool operator>=(ReplicaParallelDim const &) const; + int degree; + ::FlexFlow::ReplicaType replica_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h new file mode 100644 index 0000000000..da3913b426 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_H + +#include "op-attrs/replica_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml new file mode 100644 index 0000000000..2ad442aa22 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/replica_type.dtg.h", +] + +[[fields]] +name = "degree" +type = "int" + +[[fields]] +name = "replica_type" +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h new file mode 100644 index 0000000000..321029347f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ReplicaParallelDimSet { + ReplicaParallelDimSet() = delete; + ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree); + + bool operator==(ReplicaParallelDimSet const &) const; + bool operator!=(ReplicaParallelDimSet const &) const; + bool operator<(ReplicaParallelDimSet const &) const; + bool operator>(ReplicaParallelDimSet const &) const; + bool operator<=(ReplicaParallelDimSet const &) const; + bool operator>=(ReplicaParallelDimSet const &) const; + ::FlexFlow::SumDegree sum_degree; + ::FlexFlow::DiscardCopyDegree discard_copy_degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaParallelDimSet const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ReplicaParallelDimSet from_json(json const &); + static void to_json(json &, FlexFlow::ReplicaParallelDimSet const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &); +std::ostream &operator<<(std::ostream &, ReplicaParallelDimSet const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h new file mode 100644 index 0000000000..74a8df339b --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_PARALLEL_DIM_SET_H + +#include "op-attrs/replica_parallel_dim.dtg.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/replica_type.dtg.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set(); +int get_degree_of_replica_type(ReplicaParallelDimSet const &, ReplicaType); +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &); +bool is_valid(ReplicaParallelDimSet const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml new file mode 100644 index 0000000000..66f50bee9f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "ReplicaParallelDimSet" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" diff --git a/lib/op-attrs/include/op-attrs/replica_type.dtg.h b/lib/op-attrs/include/op-attrs/replica_type.dtg.h new file mode 100644 index 0000000000..3b965d3e77 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ReplicaType { SUM, DISCARD_COPY }; +std::string format_as(ReplicaType); +std::ostream &operator<<(std::ostream &, ReplicaType); +void to_json(::nlohmann::json &, ReplicaType); +void from_json(::nlohmann::json const &, ReplicaType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ReplicaType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_REPLICA_TYPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/replica_type.enum.toml b/lib/op-attrs/include/op-attrs/replica_type.enum.toml new file mode 100644 index 0000000000..0c0eb5e3ab --- /dev/null +++ b/lib/op-attrs/include/op-attrs/replica_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ReplicaType" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "SUM" + +[[values]] +name = "DISCARD_COPY" diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h new file mode 100644 index 0000000000..631852c259 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ShardParallelDim { + ShardParallelDim() = delete; + ShardParallelDim(size_t const &size, int const °ree); + + bool operator==(ShardParallelDim const &) const; + bool operator!=(ShardParallelDim const &) const; + bool operator<(ShardParallelDim const &) const; + bool operator>(ShardParallelDim const &) const; + bool operator<=(ShardParallelDim const &) const; + bool operator>=(ShardParallelDim const &) const; + size_t size; + int degree; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ShardParallelDim const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ShardParallelDim from_json(json const &); + static void to_json(json &, FlexFlow::ShardParallelDim const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &); +std::ostream &operator<<(std::ostream &, ShardParallelDim const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_DTG_H diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.h b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h new file mode 100644 index 0000000000..0a6323192d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_SHARD_PARALLEL_DIM_H + +#include "op-attrs/shard_parallel_dim.dtg.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml new file mode 100644 index 0000000000..21c81396d1 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ShardParallelDim" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "size" +type = "size_t" + +[[fields]] +name = "degree" +type = "int" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h new file mode 100644 index 0000000000..a8e46a4626 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "5beb89eeae9eba303f90e726c794375d" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorDims { + TensorDims() = delete; + TensorDims(::FlexFlow::FFOrdered const &ff_ordered); + + bool operator==(TensorDims const &) const; + bool operator!=(TensorDims const &) const; + bool operator<(TensorDims const &) const; + bool operator>(TensorDims const &) const; + bool operator<=(TensorDims const &) const; + bool operator>=(TensorDims const &) const; + ::FlexFlow::FFOrdered ff_ordered; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorDims const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorDims from_json(json const &); + static void to_json(json &, FlexFlow::TensorDims const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorDims const &); +std::ostream &operator<<(std::ostream &, TensorDims const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h new file mode 100644 index 0000000000..2391197471 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_DIMS_H + +#include "op-attrs/parallel_tensor_dims.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &); + +size_t num_dims(TensorDims const &); +size_t dim_at_idx(TensorDims const &, ff_dim_t); +size_t &dim_at_idx(TensorDims &, ff_dim_t); + +ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml new file mode 100644 index 0000000000..cff8e08b0f --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TensorDims" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +includes = [ + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "ff_ordered" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h new file mode 100644 index 0000000000..f36d5d1306 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" +} +*/ + +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorShape { + TensorShape() = delete; + TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type); + + bool operator==(TensorShape const &) const; + bool operator!=(TensorShape const &) const; + bool operator<(TensorShape const &) const; + bool operator>(TensorShape const &) const; + bool operator<=(TensorShape const &) const; + bool operator>=(TensorShape const &) const; + ::FlexFlow::TensorDims dims; + ::FlexFlow::DataType data_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorShape const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorShape from_json(json const &); + static void to_json(json &, FlexFlow::TensorShape const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorShape const &); +std::ostream &operator<<(std::ostream &, TensorShape const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_TENSOR_SHAPE_DTG_H diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index fa34860817..ad751461e8 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -1,34 +1,14 @@ #ifndef _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H #define _FLEXFLOW_OPATTRS_TENSOR_SHAPE_H -#include "datatype.h" -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/visitable.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -using TensorDims = FFOrdered; - -struct TensorShape : public use_visitable_cmp { - TensorShape() = delete; - - template - TensorShape(Dims const &dims, DataType data_type) - : dims(dims), data_type(data_type) {} - - size_t at(ff_dim_t) const; - size_t operator[](ff_dim_t) const; - -public: - TensorDims dims; - DataType data_type; -}; +size_t num_dims(TensorShape const &); +size_t dim_at_idx(TensorShape const &, ff_dim_t); +size_t &dim_at_idx(TensorShape &, ff_dim_t); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::TensorShape, dims, data_type); -MAKE_VISIT_HASHABLE(::FlexFlow::TensorShape); - #endif diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml new file mode 100644 index 0000000000..901c3b9e60 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/tensor_shape.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TensorShape" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", +] + +[[fields]] +name = "dims" +type = "::FlexFlow::TensorDims" + +[[fields]] +name = "data_type" +type = "::FlexFlow::DataType" diff --git a/lib/op-attrs/src/batch_matmul.cc b/lib/op-attrs/src/batch_matmul.cc deleted file mode 100644 index 1cc8c5cfda..0000000000 --- a/lib/op-attrs/src/batch_matmul.cc +++ /dev/null @@ -1,26 +0,0 @@ -#include "op-attrs/ops/batch_matmul.h" - -namespace FlexFlow { - -/* bool BatchMatmulAttrs::is_valid( */ -/* ParallelTensorShape const &lhs, ParallelTensorShape const &rhs) const { - */ -/* if (!lhs.is_valid() || !rhs.is_valid()) { */ -/* return false; */ -/* } */ -/* if (lhs.num_dims() != rhs.num_dims()) { */ -/* return false; */ -/* } */ -/* for (int i = lhs.num_dims() - 1; i >= 2; i--) { */ -/* if (lhs.at(i) != rhs.at(i)) { */ -/* return false; */ -/* } */ -/* } */ -/* if (lhs.at(0) != rhs.at(1)) { */ -/* return false; */ -/* } */ - -/* return true; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/batch_norm.cc b/lib/op-attrs/src/batch_norm.cc deleted file mode 100644 index 4e352d5f1c..0000000000 --- a/lib/op-attrs/src/batch_norm.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/batch_norm.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/broadcast.cc b/lib/op-attrs/src/broadcast.cc deleted file mode 100644 index c69f480b84..0000000000 --- a/lib/op-attrs/src/broadcast.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/broadcast.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/combine.cc b/lib/op-attrs/src/combine.cc deleted file mode 100644 index cdca524538..0000000000 --- a/lib/op-attrs/src/combine.cc +++ /dev/null @@ -1,18 +0,0 @@ -#include "op-attrs/ops/combine.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -/* bool CombineAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* return input.at(this->combine_legion_dim).degree % this->combine_degree == - * 0; */ -/* } */ - -/* ParallelTensorShape CombineAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->combine_legion_dim).degree /= this->combine_degree; */ -/* return output; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/conv_2d.cc b/lib/op-attrs/src/conv_2d.cc deleted file mode 100644 index 40ba3c8b41..0000000000 --- a/lib/op-attrs/src/conv_2d.cc +++ /dev/null @@ -1,115 +0,0 @@ -#include "op-attrs/ops/conv_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" -#include "utils/vector.h" - -namespace FlexFlow { - -namespace Input { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Output { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4, - NUMDIM = 5; -} - -namespace Kernel { -constexpr int WIDTH = 0, HEIGHT = 1, CHANNEL_IN = 2, CHANNEL_OUT = 3, - REPLICA = 4; -constexpr int WEIGHT_IDX = 0; -} // namespace Kernel - -namespace Bias { -constexpr int CHANNEL = 0, REPLICA_1 = 1, REPLICA_2 = 2, REPLICA_3 = 3, - REPLICA_4 = 4; -constexpr int WEIGHT_IDX = 1; -} // namespace Bias - -static std::vector - construct_output_mappings(ParallelTensorShape const &input_shape) { - return construct_output_parallel_dims( - {{Input::CHANNEL, MappingOperation::REPLICATE, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::REPLICA, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}}); -} - -static std::vector - construct_kernel_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims( - { - {Input::REPLICA, MappingOperation::PARTITION, Kernel::CHANNEL_OUT}, - {Input::SAMPLE, MappingOperation::REPLICATE, Kernel::REPLICA}, - {Input::CHANNEL, MappingOperation::PARTITION, Kernel::CHANNEL_IN}, - {Input::HEIGHT, - MappingOperation::REPLICATE, - Kernel::HEIGHT}, // Kernel::{HEIGHT, WEIGHT} would both work - // here - {Input::WIDTH, - MappingOperation::REPLICATE, - Kernel::WIDTH}, // same as above - }, - 0, - Kernel::WEIGHT_IDX); -} - -static std::vector - construct_bias_mappings(ParallelTensorShape const &input_shape) { - return construct_weight_parallel_dims({{Input::REPLICA, Bias::REPLICA_1}, - {Input::SAMPLE, Bias::REPLICA_2}, - {Input::CHANNEL, Bias::CHANNEL}, - {Input::HEIGHT, Bias::REPLICA_3}, - {Input::WIDTH, Bias::REPLICA_4}}, - 0, - Bias::WEIGHT_IDX); -} - -std::vector - construct_mappings(ParallelTensorShape const &input_shape, bool use_bias) { - std::vector mappings = - concat(construct_output_mappings(input_shape), - construct_kernel_mappings(input_shape)); - if (use_bias) { - std::vector bias_mappings = - construct_bias_mappings(input_shape); - mappings.insert(mappings.end(), bias_mappings.begin(), bias_mappings.end()); - } - - return mappings; -} - -TensorShape get_kernel_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -TensorShape get_bias_shape(Conv2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -/* bool Conv2DAttrs::is_valid(ParallelTensorShape const &input_shape) const { */ -/* bool is_valid = true; */ -/* is_valid &= input_shape.is_valid(); */ -/* is_valid &= this->calculate_output_shape(input_shape).is_valid(); */ -/* is_valid &= this->calculate_kernel_shape(input_shape).is_valid(); */ -/* if (use_bias) { */ -/* is_valid &= this->calculate_bias_shape(input_shape).is_valid(); */ -/* } */ - -/* // TODO FIXME: Currently disable parallelizing the height and width - * dimension */ -/* if (input_shape.at(0).degree > 1 || input_shape.at(1).degree > 1) { */ -/* return false; */ -/* } */ - -/* return is_valid; */ - -/* } */ - -/* OperatorType Conv2DAttrs::op_type() const { */ -/* return OP_CONV2D; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_binary.cc b/lib/op-attrs/src/element_binary.cc deleted file mode 100644 index b713c6753f..0000000000 --- a/lib/op-attrs/src/element_binary.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/element_binary.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/element_unary.cc b/lib/op-attrs/src/element_unary.cc deleted file mode 100644 index 481151fafb..0000000000 --- a/lib/op-attrs/src/element_unary.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/element_unary.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc deleted file mode 100644 index 56014fcc67..0000000000 --- a/lib/op-attrs/src/embedding.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "op-attrs/ops/embedding.h" - -namespace FlexFlow { - -TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/linear.cc b/lib/op-attrs/src/linear.cc deleted file mode 100644 index 16a94e7f6c..0000000000 --- a/lib/op-attrs/src/linear.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/linear.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/noop.cc b/lib/op-attrs/src/noop.cc deleted file mode 100644 index 387660164f..0000000000 --- a/lib/op-attrs/src/noop.cc +++ /dev/null @@ -1 +0,0 @@ -#include "op-attrs/ops/noop.h" diff --git a/lib/op-attrs/src/op-attrs/activation.dtg.cc b/lib/op-attrs/src/op-attrs/activation.dtg.cc new file mode 100644 index 0000000000..5671b1720f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/activation.dtg.cc @@ -0,0 +1,86 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/activation.enum.toml +/* proj-data +{ + "generated_from": "2b0d2e3e825732838aa5be99f2f0e6df" +} +*/ + +#include "op-attrs/activation.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::Activation x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(Activation x) { + switch (x) { + case Activation::RELU: + return "RELU"; + case Activation::SIGMOID: + return "SIGMOID"; + case Activation::TANH: + return "TANH"; + case Activation::GELU: + return "GELU"; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, Activation x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, Activation x) { + switch (x) { + case Activation::RELU: + j = "RELU"; + break; + case Activation::SIGMOID: + j = "SIGMOID"; + break; + case Activation::TANH: + j = "TANH"; + break; + case Activation::GELU: + j = "GELU"; + break; + default: + std::ostringstream oss; + oss << "Unknown Activation value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, Activation &x) { + std::string as_str = j.get(); + if (as_str == "RELU") { + x = Activation::RELU; + } else if (as_str == "SIGMOID") { + x = Activation::SIGMOID; + } else if (as_str == "TANH") { + x = Activation::TANH; + } else if (as_str == "GELU") { + x = Activation::GELU; + } else { + std::ostringstream oss; + oss << "Unknown Activation value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::Activation::RELU, + FlexFlow::Activation::SIGMOID, + FlexFlow::Activation::TANH, + FlexFlow::Activation::GELU); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc new file mode 100644 index 0000000000..72beeb27c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/aggregate_op.dtg.cc @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/aggregate_op.enum.toml +/* proj-data +{ + "generated_from": "441fe9b0bb8f2dc2b31f74c58320ef30" +} +*/ + +#include "op-attrs/aggregate_op.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::AggregateOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(AggregateOp x) { + switch (x) { + case AggregateOp::SUM: + return "SUM"; + default: + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, AggregateOp x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, AggregateOp x) { + switch (x) { + case AggregateOp::SUM: + j = "SUM"; + break; + default: + std::ostringstream oss; + oss << "Unknown AggregateOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, AggregateOp &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = AggregateOp::SUM; + } else { + std::ostringstream oss; + oss << "Unknown AggregateOp value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::AggregateOp::SUM); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc new file mode 100644 index 0000000000..f8d05de941 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/as_dot.cc @@ -0,0 +1,13 @@ +#include "op-attrs/as_dot.h" + +namespace FlexFlow { + +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..166416cbad --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,11 @@ +#include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc new file mode 100644 index 0000000000..9bcde22cd9 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.dtg.cc @@ -0,0 +1,597 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +/* proj-data +{ + "generated_from": "cc0ab49405423594ffa1d8f541235a48" +} +*/ + +#include "op-attrs/computation_graph_op_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::BroadcastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +ComputationGraphOpAttrs::ComputationGraphOpAttrs( + ::FlexFlow::WeightAttrs const &v) + : raw_variant(v) {} +bool ComputationGraphOpAttrs::operator==( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool ComputationGraphOpAttrs::operator!=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool ComputationGraphOpAttrs::operator<=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool ComputationGraphOpAttrs::operator>=( + ComputationGraphOpAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::ComputationGraphOpAttrs>::operator()( + ::FlexFlow::ComputationGraphOpAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::ComputationGraphOpAttrs + adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "batch_matmul") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "batch_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "broadcast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::BroadcastAttrs>()}; + } else if (key == "cast") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "concat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "conv2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "dropout") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "element_binary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "element_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "element_scalar_unary") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "embedding") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "flat") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "gather") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "input") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "layer_norm") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "linear") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "multi_head_attention") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "noop") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "pool2d") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "reduce") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "reverse") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "reshape") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "split") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "softmax") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "topk") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "transpose") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else if (key == "weight") { + return ::FlexFlow::ComputationGraphOpAttrs{ + j.at("value").template get<::FlexFlow::WeightAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::ComputationGraphOpAttrs>::to_json( + json &j, ::FlexFlow::ComputationGraphOpAttrs const &x) { + j["__type"] = "ComputationGraphOpAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "batch_matmul"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "batch_norm"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "broadcast"; + j["value"] = x.get<::FlexFlow::BroadcastAttrs>(); + break; + } + case 3: { + j["type"] = "cast"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 4: { + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 5: { + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 6: { + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 7: { + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 8: { + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 10: { + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 11: { + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 12: { + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 13: { + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 14: { + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 15: { + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 16: { + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 17: { + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 18: { + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 19: { + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 20: { + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 21: { + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 22: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 23: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 24: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 25: { + j["type"] = "transpose"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + case 26: { + j["type"] = "weight"; + j["value"] = x.get<::FlexFlow::WeightAttrs>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ComputationGraphOpAttrs> + Arbitrary<::FlexFlow::ComputationGraphOpAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::BroadcastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>()), + gen::construct<::FlexFlow::ComputationGraphOpAttrs>( + gen::arbitrary<::FlexFlow::WeightAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ComputationGraphOpAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + case 26: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type ComputationGraphOpAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::ComputationGraphOpAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/datatype.dtg.cc b/lib/op-attrs/src/op-attrs/datatype.dtg.cc new file mode 100644 index 0000000000..a9c1d54f0e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/datatype.dtg.cc @@ -0,0 +1,102 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/datatype.enum.toml +/* proj-data +{ + "generated_from": "8315d0aa0a65b00c13aa580e923592ef" +} +*/ + +#include "op-attrs/datatype.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DataType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DataType x) { + switch (x) { + case DataType::BOOL: + return "BOOL"; + case DataType::INT32: + return "INT32"; + case DataType::INT64: + return "INT64"; + case DataType::HALF: + return "HALF"; + case DataType::FLOAT: + return "FLOAT"; + case DataType::DOUBLE: + return "DOUBLE"; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DataType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DataType x) { + switch (x) { + case DataType::BOOL: + j = "BOOL"; + break; + case DataType::INT32: + j = "INT32"; + break; + case DataType::INT64: + j = "INT64"; + break; + case DataType::HALF: + j = "HALF"; + break; + case DataType::FLOAT: + j = "FLOAT"; + break; + case DataType::DOUBLE: + j = "DOUBLE"; + break; + default: + std::ostringstream oss; + oss << "Unknown DataType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DataType &x) { + std::string as_str = j.get(); + if (as_str == "BOOL") { + x = DataType::BOOL; + } else if (as_str == "INT32") { + x = DataType::INT32; + } else if (as_str == "INT64") { + x = DataType::INT64; + } else if (as_str == "HALF") { + x = DataType::HALF; + } else if (as_str == "FLOAT") { + x = DataType::FLOAT; + } else if (as_str == "DOUBLE") { + x = DataType::DOUBLE; + } else { + std::ostringstream oss; + oss << "Unknown DataType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DataType::BOOL, + FlexFlow::DataType::INT32, + FlexFlow::DataType::INT64, + FlexFlow::DataType::HALF, + FlexFlow::DataType::FLOAT, + FlexFlow::DataType::DOUBLE); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc new file mode 100644 index 0000000000..8b22dfd18d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_dim.dtg.cc @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ff_dim.struct.toml +/* proj-data +{ + "generated_from": "a5fa89a024e95c4f2d52681a74cab30f" +} +*/ + +#include "op-attrs/ff_dim.dtg.h" + +#include + +namespace FlexFlow { +ff_dim_t::ff_dim_t(int const &value) : value(value) {} +bool ff_dim_t::operator==(ff_dim_t const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ff_dim_t::operator!=(ff_dim_t const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ff_dim_t::operator<(ff_dim_t const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ff_dim_t::operator>(ff_dim_t const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ff_dim_t::operator<=(ff_dim_t const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ff_dim_t::operator>=(ff_dim_t const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::ff_dim_t const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ff_dim_t + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::ff_dim_t const &v) { + j["__type"] = "ff_dim_t"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ff_dim_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ff_dim_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/get_op_type.cc b/lib/op-attrs/src/op-attrs/get_op_type.cc similarity index 60% rename from lib/op-attrs/src/get_op_type.cc rename to lib/op-attrs/src/op-attrs/get_op_type.cc index 3fa401b647..aced8d873c 100644 --- a/lib/op-attrs/src/get_op_type.cc +++ b/lib/op-attrs/src/op-attrs/get_op_type.cc @@ -3,25 +3,25 @@ namespace FlexFlow { OperatorType get_op_type(BatchMatmulAttrs const &) { - return Op::BATCHMATMUL; + return OperatorType::BATCHMATMUL; } OperatorType get_op_type(BatchNormAttrs const &) { - return Op::BATCHNORM; + return OperatorType::BATCHNORM; } OperatorType get_op_type(BroadcastAttrs const &) { - return Op::BROADCAST; + return OperatorType::BROADCAST; } OperatorType get_op_type(CastAttrs const &) { - return Op::CAST; + return OperatorType::CAST; } OperatorType get_op_type(ConcatAttrs const &) { - return Op::CONCAT; + return OperatorType::CONCAT; } OperatorType get_op_type(Conv2DAttrs const &) { - return Op::CONV2D; + return OperatorType::CONV2D; } OperatorType get_op_type(DropoutAttrs const &) { - return Op::DROPOUT; + return OperatorType::DROPOUT; } OperatorType get_op_type(ElementBinaryAttrs const &attrs) { return attrs.type; @@ -33,64 +33,67 @@ OperatorType get_op_type(ElementScalarUnaryAttrs const &attrs) { return attrs.op_type; } OperatorType get_op_type(EmbeddingAttrs const &) { - return Op::EMBEDDING; + return OperatorType::EMBEDDING; } OperatorType get_op_type(FlatAttrs const &) { - return Op::FLAT; + return OperatorType::FLAT; } OperatorType get_op_type(GatherAttrs const &) { - return Op::GATHER; + return OperatorType::GATHER; } OperatorType get_op_type(InputAttrs const &) { - return Op::INPUT; + return OperatorType::INPUT; } OperatorType get_op_type(LayerNormAttrs const &) { - return Op::LAYERNORM; + return OperatorType::LAYERNORM; } OperatorType get_op_type(LinearAttrs const &) { - return Op::LINEAR; + return OperatorType::LINEAR; } OperatorType get_op_type(MultiHeadAttentionAttrs const &) { - return Op::MULTIHEAD_ATTENTION; + return OperatorType::MULTIHEAD_ATTENTION; } OperatorType get_op_type(NoopAttrs const &) { - return Op::NOOP; + return OperatorType::NOOP; } OperatorType get_op_type(Pool2DAttrs const &) { - return Op::POOL2D; + return OperatorType::POOL2D; } -OperatorType get_op_type(ReduceAttrs const &) { - return Op::REDUCE_SUM; +OperatorType get_op_type(ReduceAttrs const &attrs) { + return attrs.op_type; } OperatorType get_op_type(ReshapeAttrs const &) { - return Op::RESHAPE; + return OperatorType::RESHAPE; +} +OperatorType get_op_type(ReverseAttrs const &) { + return OperatorType::REVERSE; } OperatorType get_op_type(SplitAttrs const &) { - return Op::SPLIT; + return OperatorType::SPLIT; } OperatorType get_op_type(SoftmaxAttrs const &) { - return Op::SOFTMAX; + return OperatorType::SOFTMAX; } OperatorType get_op_type(TopKAttrs const &) { - return Op::TOPK; + return OperatorType::TOPK; } OperatorType get_op_type(TransposeAttrs const &) { - return Op::TRANSPOSE; + return OperatorType::TRANSPOSE; } OperatorType get_op_type(CombineAttrs const &) { - return Op::COMBINE; + return OperatorType::COMBINE; } OperatorType get_op_type(ReductionAttrs const &) { - return Op::REDUCTION; + return OperatorType::REDUCTION; } OperatorType get_op_type(RepartitionAttrs const &) { - return Op::REPARTITION; + return OperatorType::REPARTITION; } OperatorType get_op_type(ReplicateAttrs const &) { - return Op::REPLICATE; + return OperatorType::REPLICATE; } -OperatorType get_op_type(ReverseAttrs const &attrs) { - return Op::REVERSE; +OperatorType get_op_type(WeightAttrs const &) { + return OperatorType::WEIGHT; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..ed06df2c78 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l1_regularizer_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l1_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "50968fb8a3d43395d0eab7594f4935c0" +} +*/ + +#include "op-attrs/l1_regularizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +L1RegularizerAttrs::L1RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L1RegularizerAttrs::operator==(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator!=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator<=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L1RegularizerAttrs::operator>=(L1RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L1RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L1RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L1RegularizerAttrs const &v) { + j["__type"] = "L1RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L1RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L1RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..f0f3f34ee5 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/l2_regularizer_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/l2_regularizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "c4f182e547ab6f0d5613e7eeb95d438e" +} +*/ + +#include "op-attrs/l2_regularizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +L2RegularizerAttrs::L2RegularizerAttrs(float const &lambda) : lambda(lambda) {} +bool L2RegularizerAttrs::operator==(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) == std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator!=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) != std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) < std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) > std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator<=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) <= std::tie(other.lambda); +} +bool L2RegularizerAttrs::operator>=(L2RegularizerAttrs const &other) const { + return std::tie(this->lambda) >= std::tie(other.lambda); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::L2RegularizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lambda) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::L2RegularizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lambda").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::L2RegularizerAttrs const &v) { + j["__type"] = "L2RegularizerAttrs"; + j["lambda"] = v.lambda; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(L2RegularizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, L2RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_type.cc b/lib/op-attrs/src/op-attrs/operator_type.cc new file mode 100644 index 0000000000..5a516ef122 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type.cc @@ -0,0 +1,24 @@ +#include "op-attrs/operator_type.h" + +namespace FlexFlow { + +std::string get_operator_type_name(OperatorType op) { + return fmt::to_string(op); +} + +bool is_parallel_op(OperatorType const &t) { + switch (t) { + case OperatorType::REPARTITION: + case OperatorType::COMBINE: + case OperatorType::REPLICATE: + case OperatorType::REDUCTION: + case OperatorType::BATCH: + case OperatorType::PIPELINE: + case OperatorType::FUSED_PARALLEL: + return true; + default: + return false; + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/operator_type.dtg.cc b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc new file mode 100644 index 0000000000..07b6396a5a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/operator_type.dtg.cc @@ -0,0 +1,720 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/operator_type.enum.toml +/* proj-data +{ + "generated_from": "c1c4687ef2fbc7dad996e5c25d47124c" +} +*/ + +#include "op-attrs/operator_type.dtg.h" + +#include +#include + +namespace std { +size_t + hash::operator()(FlexFlow::OperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorType x) { + switch (x) { + case OperatorType::NOOP: + return "NOOP"; + case OperatorType::INPUT: + return "INPUT"; + case OperatorType::WEIGHT: + return "WEIGHT"; + case OperatorType::CONV2D: + return "CONV2D"; + case OperatorType::DROPOUT: + return "DROPOUT"; + case OperatorType::LINEAR: + return "LINEAR"; + case OperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case OperatorType::POOL2D: + return "POOL2D"; + case OperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case OperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case OperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case OperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case OperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case OperatorType::RELU: + return "RELU"; + case OperatorType::IDENTITY: + return "IDENTITY"; + case OperatorType::SIGMOID: + return "SIGMOID"; + case OperatorType::TANH: + return "TANH"; + case OperatorType::ELU: + return "ELU"; + case OperatorType::FLAT: + return "FLAT"; + case OperatorType::SOFTMAX: + return "SOFTMAX"; + case OperatorType::BATCHNORM: + return "BATCHNORM"; + case OperatorType::CONCAT: + return "CONCAT"; + case OperatorType::SPLIT: + return "SPLIT"; + case OperatorType::EMBEDDING: + return "EMBEDDING"; + case OperatorType::CACHE: + return "CACHE"; + case OperatorType::RESHAPE: + return "RESHAPE"; + case OperatorType::REVERSE: + return "REVERSE"; + case OperatorType::TRANSPOSE: + return "TRANSPOSE"; + case OperatorType::EW_ADD: + return "EW_ADD"; + case OperatorType::EW_MUL: + return "EW_MUL"; + case OperatorType::MATMUL: + return "MATMUL"; + case OperatorType::MUL: + return "MUL"; + case OperatorType::ENLARGE: + return "ENLARGE"; + case OperatorType::SQUEEZE: + return "SQUEEZE"; + case OperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case OperatorType::EW_SUB: + return "EW_SUB"; + case OperatorType::EW_DIV: + return "EW_DIV"; + case OperatorType::EW_EQUAL: + return "EW_EQUAL"; + case OperatorType::EW_GREATER: + return "EW_GREATER"; + case OperatorType::EW_LESS: + return "EW_LESS"; + case OperatorType::EW_MAX: + return "EW_MAX"; + case OperatorType::EW_MIN: + return "EW_MIN"; + case OperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case OperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case OperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case OperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case OperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case OperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case OperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case OperatorType::PAD: + return "PAD"; + case OperatorType::SHAPE: + return "SHAPE"; + case OperatorType::SIZE: + return "SIZE"; + case OperatorType::TOPK: + return "TOPK"; + case OperatorType::WHERE: + return "WHERE"; + case OperatorType::CEIL: + return "CEIL"; + case OperatorType::CAST: + return "CAST"; + case OperatorType::EXP: + return "EXP"; + case OperatorType::ROUND: + return "ROUND"; + case OperatorType::LOG: + return "LOG"; + case OperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case OperatorType::SQRT: + return "SQRT"; + case OperatorType::SIN: + return "SIN"; + case OperatorType::COS: + return "COS"; + case OperatorType::LEAKYRELU: + return "LEAKYRELU"; + case OperatorType::SLICE: + return "SLICE"; + case OperatorType::RESIZE: + return "RESIZE"; + case OperatorType::PRELU: + return "PRELU"; + case OperatorType::GELU: + return "GELU"; + case OperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case OperatorType::FUSED: + return "FUSED"; + case OperatorType::RSQRT: + return "RSQRT"; + case OperatorType::POW: + return "POW"; + case OperatorType::MEAN: + return "MEAN"; + case OperatorType::LAYERNORM: + return "LAYERNORM"; + case OperatorType::GATHER: + return "GATHER"; + case OperatorType::BROADCAST: + return "BROADCAST"; + case OperatorType::REPARTITION: + return "REPARTITION"; + case OperatorType::COMBINE: + return "COMBINE"; + case OperatorType::REPLICATE: + return "REPLICATE"; + case OperatorType::REDUCTION: + return "REDUCTION"; + case OperatorType::BATCH: + return "BATCH"; + case OperatorType::PIPELINE: + return "PIPELINE"; + case OperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorType x) { + switch (x) { + case OperatorType::NOOP: + j = "NOOP"; + break; + case OperatorType::INPUT: + j = "INPUT"; + break; + case OperatorType::WEIGHT: + j = "WEIGHT"; + break; + case OperatorType::CONV2D: + j = "CONV2D"; + break; + case OperatorType::DROPOUT: + j = "DROPOUT"; + break; + case OperatorType::LINEAR: + j = "LINEAR"; + break; + case OperatorType::BATCHMATMUL: + j = "BATCHMATMUL"; + break; + case OperatorType::POOL2D: + j = "POOL2D"; + break; + case OperatorType::SCALAR_MULTIPLY: + j = "SCALAR_MULTIPLY"; + break; + case OperatorType::SCALAR_ADD: + j = "SCALAR_ADD"; + break; + case OperatorType::SCALAR_FLOOR_DIV: + j = "SCALAR_FLOOR_DIV"; + break; + case OperatorType::SCALAR_TRUE_DIV: + j = "SCALAR_TRUE_DIV"; + break; + case OperatorType::SCALAR_SUB: + j = "SCALAR_SUB"; + break; + case OperatorType::RELU: + j = "RELU"; + break; + case OperatorType::IDENTITY: + j = "IDENTITY"; + break; + case OperatorType::SIGMOID: + j = "SIGMOID"; + break; + case OperatorType::TANH: + j = "TANH"; + break; + case OperatorType::ELU: + j = "ELU"; + break; + case OperatorType::FLAT: + j = "FLAT"; + break; + case OperatorType::SOFTMAX: + j = "SOFTMAX"; + break; + case OperatorType::BATCHNORM: + j = "BATCHNORM"; + break; + case OperatorType::CONCAT: + j = "CONCAT"; + break; + case OperatorType::SPLIT: + j = "SPLIT"; + break; + case OperatorType::EMBEDDING: + j = "EMBEDDING"; + break; + case OperatorType::CACHE: + j = "CACHE"; + break; + case OperatorType::RESHAPE: + j = "RESHAPE"; + break; + case OperatorType::REVERSE: + j = "REVERSE"; + break; + case OperatorType::TRANSPOSE: + j = "TRANSPOSE"; + break; + case OperatorType::EW_ADD: + j = "EW_ADD"; + break; + case OperatorType::EW_MUL: + j = "EW_MUL"; + break; + case OperatorType::MATMUL: + j = "MATMUL"; + break; + case OperatorType::MUL: + j = "MUL"; + break; + case OperatorType::ENLARGE: + j = "ENLARGE"; + break; + case OperatorType::SQUEEZE: + j = "SQUEEZE"; + break; + case OperatorType::UNSQUEEZE: + j = "UNSQUEEZE"; + break; + case OperatorType::EW_SUB: + j = "EW_SUB"; + break; + case OperatorType::EW_DIV: + j = "EW_DIV"; + break; + case OperatorType::EW_EQUAL: + j = "EW_EQUAL"; + break; + case OperatorType::EW_GREATER: + j = "EW_GREATER"; + break; + case OperatorType::EW_LESS: + j = "EW_LESS"; + break; + case OperatorType::EW_MAX: + j = "EW_MAX"; + break; + case OperatorType::EW_MIN: + j = "EW_MIN"; + break; + case OperatorType::REDUCE_ARGMAX: + j = "REDUCE_ARGMAX"; + break; + case OperatorType::REDUCE_ARGMIN: + j = "REDUCE_ARGMIN"; + break; + case OperatorType::REDUCE_MAX: + j = "REDUCE_MAX"; + break; + case OperatorType::REDUCE_MEAN: + j = "REDUCE_MEAN"; + break; + case OperatorType::REDUCE_MIN: + j = "REDUCE_MIN"; + break; + case OperatorType::REDUCE_PROD: + j = "REDUCE_PROD"; + break; + case OperatorType::REDUCE_SUM: + j = "REDUCE_SUM"; + break; + case OperatorType::PAD: + j = "PAD"; + break; + case OperatorType::SHAPE: + j = "SHAPE"; + break; + case OperatorType::SIZE: + j = "SIZE"; + break; + case OperatorType::TOPK: + j = "TOPK"; + break; + case OperatorType::WHERE: + j = "WHERE"; + break; + case OperatorType::CEIL: + j = "CEIL"; + break; + case OperatorType::CAST: + j = "CAST"; + break; + case OperatorType::EXP: + j = "EXP"; + break; + case OperatorType::ROUND: + j = "ROUND"; + break; + case OperatorType::LOG: + j = "LOG"; + break; + case OperatorType::LOGICAL_NOT: + j = "LOGICAL_NOT"; + break; + case OperatorType::SQRT: + j = "SQRT"; + break; + case OperatorType::SIN: + j = "SIN"; + break; + case OperatorType::COS: + j = "COS"; + break; + case OperatorType::LEAKYRELU: + j = "LEAKYRELU"; + break; + case OperatorType::SLICE: + j = "SLICE"; + break; + case OperatorType::RESIZE: + j = "RESIZE"; + break; + case OperatorType::PRELU: + j = "PRELU"; + break; + case OperatorType::GELU: + j = "GELU"; + break; + case OperatorType::MULTIHEAD_ATTENTION: + j = "MULTIHEAD_ATTENTION"; + break; + case OperatorType::FUSED: + j = "FUSED"; + break; + case OperatorType::RSQRT: + j = "RSQRT"; + break; + case OperatorType::POW: + j = "POW"; + break; + case OperatorType::MEAN: + j = "MEAN"; + break; + case OperatorType::LAYERNORM: + j = "LAYERNORM"; + break; + case OperatorType::GATHER: + j = "GATHER"; + break; + case OperatorType::BROADCAST: + j = "BROADCAST"; + break; + case OperatorType::REPARTITION: + j = "REPARTITION"; + break; + case OperatorType::COMBINE: + j = "COMBINE"; + break; + case OperatorType::REPLICATE: + j = "REPLICATE"; + break; + case OperatorType::REDUCTION: + j = "REDUCTION"; + break; + case OperatorType::BATCH: + j = "BATCH"; + break; + case OperatorType::PIPELINE: + j = "PIPELINE"; + break; + case OperatorType::FUSED_PARALLEL: + j = "FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorType &x) { + std::string as_str = j.get(); + if (as_str == "NOOP") { + x = OperatorType::NOOP; + } else if (as_str == "INPUT") { + x = OperatorType::INPUT; + } else if (as_str == "WEIGHT") { + x = OperatorType::WEIGHT; + } else if (as_str == "CONV2D") { + x = OperatorType::CONV2D; + } else if (as_str == "DROPOUT") { + x = OperatorType::DROPOUT; + } else if (as_str == "LINEAR") { + x = OperatorType::LINEAR; + } else if (as_str == "BATCHMATMUL") { + x = OperatorType::BATCHMATMUL; + } else if (as_str == "POOL2D") { + x = OperatorType::POOL2D; + } else if (as_str == "SCALAR_MULTIPLY") { + x = OperatorType::SCALAR_MULTIPLY; + } else if (as_str == "SCALAR_ADD") { + x = OperatorType::SCALAR_ADD; + } else if (as_str == "SCALAR_FLOOR_DIV") { + x = OperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "SCALAR_TRUE_DIV") { + x = OperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "SCALAR_SUB") { + x = OperatorType::SCALAR_SUB; + } else if (as_str == "RELU") { + x = OperatorType::RELU; + } else if (as_str == "IDENTITY") { + x = OperatorType::IDENTITY; + } else if (as_str == "SIGMOID") { + x = OperatorType::SIGMOID; + } else if (as_str == "TANH") { + x = OperatorType::TANH; + } else if (as_str == "ELU") { + x = OperatorType::ELU; + } else if (as_str == "FLAT") { + x = OperatorType::FLAT; + } else if (as_str == "SOFTMAX") { + x = OperatorType::SOFTMAX; + } else if (as_str == "BATCHNORM") { + x = OperatorType::BATCHNORM; + } else if (as_str == "CONCAT") { + x = OperatorType::CONCAT; + } else if (as_str == "SPLIT") { + x = OperatorType::SPLIT; + } else if (as_str == "EMBEDDING") { + x = OperatorType::EMBEDDING; + } else if (as_str == "CACHE") { + x = OperatorType::CACHE; + } else if (as_str == "RESHAPE") { + x = OperatorType::RESHAPE; + } else if (as_str == "REVERSE") { + x = OperatorType::REVERSE; + } else if (as_str == "TRANSPOSE") { + x = OperatorType::TRANSPOSE; + } else if (as_str == "EW_ADD") { + x = OperatorType::EW_ADD; + } else if (as_str == "EW_MUL") { + x = OperatorType::EW_MUL; + } else if (as_str == "MATMUL") { + x = OperatorType::MATMUL; + } else if (as_str == "MUL") { + x = OperatorType::MUL; + } else if (as_str == "ENLARGE") { + x = OperatorType::ENLARGE; + } else if (as_str == "SQUEEZE") { + x = OperatorType::SQUEEZE; + } else if (as_str == "UNSQUEEZE") { + x = OperatorType::UNSQUEEZE; + } else if (as_str == "EW_SUB") { + x = OperatorType::EW_SUB; + } else if (as_str == "EW_DIV") { + x = OperatorType::EW_DIV; + } else if (as_str == "EW_EQUAL") { + x = OperatorType::EW_EQUAL; + } else if (as_str == "EW_GREATER") { + x = OperatorType::EW_GREATER; + } else if (as_str == "EW_LESS") { + x = OperatorType::EW_LESS; + } else if (as_str == "EW_MAX") { + x = OperatorType::EW_MAX; + } else if (as_str == "EW_MIN") { + x = OperatorType::EW_MIN; + } else if (as_str == "REDUCE_ARGMAX") { + x = OperatorType::REDUCE_ARGMAX; + } else if (as_str == "REDUCE_ARGMIN") { + x = OperatorType::REDUCE_ARGMIN; + } else if (as_str == "REDUCE_MAX") { + x = OperatorType::REDUCE_MAX; + } else if (as_str == "REDUCE_MEAN") { + x = OperatorType::REDUCE_MEAN; + } else if (as_str == "REDUCE_MIN") { + x = OperatorType::REDUCE_MIN; + } else if (as_str == "REDUCE_PROD") { + x = OperatorType::REDUCE_PROD; + } else if (as_str == "REDUCE_SUM") { + x = OperatorType::REDUCE_SUM; + } else if (as_str == "PAD") { + x = OperatorType::PAD; + } else if (as_str == "SHAPE") { + x = OperatorType::SHAPE; + } else if (as_str == "SIZE") { + x = OperatorType::SIZE; + } else if (as_str == "TOPK") { + x = OperatorType::TOPK; + } else if (as_str == "WHERE") { + x = OperatorType::WHERE; + } else if (as_str == "CEIL") { + x = OperatorType::CEIL; + } else if (as_str == "CAST") { + x = OperatorType::CAST; + } else if (as_str == "EXP") { + x = OperatorType::EXP; + } else if (as_str == "ROUND") { + x = OperatorType::ROUND; + } else if (as_str == "LOG") { + x = OperatorType::LOG; + } else if (as_str == "LOGICAL_NOT") { + x = OperatorType::LOGICAL_NOT; + } else if (as_str == "SQRT") { + x = OperatorType::SQRT; + } else if (as_str == "SIN") { + x = OperatorType::SIN; + } else if (as_str == "COS") { + x = OperatorType::COS; + } else if (as_str == "LEAKYRELU") { + x = OperatorType::LEAKYRELU; + } else if (as_str == "SLICE") { + x = OperatorType::SLICE; + } else if (as_str == "RESIZE") { + x = OperatorType::RESIZE; + } else if (as_str == "PRELU") { + x = OperatorType::PRELU; + } else if (as_str == "GELU") { + x = OperatorType::GELU; + } else if (as_str == "MULTIHEAD_ATTENTION") { + x = OperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "FUSED") { + x = OperatorType::FUSED; + } else if (as_str == "RSQRT") { + x = OperatorType::RSQRT; + } else if (as_str == "POW") { + x = OperatorType::POW; + } else if (as_str == "MEAN") { + x = OperatorType::MEAN; + } else if (as_str == "LAYERNORM") { + x = OperatorType::LAYERNORM; + } else if (as_str == "GATHER") { + x = OperatorType::GATHER; + } else if (as_str == "BROADCAST") { + x = OperatorType::BROADCAST; + } else if (as_str == "REPARTITION") { + x = OperatorType::REPARTITION; + } else if (as_str == "COMBINE") { + x = OperatorType::COMBINE; + } else if (as_str == "REPLICATE") { + x = OperatorType::REPLICATE; + } else if (as_str == "REDUCTION") { + x = OperatorType::REDUCTION; + } else if (as_str == "BATCH") { + x = OperatorType::BATCH; + } else if (as_str == "PIPELINE") { + x = OperatorType::PIPELINE; + } else if (as_str == "FUSED_PARALLEL") { + x = OperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown OperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorType::NOOP, + FlexFlow::OperatorType::INPUT, + FlexFlow::OperatorType::WEIGHT, + FlexFlow::OperatorType::CONV2D, + FlexFlow::OperatorType::DROPOUT, + FlexFlow::OperatorType::LINEAR, + FlexFlow::OperatorType::BATCHMATMUL, + FlexFlow::OperatorType::POOL2D, + FlexFlow::OperatorType::SCALAR_MULTIPLY, + FlexFlow::OperatorType::SCALAR_ADD, + FlexFlow::OperatorType::SCALAR_FLOOR_DIV, + FlexFlow::OperatorType::SCALAR_TRUE_DIV, + FlexFlow::OperatorType::SCALAR_SUB, + FlexFlow::OperatorType::RELU, + FlexFlow::OperatorType::IDENTITY, + FlexFlow::OperatorType::SIGMOID, + FlexFlow::OperatorType::TANH, + FlexFlow::OperatorType::ELU, + FlexFlow::OperatorType::FLAT, + FlexFlow::OperatorType::SOFTMAX, + FlexFlow::OperatorType::BATCHNORM, + FlexFlow::OperatorType::CONCAT, + FlexFlow::OperatorType::SPLIT, + FlexFlow::OperatorType::EMBEDDING, + FlexFlow::OperatorType::CACHE, + FlexFlow::OperatorType::RESHAPE, + FlexFlow::OperatorType::REVERSE, + FlexFlow::OperatorType::TRANSPOSE, + FlexFlow::OperatorType::EW_ADD, + FlexFlow::OperatorType::EW_MUL, + FlexFlow::OperatorType::MATMUL, + FlexFlow::OperatorType::MUL, + FlexFlow::OperatorType::ENLARGE, + FlexFlow::OperatorType::SQUEEZE, + FlexFlow::OperatorType::UNSQUEEZE, + FlexFlow::OperatorType::EW_SUB, + FlexFlow::OperatorType::EW_DIV, + FlexFlow::OperatorType::EW_EQUAL, + FlexFlow::OperatorType::EW_GREATER, + FlexFlow::OperatorType::EW_LESS, + FlexFlow::OperatorType::EW_MAX, + FlexFlow::OperatorType::EW_MIN, + FlexFlow::OperatorType::REDUCE_ARGMAX, + FlexFlow::OperatorType::REDUCE_ARGMIN, + FlexFlow::OperatorType::REDUCE_MAX, + FlexFlow::OperatorType::REDUCE_MEAN, + FlexFlow::OperatorType::REDUCE_MIN, + FlexFlow::OperatorType::REDUCE_PROD, + FlexFlow::OperatorType::REDUCE_SUM, + FlexFlow::OperatorType::PAD, + FlexFlow::OperatorType::SHAPE, + FlexFlow::OperatorType::SIZE, + FlexFlow::OperatorType::TOPK, + FlexFlow::OperatorType::WHERE, + FlexFlow::OperatorType::CEIL, + FlexFlow::OperatorType::CAST, + FlexFlow::OperatorType::EXP, + FlexFlow::OperatorType::ROUND, + FlexFlow::OperatorType::LOG, + FlexFlow::OperatorType::LOGICAL_NOT, + FlexFlow::OperatorType::SQRT, + FlexFlow::OperatorType::SIN, + FlexFlow::OperatorType::COS, + FlexFlow::OperatorType::LEAKYRELU, + FlexFlow::OperatorType::SLICE, + FlexFlow::OperatorType::RESIZE, + FlexFlow::OperatorType::PRELU, + FlexFlow::OperatorType::GELU, + FlexFlow::OperatorType::MULTIHEAD_ATTENTION, + FlexFlow::OperatorType::FUSED, + FlexFlow::OperatorType::RSQRT, + FlexFlow::OperatorType::POW, + FlexFlow::OperatorType::MEAN, + FlexFlow::OperatorType::LAYERNORM, + FlexFlow::OperatorType::GATHER, + FlexFlow::OperatorType::BROADCAST, + FlexFlow::OperatorType::REPARTITION, + FlexFlow::OperatorType::COMBINE, + FlexFlow::OperatorType::REPLICATE, + FlexFlow::OperatorType::REDUCTION, + FlexFlow::OperatorType::BATCH, + FlexFlow::OperatorType::PIPELINE, + FlexFlow::OperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc similarity index 77% rename from lib/op-attrs/src/attention.cc rename to lib/op-attrs/src/op-attrs/ops/attention.cc index 2c1500a477..14ab2b9b00 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -1,4 +1,9 @@ #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { @@ -27,78 +32,175 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { } int get_qSize(TensorShape const &query_shape) { - return query_shape.at(ff_dim_t(0)); + return dim_at_idx(query_shape, ff_dim_t(0)); } int get_kSize(TensorShape const &key_shape) { - return key_shape.at(ff_dim_t(0)); + return dim_at_idx(key_shape, ff_dim_t(0)); } int get_vSize(TensorShape const &value_shape) { - return value_shape.at(ff_dim_t(0)); + return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(MultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } -int get_kSize(MultiHeadAttentionInputs const &) { +int get_qSize(MultiHeadAttentionInputs const &) { NOT_IMPLEMENTED(); } -int get_vSize(MultiHeadAttentionInputs const &) { +int get_kSize(MultiHeadAttentionParallelInputs const &) { NOT_IMPLEMENTED(); } -TensorShape +int get_kSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionParallelInputs const &) { + NOT_IMPLEMENTED(); +} + +int get_vSize(MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} + +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &attrs, + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + + MultiHeadAttentionInputs parsed = parse_result.value(); + + return TensorShape{ + TensorDims{FFOrdered{ + parsed.batch_size, + parsed.sequence_length, + size_t_from_int(attrs.embed_dim), + }}, + parsed.datatype, + }; +} + +tl::expected get_weights_shape(MultiHeadAttentionAttrs const &attrs, - MultiHeadAttentionInputs const &inputs) { - size_t qParas = get_qProjSize(attrs) * get_qSize(inputs); - size_t kParas = get_kProjSize(attrs) * get_kSize(inputs); - size_t vParas = get_vProjSize(attrs) * get_vSize(inputs); - TensorShape output_shape = get_output_shape(attrs, inputs); - size_t oParas = get_oProjSize(attrs) * get_oSize(output_shape); + TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + tl::expected parse_result = + parse_attention_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } - TensorDims dims = {qParas + kParas + vParas + oParas, - static_cast(attrs.embed_dim)}; + MultiHeadAttentionInputs parsed = parse_result.value(); - return {dims, DataType::FLOAT}; + // W^Q_i in "Attention Is All You Need" top of page 5 + size_t qProjectWeightSize = parsed.query_size * attrs.kdim; + + // W^K_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t kProjectWeightSize = parsed.key_size * attrs.kdim; + + // W^V_i in "Attention Is All You Need" top of page 5 (all i's put together) + size_t vProjectWeightSize = parsed.value_size * attrs.vdim; + + // W^O in "Attention Is All You Need" top of page 5, with num_heads factored + // out + size_t outWeightSize = parsed.value_size * attrs.embed_dim; + + return TensorShape{ + TensorDims{FFOrdered{ + (qProjectWeightSize + kProjectWeightSize + vProjectWeightSize + + outWeightSize), + size_t_from_int(attrs.num_heads), + }}, + parsed.datatype, + }; } -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - ParallelTensorShape const &query_shape, - ParallelTensorShape const &key_shape, - ParallelTensorShape const &value_shape) { - /* ParallelDim replica_dim = query_shape.at(ff_dim_t(query_shape.num_dims() - - * 2)); */ - /* replica_dim.size = replica_dim.degree; */ +tl::expected + get_weights_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); + + tl::expected result_unpar_get_shape = + get_weights_shape(attrs, + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); - /* ParallelDim */ + int joined_dim_degree = 1; + int head_dim_degree = parsed.discard_copy_degree.value; - ParallelTensorShape output_shape = query_shape; - output_shape.at(ff_dim_t(output_shape.num_dims() - 1)).size = attrs.embed_dim; - return output_shape; + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{1}, + DiscardCopyDegree{parsed.batch_dim.degree}, + FFOrdered{joined_dim_degree, head_dim_degree}); } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, - TensorShape const &query_shape, - TensorShape const &key_shape, - TensorShape const &value_shape) { - ParallelTensorShape parallel_shape = +tl::expected + get_output_shape(MultiHeadAttentionAttrs const &attrs, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected parse_result = + parse_attention_parallel_input_shape(input_q, input_k, input_v); + if (!parse_result.has_value()) { + return tl::unexpected(parse_result.error()); + } + MultiHeadAttentionParallelInputs parsed = parse_result.value(); + + tl::expected result_unpar_get_shape = get_output_shape(attrs, - static_cast(query_shape), - static_cast(key_shape), - static_cast(value_shape)); - return get_tensor_shape_unsafe(parallel_shape); + get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!result_unpar_get_shape.has_value()) { + return tl::unexpected(result_unpar_get_shape.error()); + } + TensorShape unpar_shape = result_unpar_get_shape.value(); + + int sum_degree = parsed.discard_copy_degree.value; + int discard_copy_degree = 1; + int batch_degree = parsed.batch_dim.degree; + int seq_len_degree = 1; + int out_dim_degree = 1; + + return lift_to_parallel_with_degrees( + unpar_shape, + SumDegree{sum_degree}, + DiscardCopyDegree{discard_copy_degree}, + FFOrdered{batch_degree, seq_len_degree, out_dim_degree}); } -TensorShape get_output_shape(MultiHeadAttentionAttrs const &, - MultiHeadAttentionInputs const &) { + +int get_oSize(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } -int get_oSize(ParallelTensorShape const &) { +int get_oSize(TensorShape const &) { NOT_IMPLEMENTED(); } + } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc new file mode 100644 index 0000000000..65feb642e1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc @@ -0,0 +1,80 @@ +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected + parse_attention_input_shape(TensorShape const &input_q, + TensorShape const &input_k, + TensorShape const &input_v) { + if (num_dims(input_q) != 3) { + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_dims(input_q), + 3)); + } + if (num_dims(input_k) != 3) { + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_dims(input_k), + 3)); + } + if (num_dims(input_v) != 3) { + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_dims(input_v), + 3)); + } + + size_t seq_len_q = dim_at_idx(input_q, ff_dim_t{-2}); + size_t seq_len_k = dim_at_idx(input_k, ff_dim_t{-2}); + size_t seq_len_v = dim_at_idx(input_v, ff_dim_t{-2}); + + if (!all_same(seq_len_q, seq_len_k, seq_len_v)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the sequence length: {} (Q) vs {} (K) vs {} (V)", + seq_len_q, + seq_len_k, + seq_len_v)); + } + + size_t batch_size_q = dim_at_idx(input_q, ff_dim_t{-3}); + size_t batch_size_k = dim_at_idx(input_k, ff_dim_t{-3}); + size_t batch_size_v = dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same(batch_size_q, batch_size_k, batch_size_v)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the batch size: {} (Q) vs {} (K) vs {} (V)", + batch_size_q, + batch_size_k, + batch_size_v)); + } + + if (!all_same(input_q.data_type, input_k.data_type, input_v.data_type)) { + return tl::unexpected(fmt::format( + "Q, K, V disagree on the datatype: {} (Q) vs {} (K) vs {} (V)", + input_q.data_type, + input_k.data_type, + input_v.data_type)); + } + + size_t q_size = dim_at_idx(input_q, ff_dim_t{-1}); + size_t k_size = dim_at_idx(input_k, ff_dim_t{-1}); + size_t v_size = dim_at_idx(input_v, ff_dim_t{-1}); + + return MultiHeadAttentionInputs{ + batch_size_q, + seq_len_q, + q_size, + k_size, + v_size, + input_q.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc new file mode 100644 index 0000000000..26d3138eb4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.dtg.cc @@ -0,0 +1,185 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "c57a9d1d2822a726ee9d9369d22e8e72" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include +#include + +namespace FlexFlow { +MultiHeadAttentionInputs::MultiHeadAttentionInputs( + size_t const &batch_size, + size_t const &sequence_length, + size_t const &query_size, + size_t const &key_size, + size_t const &value_size, + ::FlexFlow::DataType const &datatype) + : batch_size(batch_size), sequence_length(sequence_length), + query_size(query_size), key_size(key_size), value_size(value_size), + datatype(datatype) {} +bool MultiHeadAttentionInputs::operator==( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) == std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator!=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) != std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) < std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) > std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator<=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) <= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +bool MultiHeadAttentionInputs::operator>=( + MultiHeadAttentionInputs const &other) const { + return std::tie(this->batch_size, + this->sequence_length, + this->query_size, + this->key_size, + this->value_size, + this->datatype) >= std::tie(other.batch_size, + other.sequence_length, + other.query_size, + other.key_size, + other.value_size, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.batch_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.sequence_length) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.query_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.key_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.value_size) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("batch_size").template get(), + j.at("sequence_length").template get(), + j.at("query_size").template get(), + j.at("key_size").template get(), + j.at("value_size").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionInputs const &v) { + j["__type"] = "MultiHeadAttentionInputs"; + j["batch_size"] = v.batch_size; + j["sequence_length"] = v.sequence_length; + j["query_size"] = v.query_size; + j["key_size"] = v.key_size; + j["value_size"] = v.value_size; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc new file mode 100644 index 0000000000..2cd5b7ec00 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -0,0 +1,132 @@ +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" +#include "op-attrs/ops/attention/multihead_attention_inputs.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +template +static bool all_same(T const &x, T const &y, T const &z) { + return x == y && y == z; +} + +tl::expected + parse_attention_parallel_input_shape(ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v) { + tl::expected unpar_parse_result = + parse_attention_input_shape(get_reduced_shape(input_q), + get_reduced_shape(input_k), + get_reduced_shape(input_v)); + if (!unpar_parse_result.has_value()) { + return tl::unexpected( + fmt::format("MHA unparallel input parsing failed with message: \"{}\"", + unpar_parse_result.error())); + } + + if (num_shard_dims(input_q) != 3) { + return tl::unexpected( + fmt::format("Query input has incorrect number of dims: {} != {}", + num_shard_dims(input_q), + 3)); + } + if (num_shard_dims(input_k) != 3) { + return tl::unexpected( + fmt::format("Key input has incorrect number of dims: {} != {}", + num_shard_dims(input_k), + 3)); + } + if (num_shard_dims(input_v) != 3) { + return tl::unexpected( + fmt::format("Value input has incorrect number of dims: {} != {}", + num_shard_dims(input_v), + 3)); + } + + ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, ff_dim_t{-2}); + if (seq_len_q.degree != 1) { + return tl::unexpected( + fmt::format("Query sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_q.degree)); + } + + ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, ff_dim_t{-2}); + if (seq_len_k.degree != 1) { + return tl::unexpected( + fmt::format("Key sequence length parallel degree expected to be 1, but " + "received degree {}", + seq_len_k.degree)); + } + + ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, ff_dim_t{-2}); + if (seq_len_v.degree != 1) { + return tl::unexpected( + fmt::format("Value sequence length parallel degree expected to be 1, " + "but received degree {}", + seq_len_v.degree)); + } + + ShardParallelDim batch_size_q = shard_dim_at_idx(input_q, ff_dim_t{-3}); + ShardParallelDim batch_size_k = shard_dim_at_idx(input_k, ff_dim_t{-3}); + ShardParallelDim batch_size_v = shard_dim_at_idx(input_v, ff_dim_t{-3}); + + if (!all_same( + batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { + return tl::unexpected( + fmt::format("Q, K, V disagree on the parallel degree of the batch " + "dimension: {} (Q) vs {} (K) vs {} (V)", + batch_size_q.degree, + batch_size_k.degree, + batch_size_v.degree)); + } + + ShardParallelDim query_dim = shard_dim_at_idx(input_q, ff_dim_t{-1}); + if (query_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected query tensor to have query dim parallel degree " + "1, but received degree {}", + query_dim.degree)); + } + + ShardParallelDim key_dim = shard_dim_at_idx(input_k, ff_dim_t{-1}); + if (key_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected key tensor to have key dim parallel degree 1, " + "but received degree {}", + key_dim.degree)); + } + + ShardParallelDim value_dim = shard_dim_at_idx(input_v, ff_dim_t{-1}); + if (value_dim.degree > 1) { + return tl::unexpected( + fmt::format("Expected value tensor to have value dim parallel degree " + "1, but received degree {}", + value_dim.degree)); + } + + int discard_copy_q = get_discard_copy_degree(input_q); + int discard_copy_k = get_discard_copy_degree(input_k); + int discard_copy_v = get_discard_copy_degree(input_v); + + if (!all_same(discard_copy_q, discard_copy_k, discard_copy_v)) { + return tl::unexpected(fmt::format("Q, K, V disagree on the discard-copy " + "degree: {} (Q) vs {} (K) vs {} (V)", + discard_copy_q, + discard_copy_k, + discard_copy_v)); + } + + return MultiHeadAttentionParallelInputs{ + batch_size_q, + seq_len_q, + query_dim, + key_dim, + value_dim, + discard_copy_q, + input_q.data_type, + }; + + // return; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc new file mode 100644 index 0000000000..94784d83cc --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.cc @@ -0,0 +1,209 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention/multihead_attention_parallel_inputs.struct.toml +/* proj-data +{ + "generated_from": "7c434445707968123a361c038a337da2" +} +*/ + +#include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include +#include + +namespace FlexFlow { +MultiHeadAttentionParallelInputs::MultiHeadAttentionParallelInputs( + ::FlexFlow::ShardParallelDim const &batch_dim, + ::FlexFlow::ShardParallelDim const &sequence_dim, + ::FlexFlow::ShardParallelDim const &query_dim, + ::FlexFlow::ShardParallelDim const &key_dim, + ::FlexFlow::ShardParallelDim const &value_dim, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree, + ::FlexFlow::DataType const &datatype) + : batch_dim(batch_dim), sequence_dim(sequence_dim), query_dim(query_dim), + key_dim(key_dim), value_dim(value_dim), + discard_copy_degree(discard_copy_degree), datatype(datatype) {} +bool MultiHeadAttentionParallelInputs::operator==( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) == std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator!=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) != std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) < std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) > std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator<=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) <= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +bool MultiHeadAttentionParallelInputs::operator>=( + MultiHeadAttentionParallelInputs const &other) const { + return std::tie(this->batch_dim, + this->sequence_dim, + this->query_dim, + this->key_dim, + this->value_dim, + this->discard_copy_degree, + this->datatype) >= std::tie(other.batch_dim, + other.sequence_dim, + other.query_dim, + other.key_dim, + other.value_dim, + other.discard_copy_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionParallelInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.batch_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sequence_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.query_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.key_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.value_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionParallelInputs + adl_serializer::from_json( + json const &j) { + return { + j.at("batch_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sequence_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("query_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("key_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("value_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("discard_copy_degree").template get<::FlexFlow::DiscardCopyDegree>(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionParallelInputs const &v) { + j["__type"] = "MultiHeadAttentionParallelInputs"; + j["batch_dim"] = v.batch_dim; + j["sequence_dim"] = v.sequence_dim; + j["query_dim"] = v.query_dim; + j["key_dim"] = v.key_dim; + j["value_dim"] = v.value_dim; + j["discard_copy_degree"] = v.discard_copy_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionParallelInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + MultiHeadAttentionParallelInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc new file mode 100644 index 0000000000..ad0c094969 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/attention_attrs.dtg.cc @@ -0,0 +1,220 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/attention_attrs.struct.toml +/* proj-data +{ + "generated_from": "360324465947562229dc6632a9e9a2f3" +} +*/ + +#include "op-attrs/ops/attention_attrs.dtg.h" + +#include + +namespace FlexFlow { +MultiHeadAttentionAttrs::MultiHeadAttentionAttrs(int const &embed_dim, + int const &num_heads, + int const &kdim, + int const &vdim, + float const &dropout, + bool const &bias, + bool const &add_bias_kv, + bool const &add_zero_attn) + : embed_dim(embed_dim), num_heads(num_heads), kdim(kdim), vdim(vdim), + dropout(dropout), bias(bias), add_bias_kv(add_bias_kv), + add_zero_attn(add_zero_attn) {} +bool MultiHeadAttentionAttrs::operator==( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) == std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator!=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) != std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) < std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) > std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator<=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) <= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +bool MultiHeadAttentionAttrs::operator>=( + MultiHeadAttentionAttrs const &other) const { + return std::tie(this->embed_dim, + this->num_heads, + this->kdim, + this->vdim, + this->dropout, + this->bias, + this->add_bias_kv, + this->add_zero_attn) >= std::tie(other.embed_dim, + other.num_heads, + other.kdim, + other.vdim, + other.dropout, + other.bias, + other.add_bias_kv, + other.add_zero_attn); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MultiHeadAttentionAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.embed_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_heads) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.vdim) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.dropout) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.bias) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.add_bias_kv) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.add_zero_attn) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MultiHeadAttentionAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("embed_dim").template get(), + j.at("num_heads").template get(), + j.at("kdim").template get(), + j.at("vdim").template get(), + j.at("dropout").template get(), + j.at("bias").template get(), + j.at("add_bias_kv").template get(), + j.at("add_zero_attn").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MultiHeadAttentionAttrs const &v) { + j["__type"] = "MultiHeadAttentionAttrs"; + j["embed_dim"] = v.embed_dim; + j["num_heads"] = v.num_heads; + j["kdim"] = v.kdim; + j["vdim"] = v.vdim; + j["dropout"] = v.dropout; + j["bias"] = v.bias; + j["add_bias_kv"] = v.add_bias_kv; + j["add_zero_attn"] = v.add_zero_attn; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(MultiHeadAttentionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MultiHeadAttentionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc new file mode 100644 index 0000000000..cbda4ea533 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -0,0 +1,176 @@ +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +// bool BatchMatmulAttrs::is_valid( +// ParallelTensorShape const &lhs, +// ParallelTensorShape const &rhs) const { +// if (!lhs.is_valid() || !rhs.is_valid()) { +// return false; +// } +// if (lhs.num_dims() != rhs.num_dims()) { +// return false; +// } +// for (int i = lhs.num_dims() - 1; i >= 2; i--) { +// if (lhs.at(i) != rhs.at(i)) { +// return false; +// } +// } +// if (lhs.at(0) != rhs.at(1)) { +// return false; +// } +// +// return true; +// } + +bool is_valid(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + // If input_lhs is a (b×n×m) tensor, + // input_rhs is a (b×m×p) tensor, + // out will be a (b×n×p) tensor. + // https://pytorch.org/docs/stable/generated/torch.bmm.html + + if (num_dims(input_lhs) != 3) { + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_dims(input_lhs), + 3)); + } + if (num_dims(input_rhs) != 3) { + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_dims(input_rhs), + 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); + } + + size_t lhs_b = dim_at_idx(input_lhs, ff_dim_t{0}); + size_t n = dim_at_idx(input_lhs, ff_dim_t{1}); + size_t lhs_m = dim_at_idx(input_lhs, ff_dim_t{2}); + + size_t rhs_b = dim_at_idx(input_rhs, ff_dim_t{0}); + size_t rhs_m = dim_at_idx(input_rhs, ff_dim_t{1}); + size_t p = dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + if (lhs_m != rhs_m) { + return tl::unexpected( + fmt::format("RHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + lhs_b, + n, + p, + }, + }, + input_lhs.data_type, + }; +} + +tl::expected + get_output_shape(BatchMatmulAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + if (num_shard_dims(input_lhs) != 3) { + return tl::unexpected( + fmt::format("LHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_lhs), + 3)); + } + if (num_shard_dims(input_rhs) != 3) { + return tl::unexpected( + fmt::format("RHS input has incorrect number of shard dims: {} != {}", + num_shard_dims(input_rhs), + 3)); + } + if (input_lhs.data_type != input_rhs.data_type) { + return tl::unexpected(fmt::format("Input datatypes do not match: {} != {}", + input_lhs.data_type, + input_rhs.data_type)); + } + + assert(get_total_parallel_degree(input_lhs) == + get_total_parallel_degree(input_rhs)); + + ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, ff_dim_t{0}); + ShardParallelDim n = shard_dim_at_idx(input_lhs, ff_dim_t{1}); + ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, ff_dim_t{2}); + + ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, ff_dim_t{0}); + ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, ff_dim_t{1}); + ShardParallelDim p = shard_dim_at_idx(input_rhs, ff_dim_t{2}); + + if (lhs_b != rhs_b) { + return tl::unexpected( + fmt::format("LHS b dim ({}) != RHS b dim ({})", lhs_b, rhs_b)); + } + + if (lhs_m != rhs_m) { + return tl::unexpected( + fmt::format("LHS m dim ({}) != RHS m dim ({})", lhs_m, rhs_m)); + } + + if (get_discard_copy_degree(input_lhs) != + get_sum_degree(input_rhs) * p.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in LHS: " + "lhs.= ({}) != rhs.+ ({}) * rhs.p ({})", + get_discard_copy_degree(input_lhs), + get_sum_degree(input_rhs), + p.degree)); + } + + if (get_discard_copy_degree(input_rhs) != + get_sum_degree(input_lhs) * n.degree) { + return tl::unexpected(fmt::format("Unexpected number of replicas in RHS: " + "rhs.= ({}) != lhs.+ ({}) * lhs.n ({})", + get_discard_copy_degree(input_rhs), + get_sum_degree(input_lhs), + n.degree)); + } + + ShardParallelDim output_b = lhs_b; + ShardParallelDim output_n = n; + ShardParallelDim output_p = p; + + int output_discard_copy_degree = 1; + int output_sum_degree = get_total_parallel_degree(input_lhs) / + (output_b.degree * output_n.degree * output_p.degree); + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_b, + output_n, + output_p, + }, + ReplicaParallelDimSet{ + output_sum_degree, + output_discard_copy_degree, + }, + }, + input_lhs.data_type, + }; + + return result; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc new file mode 100644 index 0000000000..f178d40696 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.dtg.cc @@ -0,0 +1,90 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_matmul.struct.toml +/* proj-data +{ + "generated_from": "c3bbf4c76982ef27107b74e1e6e5d360" +} +*/ + +#include "op-attrs/ops/batch_matmul.dtg.h" + +#include + +namespace FlexFlow { +BatchMatmulAttrs::BatchMatmulAttrs(int const &a_seq_length_dim, + int const &b_seq_length_dim) + : a_seq_length_dim(a_seq_length_dim), b_seq_length_dim(b_seq_length_dim) {} +bool BatchMatmulAttrs::operator==(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) == + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator!=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) != + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) < + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) > + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator<=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) <= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +bool BatchMatmulAttrs::operator>=(BatchMatmulAttrs const &other) const { + return std::tie(this->a_seq_length_dim, this->b_seq_length_dim) >= + std::tie(other.a_seq_length_dim, other.b_seq_length_dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchMatmulAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.a_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.b_seq_length_dim) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchMatmulAttrs + adl_serializer::from_json(json const &j) { + return {j.at("a_seq_length_dim").template get(), + j.at("b_seq_length_dim").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchMatmulAttrs const &v) { + j["__type"] = "BatchMatmulAttrs"; + j["a_seq_length_dim"] = v.a_seq_length_dim; + j["b_seq_length_dim"] = v.b_seq_length_dim; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchMatmulAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchMatmulAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..7be51efa22 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/batch_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(BatchNormAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc new file mode 100644 index 0000000000..cb8dcadae1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm_attrs.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "f8e0219d8a3e008a73c38cf84d25f66e" +} +*/ + +#include "op-attrs/ops/batch_norm_attrs.dtg.h" + +#include + +namespace FlexFlow { +BatchNormAttrs::BatchNormAttrs(bool const &relu) : relu(relu) {} +bool BatchNormAttrs::operator==(BatchNormAttrs const &other) const { + return std::tie(this->relu) == std::tie(other.relu); +} +bool BatchNormAttrs::operator!=(BatchNormAttrs const &other) const { + return std::tie(this->relu) != std::tie(other.relu); +} +bool BatchNormAttrs::operator<(BatchNormAttrs const &other) const { + return std::tie(this->relu) < std::tie(other.relu); +} +bool BatchNormAttrs::operator>(BatchNormAttrs const &other) const { + return std::tie(this->relu) > std::tie(other.relu); +} +bool BatchNormAttrs::operator<=(BatchNormAttrs const &other) const { + return std::tie(this->relu) <= std::tie(other.relu); +} +bool BatchNormAttrs::operator>=(BatchNormAttrs const &other) const { + return std::tie(this->relu) >= std::tie(other.relu); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BatchNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.relu) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BatchNormAttrs + adl_serializer::from_json(json const &j) { + return {j.at("relu").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BatchNormAttrs const &v) { + j["__type"] = "BatchNormAttrs"; + j["relu"] = v.relu; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BatchNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BatchNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc new file mode 100644 index 0000000000..ec08bd6a1d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.dtg.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +/* proj-data +{ + "generated_from": "12715c970e8416eacbd0750f338478e5" +} +*/ + +#include "op-attrs/ops/broadcast.dtg.h" + +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +BroadcastAttrs::BroadcastAttrs( + ::FlexFlow::stack_vector const &target_dims) + : target_dims(target_dims) {} +bool BroadcastAttrs::operator==(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) == std::tie(other.target_dims); +} +bool BroadcastAttrs::operator!=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) != std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) < std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) > std::tie(other.target_dims); +} +bool BroadcastAttrs::operator<=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) <= std::tie(other.target_dims); +} +bool BroadcastAttrs::operator>=(BroadcastAttrs const &other) const { + return std::tie(this->target_dims) >= std::tie(other.target_dims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::BroadcastAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::stack_vector>{}( + x.target_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::BroadcastAttrs + adl_serializer::from_json(json const &j) { + return {j.at("target_dims") + .template get<::FlexFlow::stack_vector>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::BroadcastAttrs const &v) { + j["__type"] = "BroadcastAttrs"; + j["target_dims"] = v.target_dims; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(BroadcastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, BroadcastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc similarity index 100% rename from lib/op-attrs/src/cast.cc rename to lib/op-attrs/src/op-attrs/ops/cast.cc diff --git a/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc new file mode 100644 index 0000000000..28367f3449 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/cast_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/cast_attrs.struct.toml +/* proj-data +{ + "generated_from": "c171c87db89b9ec9ea7d52a50c153054" +} +*/ + +#include "op-attrs/ops/cast_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include + +namespace FlexFlow { +CastAttrs::CastAttrs(DataType const &dtype) : dtype(dtype) {} +bool CastAttrs::operator==(CastAttrs const &other) const { + return std::tie(this->dtype) == std::tie(other.dtype); +} +bool CastAttrs::operator!=(CastAttrs const &other) const { + return std::tie(this->dtype) != std::tie(other.dtype); +} +bool CastAttrs::operator<(CastAttrs const &other) const { + return std::tie(this->dtype) < std::tie(other.dtype); +} +bool CastAttrs::operator>(CastAttrs const &other) const { + return std::tie(this->dtype) > std::tie(other.dtype); +} +bool CastAttrs::operator<=(CastAttrs const &other) const { + return std::tie(this->dtype) <= std::tie(other.dtype); +} +bool CastAttrs::operator>=(CastAttrs const &other) const { + return std::tie(this->dtype) >= std::tie(other.dtype); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::CastAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.dtype) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CastAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dtype").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CastAttrs const &v) { + j["__type"] = "CastAttrs"; + j["dtype"] = v.dtype; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(CastAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CastAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc new file mode 100644 index 0000000000..e41b78c5af --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -0,0 +1,37 @@ +#include "op-attrs/ops/combine.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(CombineAttrs const &attrs, + ParallelTensorShape const &input) { + ShardParallelDim input_dim = ({ + std::optional result = + try_get_shard_dim_at_idx(input, attrs.combine_dim); + if (!result.has_value()) { + return tl::unexpected(fmt::format( + "Failed to get shard dim at index {} in parallel tensor shape {}", + attrs.combine_dim, + input)); + } + + result.value(); + }); + + if (input_dim.degree % attrs.combine_degree != 0) { + return tl::unexpected( + fmt::format("Combine received tensor containing parallel dim {} with " + "degree {}, which is not divisible by combine degree {}", + attrs.combine_dim, + input_dim.degree, + attrs.combine_degree)); + } + + ParallelTensorShape output = input; + shard_dim_at_idx(output, attrs.combine_dim).degree /= attrs.combine_degree; + + return output; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc new file mode 100644 index 0000000000..516d3b0318 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/combine_attrs.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +/* proj-data +{ + "generated_from": "58fc5a388fd1a325ef4142094607e39a" +} +*/ + +#include "op-attrs/ops/combine_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +CombineAttrs::CombineAttrs(::FlexFlow::ff_dim_t const &combine_dim, + int const &combine_degree) + : combine_dim(combine_dim), combine_degree(combine_degree) {} +bool CombineAttrs::operator==(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) == + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator!=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) != + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) < + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) > + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator<=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) <= + std::tie(other.combine_dim, other.combine_degree); +} +bool CombineAttrs::operator>=(CombineAttrs const &other) const { + return std::tie(this->combine_dim, this->combine_degree) >= + std::tie(other.combine_dim, other.combine_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::CombineAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.combine_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.combine_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::CombineAttrs + adl_serializer::from_json(json const &j) { + return {j.at("combine_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("combine_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::CombineAttrs const &v) { + j["__type"] = "CombineAttrs"; + j["combine_dim"] = v.combine_dim; + j["combine_degree"] = v.combine_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(CombineAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, CombineAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc similarity index 100% rename from lib/op-attrs/src/concat.cc rename to lib/op-attrs/src/op-attrs/ops/concat.cc diff --git a/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc new file mode 100644 index 0000000000..20db25d485 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/concat_attrs.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +/* proj-data +{ + "generated_from": "68e0520b143e0579140a2f2cdd390759" +} +*/ + +#include "op-attrs/ops/concat_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +ConcatAttrs::ConcatAttrs(::FlexFlow::ff_dim_t const &axis, + int const &num_inputs) + : axis(axis), num_inputs(num_inputs) {} +bool ConcatAttrs::operator==(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) == + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator!=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) != + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) < + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) > + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator<=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) <= + std::tie(other.axis, other.num_inputs); +} +bool ConcatAttrs::operator>=(ConcatAttrs const &other) const { + return std::tie(this->axis, this->num_inputs) >= + std::tie(other.axis, other.num_inputs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConcatAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.num_inputs) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConcatAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>(), + j.at("num_inputs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConcatAttrs const &v) { + j["__type"] = "ConcatAttrs"; + j["axis"] = v.axis; + j["num_inputs"] = v.num_inputs; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ConcatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConcatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc new file mode 100644 index 0000000000..c9ec467af4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -0,0 +1,176 @@ +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +TensorShape get_kernel_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(attrs.out_channels), + input.num_channels, + size_t_from_int(attrs.kernel_h), + size_t_from_int(attrs.kernel_w), + }}, + input.datatype, + }; +} + +TensorShape get_bias_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + return TensorShape{ + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, + }, + input.datatype, + }; +} + +TensorShape get_output_shape(Conv2DAttrs const &attrs, + TensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DInputShape input = parse_input_shape(raw_input_shape); + + size_t out_height = + (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / + attrs.stride_h; + size_t out_width = + (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / + attrs.stride_w; + + assert(attrs.out_channels > 0); + + return TensorShape{TensorDims{FFOrdered{ + input.num_samples, + size_t_from_int(attrs.out_channels), + out_height, + out_width, + }}, + input.datatype}; +} + +ParallelTensorShape + get_kernel_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; + ShardParallelDim input_channels_dim = { + size_t_from_int(input.channel_dim.size), input.channel_dim.degree}; + ShardParallelDim kernel_height_dim = {size_t_from_int(attrs.kernel_h), 1}; + ShardParallelDim kernel_width_dim = {size_t_from_int(attrs.kernel_w), 1}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + input_channels_dim, + kernel_height_dim, + kernel_width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape get_bias_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + ShardParallelDim output_channels_dim = {size_t_from_int(attrs.out_channels), + input.discard_copy_reduction_degree}; + + int sum_degree = 1; + int discard_copy_degree = input.height_dim.degree * input.width_dim.degree * + input.sum_reduction_degree * + input.channel_dim.degree; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + output_channels_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +ParallelTensorShape + get_output_shape(Conv2DAttrs const &attrs, + ParallelTensorShape const &raw_input_shape) { + assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported + Conv2DParallelInputShape input = parse_parallel_input_shape(raw_input_shape); + + TensorShape unpar_output_shape = + get_output_shape(attrs, get_reduced_shape(raw_input_shape)); + + size_t num_samples = dim_at_idx(unpar_output_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(unpar_output_shape, ff_dim_t{1}); + size_t height = dim_at_idx(unpar_output_shape, ff_dim_t{2}); + size_t width = dim_at_idx(unpar_output_shape, ff_dim_t{3}); + + ShardParallelDim sample_dim = {num_samples, input.sample_dim.degree}; + ShardParallelDim channel_dim = {num_channels, + input.discard_copy_reduction_degree}; + ShardParallelDim height_dim = {height, input.height_dim.degree}; + ShardParallelDim width_dim = {width, input.width_dim.degree}; + + int sum_degree = input.channel_dim.degree * input.sum_reduction_degree; + int discard_copy_degree = 1; + + ParallelTensorShape result = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + sample_dim, + channel_dim, + height_dim, + width_dim, + }, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }, + }, + input.datatype, + }; + + assert(total_parallel_degree(result.dims) == + total_parallel_degree(raw_input_shape.dims)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc new file mode 100644 index 0000000000..a8a3b10bdf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -0,0 +1,23 @@ +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.h" +#include "op-attrs/tensor_shape.h" + +namespace FlexFlow { + +Conv2DInputShape parse_input_shape(TensorShape const &input) { + assert(num_dims(input) == 4); + + size_t num_samples = dim_at_idx(input, ff_dim_t{0}); + size_t in_channels = dim_at_idx(input, ff_dim_t{1}); + size_t in_height = dim_at_idx(input, ff_dim_t{2}); + size_t in_width = dim_at_idx(input, ff_dim_t{3}); + + return Conv2DInputShape{ + num_samples, + in_channels, + in_height, + in_width, + input.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc new file mode 100644 index 0000000000..74df30e2d7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.cc @@ -0,0 +1,157 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_input_shape.struct.toml +/* proj-data +{ + "generated_from": "51911f58c134d55b2d0245444acbae53" +} +*/ + +#include "op-attrs/ops/conv_2d/conv_2d_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include +#include + +namespace FlexFlow { +Conv2DInputShape::Conv2DInputShape(size_t const &num_samples, + size_t const &num_channels, + size_t const &height, + size_t const &width, + ::FlexFlow::DataType const &datatype) + : num_samples(num_samples), num_channels(num_channels), height(height), + width(width), datatype(datatype) {} +bool Conv2DInputShape::operator==(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) == std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator!=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) != std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) < std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) > std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator<=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) <= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +bool Conv2DInputShape::operator>=(Conv2DInputShape const &other) const { + return std::tie(this->num_samples, + this->num_channels, + this->height, + this->width, + this->datatype) >= std::tie(other.num_samples, + other.num_channels, + other.height, + other.width, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DInputShape const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_samples) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.height) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.width) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DInputShape + adl_serializer::from_json(json const &j) { + return {j.at("num_samples").template get(), + j.at("num_channels").template get(), + j.at("height").template get(), + j.at("width").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DInputShape const &v) { + j["__type"] = "Conv2DInputShape"; + j["num_samples"] = v.num_samples; + j["num_channels"] = v.num_channels; + j["height"] = v.height; + j["width"] = v.width; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc new file mode 100644 index 0000000000..32ac4547f1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -0,0 +1,26 @@ +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +Conv2DParallelInputShape + parse_parallel_input_shape(ParallelTensorShape const &input) { + assert(num_shard_dims(input) == 4); + + ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); + ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); + ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); + ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); + + return Conv2DParallelInputShape{ + sample_dim, + channel_dim, + height_dim, + width_dim, + get_sum_degree(input), + get_discard_copy_degree(input), + input.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc new file mode 100644 index 0000000000..df854c2b8f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.cc @@ -0,0 +1,211 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.struct.toml +/* proj-data +{ + "generated_from": "d80394bdc90f843372760310b6d17a22" +} +*/ + +#include "op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include + +namespace FlexFlow { +Conv2DParallelInputShape::Conv2DParallelInputShape( + ::FlexFlow::ShardParallelDim const &sample_dim, + ::FlexFlow::ShardParallelDim const &channel_dim, + ::FlexFlow::ShardParallelDim const &height_dim, + ::FlexFlow::ShardParallelDim const &width_dim, + int const &sum_reduction_degree, + int const &discard_copy_reduction_degree, + ::FlexFlow::DataType const &datatype) + : sample_dim(sample_dim), channel_dim(channel_dim), height_dim(height_dim), + width_dim(width_dim), sum_reduction_degree(sum_reduction_degree), + discard_copy_reduction_degree(discard_copy_reduction_degree), + datatype(datatype) {} +bool Conv2DParallelInputShape::operator==( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) == + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator!=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) != + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) < + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) > + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator<=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) <= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +bool Conv2DParallelInputShape::operator>=( + Conv2DParallelInputShape const &other) const { + return std::tie(this->sample_dim, + this->channel_dim, + this->height_dim, + this->width_dim, + this->sum_reduction_degree, + this->discard_copy_reduction_degree, + this->datatype) >= + std::tie(other.sample_dim, + other.channel_dim, + other.height_dim, + other.width_dim, + other.sum_reduction_degree, + other.discard_copy_reduction_degree, + other.datatype); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DParallelInputShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.sample_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.channel_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.height_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ShardParallelDim>{}(x.width_dim) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.sum_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.discard_copy_reduction_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.datatype) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DParallelInputShape + adl_serializer::from_json( + json const &j) { + return {j.at("sample_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("channel_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("height_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("width_dim").template get<::FlexFlow::ShardParallelDim>(), + j.at("sum_reduction_degree").template get(), + j.at("discard_copy_reduction_degree").template get(), + j.at("datatype").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DParallelInputShape const &v) { + j["__type"] = "Conv2DParallelInputShape"; + j["sample_dim"] = v.sample_dim; + j["channel_dim"] = v.channel_dim; + j["height_dim"] = v.height_dim; + j["width_dim"] = v.width_dim; + j["sum_reduction_degree"] = v.sum_reduction_degree; + j["discard_copy_reduction_degree"] = v.discard_copy_reduction_degree; + j["datatype"] = v.datatype; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary<::FlexFlow::ShardParallelDim>(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DParallelInputShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DParallelInputShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc new file mode 100644 index 0000000000..238b349cbe --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d_attrs.dtg.cc @@ -0,0 +1,256 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "74f98e1aacb57d847bb450e1d28d3e67" +} +*/ + +#include "op-attrs/ops/conv_2d_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "utils/json.h" +#include +#include + +namespace FlexFlow { +Conv2DAttrs::Conv2DAttrs( + int const &out_channels, + int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + int const &groups, + std::optional<::FlexFlow::Activation> const &activation, + bool const &use_bias) + : out_channels(out_channels), kernel_h(kernel_h), kernel_w(kernel_w), + stride_h(stride_h), stride_w(stride_w), padding_h(padding_h), + padding_w(padding_w), groups(groups), activation(activation), + use_bias(use_bias) {} +bool Conv2DAttrs::operator==(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) == std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator!=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) != std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) < std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) > std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator<=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) <= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +bool Conv2DAttrs::operator>=(Conv2DAttrs const &other) const { + return std::tie(this->out_channels, + this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->groups, + this->activation, + this->use_bias) >= std::tie(other.out_channels, + other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.groups, + other.activation, + other.use_bias); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Conv2DAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.groups) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Conv2DAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("out_channels").template get(), + j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("groups").template get(), + j.at("activation").template get>(), + j.at("use_bias").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Conv2DAttrs const &v) { + j["__type"] = "Conv2DAttrs"; + j["out_channels"] = v.out_channels; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["groups"] = v.groups; + j["activation"] = v.activation; + j["use_bias"] = v.use_bias; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Conv2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Conv2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout.cc b/lib/op-attrs/src/op-attrs/ops/dropout.cc new file mode 100644 index 0000000000..adbd144f38 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/dropout.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(DropoutAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc new file mode 100644 index 0000000000..284443a0e4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/dropout_attrs.dtg.cc @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/dropout_attrs.struct.toml +/* proj-data +{ + "generated_from": "4fdbf129ea59b8a7306813cfa4c46021" +} +*/ + +#include "op-attrs/ops/dropout_attrs.dtg.h" + +#include + +namespace FlexFlow { +DropoutAttrs::DropoutAttrs(float const &rate, unsigned long long const &seed) + : rate(rate), seed(seed) {} +bool DropoutAttrs::operator==(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) == std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator!=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) != std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) < std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) > std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator<=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) <= std::tie(other.rate, other.seed); +} +bool DropoutAttrs::operator>=(DropoutAttrs const &other) const { + return std::tie(this->rate, this->seed) >= std::tie(other.rate, other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DropoutAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.rate) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.seed) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DropoutAttrs + adl_serializer::from_json(json const &j) { + return {j.at("rate").template get(), + j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DropoutAttrs const &v) { + j["__type"] = "DropoutAttrs"; + j["rate"] = v.rate; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DropoutAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DropoutAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/src/op-attrs/ops/element_binary.cc new file mode 100644 index 0000000000..16957a036c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_binary.cc @@ -0,0 +1,75 @@ +#include "op-attrs/ops/element_binary.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ElementBinaryAttrs const &attrs, + TensorShape const &input_lhs, + TensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); + } + + return input_lhs; + } +} + +tl::expected + get_output_shape(ElementBinaryAttrs const &attrs, + ParallelTensorShape const &input_lhs, + ParallelTensorShape const &input_rhs) { + assert(!(attrs.should_broadcast_lhs && attrs.should_broadcast_rhs)); + + if (attrs.should_broadcast_lhs) { + NOT_IMPLEMENTED(); + } else if (attrs.should_broadcast_rhs) { + NOT_IMPLEMENTED(); + } else { + if (input_lhs != input_rhs) { + return tl::unexpected(fmt::format( + "Expected input shapes to match, but receieved LHS ({}) != RHS ({})", + input_lhs, + input_rhs)); + } + + switch (attrs.type) { + case OperatorType::EW_ADD: { + if (get_discard_copy_degree(input_lhs) != 1) { + return tl::unexpected( + fmt::format("Elementwise Add expected discard copy degree of " + "inputs to be 1, but receieved {}", + get_discard_copy_degree(input_lhs))); + } + + break; + } + case OperatorType::EW_SUB: + NOT_IMPLEMENTED(); + case OperatorType::EW_MUL: + NOT_IMPLEMENTED(); + case OperatorType::EW_DIV: + NOT_IMPLEMENTED(); + case OperatorType::EW_MAX: + NOT_IMPLEMENTED(); + case OperatorType::EW_MIN: + NOT_IMPLEMENTED(); + default: + return tl::unexpected(fmt::format( + "Unexpected element-wise binary operator {}", attrs.type)); + } + + return input_lhs; + } +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc new file mode 100644 index 0000000000..a0e555cb12 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_binary_attrs.dtg.cc @@ -0,0 +1,145 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_binary_attrs.struct.toml +/* proj-data +{ + "generated_from": "2bb947c9cc92e3833ee88c908c539629" +} +*/ + +#include "op-attrs/ops/element_binary_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementBinaryAttrs::ElementBinaryAttrs(::FlexFlow::OperatorType const &type, + ::FlexFlow::DataType const &compute_type, + bool const &should_broadcast_lhs, + bool const &should_broadcast_rhs) + : type(type), compute_type(compute_type), + should_broadcast_lhs(should_broadcast_lhs), + should_broadcast_rhs(should_broadcast_rhs) {} +bool ElementBinaryAttrs::operator==(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) == + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator!=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) != + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) < + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) > + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator<=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) <= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +bool ElementBinaryAttrs::operator>=(ElementBinaryAttrs const &other) const { + return std::tie(this->type, + this->compute_type, + this->should_broadcast_lhs, + this->should_broadcast_rhs) >= + std::tie(other.type, + other.compute_type, + other.should_broadcast_lhs, + other.should_broadcast_rhs); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementBinaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.compute_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_lhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.should_broadcast_rhs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementBinaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("type").template get<::FlexFlow::OperatorType>(), + j.at("compute_type").template get<::FlexFlow::DataType>(), + j.at("should_broadcast_lhs").template get(), + j.at("should_broadcast_rhs").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementBinaryAttrs const &v) { + j["__type"] = "ElementBinaryAttrs"; + j["type"] = v.type; + j["compute_type"] = v.compute_type; + j["should_broadcast_lhs"] = v.should_broadcast_lhs; + j["should_broadcast_rhs"] = v.should_broadcast_rhs; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementBinaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementBinaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc new file mode 100644 index 0000000000..ee85474caf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_scalar_unary_attrs.dtg.cc @@ -0,0 +1,98 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_scalar_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "aa6f98b992d46bdf7ad59158bc143a3f" +} +*/ + +#include "op-attrs/ops/element_scalar_unary_attrs.dtg.h" + +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementScalarUnaryAttrs::ElementScalarUnaryAttrs( + ::FlexFlow::OperatorType const &op_type, float const &scalar) + : op_type(op_type), scalar(scalar) {} +bool ElementScalarUnaryAttrs::operator==( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) == + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator!=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) != + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) < + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) > + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator<=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) <= + std::tie(other.op_type, other.scalar); +} +bool ElementScalarUnaryAttrs::operator>=( + ElementScalarUnaryAttrs const &other) const { + return std::tie(this->op_type, this->scalar) >= + std::tie(other.op_type, other.scalar); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementScalarUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.scalar) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementScalarUnaryAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("op_type").template get<::FlexFlow::OperatorType>(), + j.at("scalar").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementScalarUnaryAttrs const &v) { + j["__type"] = "ElementScalarUnaryAttrs"; + j["op_type"] = v.op_type; + j["scalar"] = v.scalar; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementScalarUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementScalarUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc new file mode 100644 index 0000000000..f703799ef3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -0,0 +1,54 @@ +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, + TensorShape const &input_shape) { + return input_shape; +} + +tl::expected + get_output_shape(ElementUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); + } + + return input_shape; +} + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + TensorShape const &input_shape) { + return input_shape; +} + +tl::expected + get_output_shape(ElementScalarUnaryAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape))); + } + + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc new file mode 100644 index 0000000000..bf90a3db7d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/element_unary_attrs.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +/* proj-data +{ + "generated_from": "75272cff78d3db866122dbb1001aedbe" +} +*/ + +#include "op-attrs/ops/element_unary_attrs.dtg.h" + +#include "op-attrs/operator_type.h" +#include + +namespace FlexFlow { +ElementUnaryAttrs::ElementUnaryAttrs(::FlexFlow::OperatorType const &op_type) + : op_type(op_type) {} +bool ElementUnaryAttrs::operator==(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) == std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator!=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) != std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) < std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) > std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator<=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) <= std::tie(other.op_type); +} +bool ElementUnaryAttrs::operator>=(ElementUnaryAttrs const &other) const { + return std::tie(this->op_type) >= std::tie(other.op_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ElementUnaryAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ElementUnaryAttrs + adl_serializer::from_json(json const &j) { + return {j.at("op_type").template get<::FlexFlow::OperatorType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ElementUnaryAttrs const &v) { + j["__type"] = "ElementUnaryAttrs"; + j["op_type"] = v.op_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ElementUnaryAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ElementUnaryAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc new file mode 100644 index 0000000000..9e9ad3a194 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -0,0 +1,112 @@ +#include "op-attrs/ops/embedding.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +static std::optional basic_check(EmbeddingAttrs const &attrs, + TensorShape const &input) { + if (input.data_type != DataType::INT32 && + input.data_type != DataType::INT64) { + return fmt::format("Embedding expected input tensor to have integer " + "datatype, but receieved tensor of datatype {}", + input.data_type); + } + + if (attrs.aggr != AggregateOp::SUM) { + return fmt::format(fmt::format( + "Currently unsupported aggregation op for embedding: {}", attrs.aggr)); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + TensorShape output = input; + dim_at_idx(output, ff_dim_t{-1}) = attrs.out_channels; + output.data_type = attrs.data_type; + return output; +} + +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, TensorShape const &input) { + { + std::optional err_msg = basic_check(attrs, input); + if (err_msg.has_value()) { + return tl::unexpected(err_msg.value()); + } + } + + return TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(attrs.num_entries), + size_t_from_int(attrs.out_channels), + }, + }, + attrs.data_type, + }; +} + +tl::expected + get_output_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { + + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = + transform(input.dims.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_weights_shape(EmbeddingAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_weights_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( + transform(ff_ordered_shard_dims(input.dims), + [](ShardParallelDim const &d) -> int { return d.degree; }))}; + int entry_dim_degree = 1; + int out_channel_degree = get_discard_copy_degree(input); + FFOrdered shard_degrees = { + entry_dim_degree, + out_channel_degree, + }; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc new file mode 100644 index 0000000000..b4d4657e08 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/embedding_attrs.dtg.cc @@ -0,0 +1,139 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +/* proj-data +{ + "generated_from": "f2bdea52e23dee6f674f598f8691d994" +} +*/ + +#include "op-attrs/ops/embedding_attrs.dtg.h" + +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +EmbeddingAttrs::EmbeddingAttrs( + int const &num_entries, + int const &out_channels, + std::optional<::FlexFlow::AggregateOp> const &aggr, + ::FlexFlow::DataType const &data_type) + : num_entries(num_entries), out_channels(out_channels), aggr(aggr), + data_type(data_type) {} +bool EmbeddingAttrs::operator==(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) == std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator!=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) != std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) < std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) > std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator<=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) <= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +bool EmbeddingAttrs::operator>=(EmbeddingAttrs const &other) const { + return std::tie(this->num_entries, + this->out_channels, + this->aggr, + this->data_type) >= std::tie(other.num_entries, + other.out_channels, + other.aggr, + other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::EmbeddingAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_entries) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.aggr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::EmbeddingAttrs + adl_serializer::from_json(json const &j) { + return {j.at("num_entries").template get(), + j.at("out_channels").template get(), + j.at("aggr").template get>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::EmbeddingAttrs const &v) { + j["__type"] = "EmbeddingAttrs"; + j["num_entries"] = v.num_entries; + j["out_channels"] = v.out_channels; + j["aggr"] = v.aggr; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(EmbeddingAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, EmbeddingAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc similarity index 94% rename from lib/op-attrs/src/flat.cc rename to lib/op-attrs/src/op-attrs/ops/flat.cc index 75d31beae4..b0683c5f08 100644 --- a/lib/op-attrs/src/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,6 +1,4 @@ #include "op-attrs/ops/flat.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc new file mode 100644 index 0000000000..ef34d97a89 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/flat_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +/* proj-data +{ + "generated_from": "b63924cd671481df30fae314a199c606" +} +*/ + +#include "op-attrs/ops/flat_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool FlatAttrs::operator==(FlatAttrs const &other) const { + return std::tie() == std::tie(); +} +bool FlatAttrs::operator!=(FlatAttrs const &other) const { + return std::tie() != std::tie(); +} +bool FlatAttrs::operator<(FlatAttrs const &other) const { + return std::tie() < std::tie(); +} +bool FlatAttrs::operator>(FlatAttrs const &other) const { + return std::tie() > std::tie(); +} +bool FlatAttrs::operator<=(FlatAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool FlatAttrs::operator>=(FlatAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::FlatAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::FlatAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::FlatAttrs const &v) { + j["__type"] = "FlatAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(FlatAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, FlatAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc similarity index 100% rename from lib/op-attrs/src/gather.cc rename to lib/op-attrs/src/op-attrs/ops/gather.cc diff --git a/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc new file mode 100644 index 0000000000..713c0f391e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/gather_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +/* proj-data +{ + "generated_from": "4ba46b6b494a7a52edda437d2a05fcf1" +} +*/ + +#include "op-attrs/ops/gather_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +GatherAttrs::GatherAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool GatherAttrs::operator==(GatherAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool GatherAttrs::operator!=(GatherAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool GatherAttrs::operator<(GatherAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool GatherAttrs::operator>(GatherAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool GatherAttrs::operator<=(GatherAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool GatherAttrs::operator>=(GatherAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GatherAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GatherAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GatherAttrs const &v) { + j["__type"] = "GatherAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(GatherAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GatherAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc new file mode 100644 index 0000000000..93606b603a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -0,0 +1,9 @@ +#include "op-attrs/ops/input.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(InputAttrs const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc new file mode 100644 index 0000000000..35544402f7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/input_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/input_attrs.struct.toml +/* proj-data +{ + "generated_from": "139ea46d57a3c8738b31b17a8c59a0aa" +} +*/ + +#include "op-attrs/ops/input_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool InputAttrs::operator==(InputAttrs const &other) const { + return std::tie() == std::tie(); +} +bool InputAttrs::operator!=(InputAttrs const &other) const { + return std::tie() != std::tie(); +} +bool InputAttrs::operator<(InputAttrs const &other) const { + return std::tie() < std::tie(); +} +bool InputAttrs::operator>(InputAttrs const &other) const { + return std::tie() > std::tie(); +} +bool InputAttrs::operator<=(InputAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool InputAttrs::operator>=(InputAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::InputAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::InputAttrs const &v) { + j["__type"] = "InputAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(InputAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InputAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc new file mode 100644 index 0000000000..437ba3638a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/layer_norm.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(LayerNormAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc new file mode 100644 index 0000000000..163f2e2f91 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm_attrs.dtg.cc @@ -0,0 +1,108 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +/* proj-data +{ + "generated_from": "349deae8d9356d3eeacd7e7d069c3155" +} +*/ + +#include "op-attrs/ops/layer_norm_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +LayerNormAttrs::LayerNormAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + bool const &elementwise_affine, + float const &eps) + : axes(axes), elementwise_affine(elementwise_affine), eps(eps) {} +bool LayerNormAttrs::operator==(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) == + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator!=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) != + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) < + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) > + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator<=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) <= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +bool LayerNormAttrs::operator>=(LayerNormAttrs const &other) const { + return std::tie(this->axes, this->elementwise_affine, this->eps) >= + std::tie(other.axes, other.elementwise_affine, other.eps); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerNormAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.elementwise_affine) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.eps) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerNormAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("elementwise_affine").template get(), + j.at("eps").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerNormAttrs const &v) { + j["__type"] = "LayerNormAttrs"; + j["axes"] = v.axes; + j["elementwise_affine"] = v.elementwise_affine; + j["eps"] = v.eps; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(LayerNormAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerNormAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc new file mode 100644 index 0000000000..8283673378 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -0,0 +1,110 @@ +#include "op-attrs/ops/linear.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); + + return TensorShape{ + TensorDims{ + FFOrdered{in_channels, size_t_from_int(attrs.out_channels)}, + }, + input_shape.data_type, + }; +} + +tl::expected + get_bias_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + return TensorShape{ + TensorDims{ + FFOrdered{size_t_from_int(attrs.out_channels)}, + }, + input_shape.data_type, + }; +} + +tl::expected + get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + TensorShape output_shape = input_shape; + output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = + size_t_from_int(attrs.out_channels); + + return output_shape; +} + +tl::expected + get_kernel_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_kernel_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = 1; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ + get_sum_degree(input) * + product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; + FFOrdered shard_degrees = FFOrdered{ + shard_dim_at_idx(input, ff_dim_t{-1}).degree, + get_discard_copy_degree(input), + }; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_bias_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = product( + slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1})); + FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +tl::expected + get_output_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + SumDegree sum_degree = + get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree; + DiscardCopyDegree discard_copy_degree = 1; + FFOrdered shard_degrees = ff_ordered_shard_degrees(input); + shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + + return lift_to_parallel_with_degrees( + unpar, sum_degree, discard_copy_degree, shard_degrees); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc new file mode 100644 index 0000000000..f3359da219 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/linear_attrs.dtg.cc @@ -0,0 +1,162 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +/* proj-data +{ + "generated_from": "7e82d282f90e08f1e0db7d5c4ce528b7" +} +*/ + +#include "op-attrs/ops/linear_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "utils/json.h" +#include + +namespace FlexFlow { +LinearAttrs::LinearAttrs( + int const &out_channels, + bool const &use_bias, + ::FlexFlow::DataType const &data_type, + std::optional<::FlexFlow::Activation> const &activation, + std::optional<::FlexFlow::RegularizerAttrs> const ®ularizer) + : out_channels(out_channels), use_bias(use_bias), data_type(data_type), + activation(activation), regularizer(regularizer) {} +bool LinearAttrs::operator==(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) == std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator!=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) != std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) < std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) > std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator<=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) <= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +bool LinearAttrs::operator>=(LinearAttrs const &other) const { + return std::tie(this->out_channels, + this->use_bias, + this->data_type, + this->activation, + this->regularizer) >= std::tie(other.out_channels, + other.use_bias, + other.data_type, + other.activation, + other.regularizer); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LinearAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.out_channels) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.use_bias) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.activation) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.regularizer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LinearAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("out_channels").template get(), + j.at("use_bias").template get(), + j.at("data_type").template get<::FlexFlow::DataType>(), + j.at("activation").template get>(), + j.at("regularizer") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LinearAttrs const &v) { + j["__type"] = "LinearAttrs"; + j["out_channels"] = v.out_channels; + j["use_bias"] = v.use_bias; + j["data_type"] = v.data_type; + j["activation"] = v.activation; + j["regularizer"] = v.regularizer; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::DataType>(), + gen::arbitrary>(), + gen::arbitrary>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(LinearAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LinearAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc new file mode 100644 index 0000000000..b2b15d820c --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/noop.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(NoopAttrs const &, + ParallelTensorShape const &input_shape) { + return input_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc new file mode 100644 index 0000000000..3ef3a0119b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/noop_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/noop_attrs.struct.toml +/* proj-data +{ + "generated_from": "d440077aa598fdad0e5aa95288b63c40" +} +*/ + +#include "op-attrs/ops/noop_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool NoopAttrs::operator==(NoopAttrs const &other) const { + return std::tie() == std::tie(); +} +bool NoopAttrs::operator!=(NoopAttrs const &other) const { + return std::tie() != std::tie(); +} +bool NoopAttrs::operator<(NoopAttrs const &other) const { + return std::tie() < std::tie(); +} +bool NoopAttrs::operator>(NoopAttrs const &other) const { + return std::tie() > std::tie(); +} +bool NoopAttrs::operator<=(NoopAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool NoopAttrs::operator>=(NoopAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::NoopAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NoopAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NoopAttrs const &v) { + j["__type"] = "NoopAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NoopAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NoopAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc new file mode 100644 index 0000000000..ac8da6d2d7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/parallel_attention_inputs.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/parallel_attention_inputs.struct.toml +/* proj-data +{ + "generated_from": "b76a39763275090d8376e1c27668d2cb" +} +*/ + +#include "op-attrs/ops/parallel_attention_inputs.dtg.h" + +#include "op-attrs/parallel_tensor_shape.h" +#include + +namespace FlexFlow { +ParallelMultiHeadAttentionInputs::ParallelMultiHeadAttentionInputs( + ::FlexFlow::ParallelTensorShape const &query, + ::FlexFlow::ParallelTensorShape const &key, + ::FlexFlow::ParallelTensorShape const &value) + : query(query), key(key), value(value) {} +bool ParallelMultiHeadAttentionInputs::operator==( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) == + std::tie(other.query, other.key, other.value); +} +bool ParallelMultiHeadAttentionInputs::operator!=( + ParallelMultiHeadAttentionInputs const &other) const { + return std::tie(this->query, this->key, this->value) != + std::tie(other.query, other.key, other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelMultiHeadAttentionInputs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.query) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.key) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelMultiHeadAttentionInputs + adl_serializer::from_json( + json const &j) { + return {j.at("query").template get<::FlexFlow::ParallelTensorShape>(), + j.at("key").template get<::FlexFlow::ParallelTensorShape>(), + j.at("value").template get<::FlexFlow::ParallelTensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelMultiHeadAttentionInputs const &v) { + j["__type"] = "ParallelMultiHeadAttentionInputs"; + j["query"] = v.query; + j["key"] = v.key; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>(), + gen::arbitrary<::FlexFlow::ParallelTensorShape>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelMultiHeadAttentionInputs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ParallelMultiHeadAttentionInputs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc similarity index 65% rename from lib/op-attrs/src/pool_2d.cc rename to lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 0867aeb344..cf6ed177d3 100644 --- a/lib/op-attrs/src/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,4 +1,16 @@ #include "op-attrs/ops/pool_2d.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(Pool2DAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow + +/* +#include "op-attrs/ops/pool_2d.h" #include "parallel_dim_mapping_record.h" #include "parallel_dim_mapping_record_solver.h" @@ -14,12 +26,11 @@ constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, REPLICA = 4; }; -/* bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ +bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { + ParallelTensorShape output_shape = this->calculate_output_shape(input); -/* return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); - */ -/* } */ + return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); +} static std::vector construct_mappings(ParallelTensorShape const &input_shape) { @@ -39,9 +50,9 @@ static ParallelDimMappingSolution return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); } -/* ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* return solve_mappings(input).output_shapes.at(0); */ -/* } */ +ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape +const &input) const { return solve_mappings(input).output_shapes.at(0); +} } // namespace FlexFlow +*/ diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc new file mode 100644 index 0000000000..8c445d8b84 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d_attrs.dtg.cc @@ -0,0 +1,214 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +/* proj-data +{ + "generated_from": "03aeafe335f68ff831e3e73a77f45caf" +} +*/ + +#include "op-attrs/ops/pool_2d_attrs.dtg.h" + +#include "op-attrs/activation.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include + +namespace FlexFlow { +Pool2DAttrs::Pool2DAttrs(int const &kernel_h, + int const &kernel_w, + int const &stride_h, + int const &stride_w, + int const &padding_h, + int const &padding_w, + ::FlexFlow::PoolOp const &pool_type, + ::FlexFlow::Activation const &activation) + : kernel_h(kernel_h), kernel_w(kernel_w), stride_h(stride_h), + stride_w(stride_w), padding_h(padding_h), padding_w(padding_w), + pool_type(pool_type), activation(activation) {} +bool Pool2DAttrs::operator==(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) == std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator!=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) != std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) < std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) > std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator<=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) <= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +bool Pool2DAttrs::operator>=(Pool2DAttrs const &other) const { + return std::tie(this->kernel_h, + this->kernel_w, + this->stride_h, + this->stride_w, + this->padding_h, + this->padding_w, + this->pool_type, + this->activation) >= std::tie(other.kernel_h, + other.kernel_w, + other.stride_h, + other.stride_w, + other.padding_h, + other.padding_w, + other.pool_type, + other.activation); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::Pool2DAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.kernel_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.kernel_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_h) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride_w) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.padding_h) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.padding_w) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::PoolOp>{}(x.pool_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::Activation>{}(x.activation) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::Pool2DAttrs + adl_serializer::from_json(json const &j) { + return {j.at("kernel_h").template get(), + j.at("kernel_w").template get(), + j.at("stride_h").template get(), + j.at("stride_w").template get(), + j.at("padding_h").template get(), + j.at("padding_w").template get(), + j.at("pool_type").template get<::FlexFlow::PoolOp>(), + j.at("activation").template get<::FlexFlow::Activation>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::Pool2DAttrs const &v) { + j["__type"] = "Pool2DAttrs"; + j["kernel_h"] = v.kernel_h; + j["kernel_w"] = v.kernel_w; + j["stride_h"] = v.stride_h; + j["stride_w"] = v.stride_w; + j["padding_h"] = v.padding_h; + j["padding_w"] = v.padding_w; + j["pool_type"] = v.pool_type; + j["activation"] = v.activation; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary<::FlexFlow::PoolOp>(), + gen::arbitrary<::FlexFlow::Activation>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(Pool2DAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, Pool2DAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce.cc b/lib/op-attrs/src/op-attrs/ops/reduce.cc new file mode 100644 index 0000000000..2a8bf06ecf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reduce.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReduceAttrs const &, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc new file mode 100644 index 0000000000..2aa9546956 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduce_attrs.dtg.cc @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +/* proj-data +{ + "generated_from": "097463446e254f662c7bdf5df4e12d17" +} +*/ + +#include "op-attrs/ops/reduce_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "op-attrs/operator_type.dtg.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +ReduceAttrs::ReduceAttrs( + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM> const &axes, + ::FlexFlow::OperatorType const &op_type, + bool const &keepdims) + : axes(axes), op_type(op_type), keepdims(keepdims) {} +bool ReduceAttrs::operator==(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) == + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator!=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) != + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) < + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) > + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator<=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) <= + std::tie(other.axes, other.op_type, other.keepdims); +} +bool ReduceAttrs::operator>=(ReduceAttrs const &other) const { + return std::tie(this->axes, this->op_type, this->keepdims) >= + std::tie(other.axes, other.op_type, other.keepdims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReduceAttrs const &x) const { + size_t result = 0; + result ^= + std::hash< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>{}( + x.axes) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorType>{}(x.op_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.keepdims) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReduceAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("axes") + .template get< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + j.at("op_type").template get<::FlexFlow::OperatorType>(), + j.at("keepdims").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReduceAttrs const &v) { + j["__type"] = "ReduceAttrs"; + j["axes"] = v.axes; + j["op_type"] = v.op_type; + j["keepdims"] = v.keepdims; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::stack_vector<::FlexFlow::ff_dim_t, MAX_TENSOR_DIM>>(), + gen::arbitrary<::FlexFlow::OperatorType>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReduceAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReduceAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction.cc b/lib/op-attrs/src/op-attrs/ops/reduction.cc new file mode 100644 index 0000000000..0fef6f37d6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction.cc @@ -0,0 +1,22 @@ +#include "op-attrs/ops/reduction.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(ReductionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) % attrs.reduction_degree != 0) { + return tl::unexpected( + fmt::format("Reduction received tensor with sum degree {}, which is " + "not divisible by reduction degree {}", + get_sum_degree(input_shape), + attrs.reduction_degree)); + } + + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.sum_degree.value /= attrs.reduction_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc new file mode 100644 index 0000000000..2f1550bb66 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reduction_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reduction_attrs.struct.toml +/* proj-data +{ + "generated_from": "1d2b5b7cf11ed04a27a6fd8215e4e2a5" +} +*/ + +#include "op-attrs/ops/reduction_attrs.dtg.h" + +#include + +namespace FlexFlow { +ReductionAttrs::ReductionAttrs(int const &reduction_degree) + : reduction_degree(reduction_degree) {} +bool ReductionAttrs::operator==(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) == std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator!=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) != std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator<(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) < std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator>(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) > std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator<=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) <= std::tie(other.reduction_degree); +} +bool ReductionAttrs::operator>=(ReductionAttrs const &other) const { + return std::tie(this->reduction_degree) >= std::tie(other.reduction_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReductionAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.reduction_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReductionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("reduction_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReductionAttrs const &v) { + j["__type"] = "ReductionAttrs"; + j["reduction_degree"] = v.reduction_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReductionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReductionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc new file mode 100644 index 0000000000..37a0b8a168 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -0,0 +1,14 @@ +#include "op-attrs/ops/repartition.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(RepartitionAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= + attrs.repartition_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc new file mode 100644 index 0000000000..6270298c87 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/repartition_attrs.dtg.cc @@ -0,0 +1,93 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +/* proj-data +{ + "generated_from": "0a4d8b435768ce3ee37013fc550c9ebb" +} +*/ + +#include "op-attrs/ops/repartition_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +RepartitionAttrs::RepartitionAttrs(::FlexFlow::ff_dim_t const &repartition_dim, + int const &repartition_degree) + : repartition_dim(repartition_dim), repartition_degree(repartition_degree) { +} +bool RepartitionAttrs::operator==(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) == + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator!=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) != + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) < + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) > + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator<=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) <= + std::tie(other.repartition_dim, other.repartition_degree); +} +bool RepartitionAttrs::operator>=(RepartitionAttrs const &other) const { + return std::tie(this->repartition_dim, this->repartition_degree) >= + std::tie(other.repartition_dim, other.repartition_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::RepartitionAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.repartition_dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.repartition_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::RepartitionAttrs + adl_serializer::from_json(json const &j) { + return {j.at("repartition_dim").template get<::FlexFlow::ff_dim_t>(), + j.at("repartition_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::RepartitionAttrs const &v) { + j["__type"] = "RepartitionAttrs"; + j["repartition_dim"] = v.repartition_dim; + j["repartition_degree"] = v.repartition_degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(RepartitionAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, RepartitionAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate.cc b/lib/op-attrs/src/op-attrs/ops/replicate.cc new file mode 100644 index 0000000000..9e163cb55a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate.cc @@ -0,0 +1,13 @@ +#include "op-attrs/ops/replicate.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReplicateAttrs const &attrs, + ParallelTensorShape const &input_shape) { + ParallelTensorShape output_shape = input_shape; + output_shape.dims.replica_dims.discard_copy_degree.value *= + attrs.replicate_degree; + return output_shape; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc new file mode 100644 index 0000000000..930c5beaf4 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/replicate_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/replicate_attrs.struct.toml +/* proj-data +{ + "generated_from": "6d3ad4d10c24dae819ffee4592a72499" +} +*/ + +#include "op-attrs/ops/replicate_attrs.dtg.h" + +#include + +namespace FlexFlow { +ReplicateAttrs::ReplicateAttrs(int const &replicate_degree) + : replicate_degree(replicate_degree) {} +bool ReplicateAttrs::operator==(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) == std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator!=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) != std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator<(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) < std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator>(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) > std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator<=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) <= std::tie(other.replicate_degree); +} +bool ReplicateAttrs::operator>=(ReplicateAttrs const &other) const { + return std::tie(this->replicate_degree) >= std::tie(other.replicate_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicateAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.replicate_degree) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicateAttrs + adl_serializer::from_json(json const &j) { + return {j.at("replicate_degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicateAttrs const &v) { + j["__type"] = "ReplicateAttrs"; + j["replicate_degree"] = v.replicate_degree; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicateAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicateAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc new file mode 100644 index 0000000000..7d0600550a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reshape.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc new file mode 100644 index 0000000000..b1fb350b88 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reshape_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reshape_attrs.struct.toml +/* proj-data +{ + "generated_from": "015d04de0ccb982e7eaa013a842880ca" +} +*/ + +#include "op-attrs/ops/reshape_attrs.dtg.h" + +#include "op-attrs/tensor_shape.dtg.h" +#include + +namespace FlexFlow { +ReshapeAttrs::ReshapeAttrs(::FlexFlow::TensorShape const &shape) + : shape(shape) {} +bool ReshapeAttrs::operator==(ReshapeAttrs const &other) const { + return std::tie(this->shape) == std::tie(other.shape); +} +bool ReshapeAttrs::operator!=(ReshapeAttrs const &other) const { + return std::tie(this->shape) != std::tie(other.shape); +} +bool ReshapeAttrs::operator<(ReshapeAttrs const &other) const { + return std::tie(this->shape) < std::tie(other.shape); +} +bool ReshapeAttrs::operator>(ReshapeAttrs const &other) const { + return std::tie(this->shape) > std::tie(other.shape); +} +bool ReshapeAttrs::operator<=(ReshapeAttrs const &other) const { + return std::tie(this->shape) <= std::tie(other.shape); +} +bool ReshapeAttrs::operator>=(ReshapeAttrs const &other) const { + return std::tie(this->shape) >= std::tie(other.shape); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReshapeAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReshapeAttrs + adl_serializer::from_json(json const &j) { + return {j.at("shape").template get<::FlexFlow::TensorShape>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReshapeAttrs const &v) { + j["__type"] = "ReshapeAttrs"; + j["shape"] = v.shape; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorShape>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReshapeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReshapeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc new file mode 100644 index 0000000000..79b5bd50fb --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/reverse.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc new file mode 100644 index 0000000000..9ac9abeb82 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/reverse_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +/* proj-data +{ + "generated_from": "c5a82c8a15ac3ce6f47dc054236ab69b" +} +*/ + +#include "op-attrs/ops/reverse_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +ReverseAttrs::ReverseAttrs(::FlexFlow::ff_dim_t const &axis) : axis(axis) {} +bool ReverseAttrs::operator==(ReverseAttrs const &other) const { + return std::tie(this->axis) == std::tie(other.axis); +} +bool ReverseAttrs::operator!=(ReverseAttrs const &other) const { + return std::tie(this->axis) != std::tie(other.axis); +} +bool ReverseAttrs::operator<(ReverseAttrs const &other) const { + return std::tie(this->axis) < std::tie(other.axis); +} +bool ReverseAttrs::operator>(ReverseAttrs const &other) const { + return std::tie(this->axis) > std::tie(other.axis); +} +bool ReverseAttrs::operator<=(ReverseAttrs const &other) const { + return std::tie(this->axis) <= std::tie(other.axis); +} +bool ReverseAttrs::operator>=(ReverseAttrs const &other) const { + return std::tie(this->axis) >= std::tie(other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReverseAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReverseAttrs + adl_serializer::from_json(json const &j) { + return {j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReverseAttrs const &v) { + j["__type"] = "ReverseAttrs"; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReverseAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReverseAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc new file mode 100644 index 0000000000..2d870af50e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/softmax.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc new file mode 100644 index 0000000000..4941b7438a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/softmax_attrs.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +/* proj-data +{ + "generated_from": "2ddf5a8b7daa32a43387f5fd5866bb3b" +} +*/ + +#include "op-attrs/ops/softmax_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +SoftmaxAttrs::SoftmaxAttrs(::FlexFlow::ff_dim_t const &dim) : dim(dim) {} +bool SoftmaxAttrs::operator==(SoftmaxAttrs const &other) const { + return std::tie(this->dim) == std::tie(other.dim); +} +bool SoftmaxAttrs::operator!=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) != std::tie(other.dim); +} +bool SoftmaxAttrs::operator<(SoftmaxAttrs const &other) const { + return std::tie(this->dim) < std::tie(other.dim); +} +bool SoftmaxAttrs::operator>(SoftmaxAttrs const &other) const { + return std::tie(this->dim) > std::tie(other.dim); +} +bool SoftmaxAttrs::operator<=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) <= std::tie(other.dim); +} +bool SoftmaxAttrs::operator>=(SoftmaxAttrs const &other) const { + return std::tie(this->dim) >= std::tie(other.dim); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SoftmaxAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.dim) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SoftmaxAttrs + adl_serializer::from_json(json const &j) { + return {j.at("dim").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SoftmaxAttrs const &v) { + j["__type"] = "SoftmaxAttrs"; + j["dim"] = v.dim; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SoftmaxAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SoftmaxAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc new file mode 100644 index 0000000000..cfb4071833 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -0,0 +1,11 @@ +#include "op-attrs/ops/split.h" + +namespace FlexFlow { + +std::vector + get_output_shapes(SplitAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc new file mode 100644 index 0000000000..c6f7e75dbf --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/split_attrs.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +/* proj-data +{ + "generated_from": "cde6b5caf6739d3b02fe8fce0d8ae8c5" +} +*/ + +#include "op-attrs/ops/split_attrs.dtg.h" + +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include "utils/stack_vector.h" +#include + +namespace FlexFlow { +SplitAttrs::SplitAttrs( + ::FlexFlow::stack_vector const &splits, + ::FlexFlow::ff_dim_t const &axis) + : splits(splits), axis(axis) {} +bool SplitAttrs::operator==(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) == + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator!=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) != + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) < + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) > + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator<=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) <= + std::tie(other.splits, other.axis); +} +bool SplitAttrs::operator>=(SplitAttrs const &other) const { + return std::tie(this->splits, this->axis) >= + std::tie(other.splits, other.axis); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SplitAttrs const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::stack_vector>{}(x.splits) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ff_dim_t>{}(x.axis) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SplitAttrs + adl_serializer::from_json(json const &j) { + return {j.at("splits") + .template get<::FlexFlow::stack_vector>(), + j.at("axis").template get<::FlexFlow::ff_dim_t>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SplitAttrs const &v) { + j["__type"] = "SplitAttrs"; + j["splits"] = v.splits; + j["axis"] = v.axis; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::stack_vector>(), + gen::arbitrary<::FlexFlow::ff_dim_t>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SplitAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SplitAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc new file mode 100644 index 0000000000..9d2fd35a94 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/topk.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TopKAttrs const &attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc new file mode 100644 index 0000000000..55ead7d858 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/topk_attrs.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/topk_attrs.struct.toml +/* proj-data +{ + "generated_from": "c1be9dc2acafc58690713e650663cc93" +} +*/ + +#include "op-attrs/ops/topk_attrs.dtg.h" + +#include + +namespace FlexFlow { +TopKAttrs::TopKAttrs(int const &k, bool const &sorted) : k(k), sorted(sorted) {} +bool TopKAttrs::operator==(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) == std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator!=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) != std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) < std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) > std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator<=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) <= std::tie(other.k, other.sorted); +} +bool TopKAttrs::operator>=(TopKAttrs const &other) const { + return std::tie(this->k, this->sorted) >= std::tie(other.k, other.sorted); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::TopKAttrs const &x) const { + size_t result = 0; + result ^= std::hash{}(x.k) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.sorted) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TopKAttrs + adl_serializer::from_json(json const &j) { + return {j.at("k").template get(), j.at("sorted").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TopKAttrs const &v) { + j["__type"] = "TopKAttrs"; + j["k"] = v.k; + j["sorted"] = v.sorted; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TopKAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TopKAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc new file mode 100644 index 0000000000..75f7eb3c18 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -0,0 +1,10 @@ +#include "op-attrs/ops/transpose.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, + ParallelTensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc new file mode 100644 index 0000000000..0a774b992e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/transpose_attrs.dtg.cc @@ -0,0 +1,82 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +/* proj-data +{ + "generated_from": "de62a505821a59c4b77197c100e204f7" +} +*/ + +#include "op-attrs/ops/transpose_attrs.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim.h" +#include + +namespace FlexFlow { +TransposeAttrs::TransposeAttrs( + ::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t> const &perm) + : perm(perm) {} +bool TransposeAttrs::operator==(TransposeAttrs const &other) const { + return std::tie(this->perm) == std::tie(other.perm); +} +bool TransposeAttrs::operator!=(TransposeAttrs const &other) const { + return std::tie(this->perm) != std::tie(other.perm); +} +bool TransposeAttrs::operator<(TransposeAttrs const &other) const { + return std::tie(this->perm) < std::tie(other.perm); +} +bool TransposeAttrs::operator>(TransposeAttrs const &other) const { + return std::tie(this->perm) > std::tie(other.perm); +} +bool TransposeAttrs::operator<=(TransposeAttrs const &other) const { + return std::tie(this->perm) <= std::tie(other.perm); +} +bool TransposeAttrs::operator>=(TransposeAttrs const &other) const { + return std::tie(this->perm) >= std::tie(other.perm); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TransposeAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>{}(x.perm) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TransposeAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("perm").template get<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TransposeAttrs const &v) { + j["__type"] = "TransposeAttrs"; + j["perm"] = v.perm; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ff_dim_t>>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TransposeAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TransposeAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc new file mode 100644 index 0000000000..a288161da2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/weight_attrs.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +/* proj-data +{ + "generated_from": "59f49374ffca95b2117b8940af1b6cac" +} +*/ + +#include "op-attrs/ops/weight_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool WeightAttrs::operator==(WeightAttrs const &other) const { + return std::tie() == std::tie(); +} +bool WeightAttrs::operator!=(WeightAttrs const &other) const { + return std::tie() != std::tie(); +} +bool WeightAttrs::operator<(WeightAttrs const &other) const { + return std::tie() < std::tie(); +} +bool WeightAttrs::operator>(WeightAttrs const &other) const { + return std::tie() > std::tie(); +} +bool WeightAttrs::operator<=(WeightAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool WeightAttrs::operator>=(WeightAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::WeightAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::WeightAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::WeightAttrs const &v) { + j["__type"] = "WeightAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(WeightAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, WeightAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc new file mode 100644 index 0000000000..886893c90a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_dim.dtg.cc @@ -0,0 +1,116 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_dim.variant.toml +/* proj-data +{ + "generated_from": "f382ff547aae62777e5091f00d034d84" +} +*/ + +#include "op-attrs/parallel_dim.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +ParallelDim::ParallelDim(::FlexFlow::ShardParallelDim const &v) + : raw_variant(v) {} +ParallelDim::ParallelDim(::FlexFlow::ReplicaParallelDim const &v) + : raw_variant(v) {} +bool ParallelDim::operator==(ParallelDim const &other) const { + return this->raw_variant == other.raw_variant; +} +bool ParallelDim::operator!=(ParallelDim const &other) const { + return this->raw_variant != other.raw_variant; +} +bool ParallelDim::operator<(ParallelDim const &other) const { + return this->raw_variant < other.raw_variant; +} +bool ParallelDim::operator>(ParallelDim const &other) const { + return this->raw_variant > other.raw_variant; +} +bool ParallelDim::operator<=(ParallelDim const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool ParallelDim::operator>=(ParallelDim const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::ParallelDim>::operator()( + ::FlexFlow::ParallelDim const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::ParallelDim + adl_serializer<::FlexFlow::ParallelDim>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "shard_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ShardParallelDim>()}; + } else if (key == "replica_dim") { + return ::FlexFlow::ParallelDim{ + j.at("value").template get<::FlexFlow::ReplicaParallelDim>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::ParallelDim>::to_json( + json &j, ::FlexFlow::ParallelDim const &x) { + j["__type"] = "ParallelDim"; + switch (x.index()) { + case 0: { + j["type"] = "shard_dim"; + j["value"] = x.get<::FlexFlow::ShardParallelDim>(); + break; + } + case 1: { + j["type"] = "replica_dim"; + j["value"] = x.get<::FlexFlow::ReplicaParallelDim>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::ParallelDim> Arbitrary<::FlexFlow::ParallelDim>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ShardParallelDim>()), + gen::construct<::FlexFlow::ParallelDim>( + gen::arbitrary<::FlexFlow::ReplicaParallelDim>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::ParallelDim const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type ParallelDim", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ::FlexFlow::ParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc new file mode 100644 index 0000000000..ff5a8224df --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -0,0 +1,72 @@ +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/replica_parallel_dim.h" +#include "op-attrs/replica_parallel_dim_set.h" +#include "op-attrs/shard_parallel_dim.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +FFOrdered ff_ordered_shard_dims(ParallelTensorDims const &d) { + return d.shard_dims; +} + +FFOrdered ff_ordered_shard_degrees(ParallelTensorDims const &d) { + return transform(d.shard_dims, + [](ShardParallelDim const &d) { return d.degree; }); +} + +std::unordered_set + replica_dims(ParallelTensorDims const &d) { + return get_replica_dims(d.replica_dims); +} + +size_t num_shard_dims(ParallelTensorDims const &dims) { + return dims.shard_dims.size(); +} + +int total_replica_degree(ParallelTensorDims const &dims) { + return dims.replica_dims.discard_copy_degree.value * + dims.replica_dims.sum_degree.value; +} + +int total_shard_degree(ParallelTensorDims const &dims) { + return product(transform(as_vector(dims.shard_dims), + [](ShardParallelDim const &d) { return d.degree; })); +} + +int total_parallel_degree(ParallelTensorDims const &dims) { + return total_replica_degree(dims) * total_shard_degree(dims); +} + +bool is_valid(ParallelTensorDims const &dims) { + return all_of(dims.shard_dims, + [](ShardParallelDim const &d) { return is_valid(d); }) && + all_of(replica_dims(dims), + [](ReplicaParallelDim const &d) { return is_valid(d); }); +} + +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { + return d.shard_dims.at(idx); +} + +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { + return d.shard_dims.at(idx); +} + +TensorDims get_piece_dims(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + +TensorDims get_tensor_dims_unsafe(ParallelTensorDims const &) { + NOT_IMPLEMENTED(); +} + +TensorDims get_reduced_dims(ParallelTensorDims const &dims) { + FFOrdered dim_sizes = transform( + dims.shard_dims, [](ShardParallelDim const &d) { return d.size; }); + return TensorDims{dim_sizes}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc new file mode 100644 index 0000000000..40be73cb9f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "aec3b6b66e34be0d5ce3055822479430" +} +*/ + +#include "op-attrs/parallel_tensor_dims.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "op-attrs/replica_parallel_dim_set.dtg.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/unordered_map.h" +#include +#include + +namespace FlexFlow { +ParallelTensorDims::ParallelTensorDims( + ::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim> const &shard_dims, + ::FlexFlow::ReplicaParallelDimSet const &replica_dims) + : shard_dims(shard_dims), replica_dims(replica_dims) {} +bool ParallelTensorDims::operator==(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) == + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator!=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) != + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator<(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) < + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) > + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator<=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) <= + std::tie(other.shard_dims, other.replica_dims); +} +bool ParallelTensorDims::operator>=(ParallelTensorDims const &other) const { + return std::tie(this->shard_dims, this->replica_dims) >= + std::tie(other.shard_dims, other.replica_dims); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorDims const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>{}( + x.shard_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaParallelDimSet>{}(x.replica_dims) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorDims + adl_serializer::from_json(json const &j) { + return { + j.at("shard_dims") + .template get<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + j.at("replica_dims").template get<::FlexFlow::ReplicaParallelDimSet>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorDims const &v) { + j["__type"] = "ParallelTensorDims"; + j["shard_dims"] = v.shard_dims; + j["replica_dims"] = v.replica_dims; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered<::FlexFlow::ShardParallelDim>>(), + gen::arbitrary<::FlexFlow::ReplicaParallelDimSet>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc new file mode 100644 index 0000000000..516cbe191f --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -0,0 +1,87 @@ +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers.h" +#include "utils/hash-utils.h" + +namespace FlexFlow { + +int num_shard_dims(ParallelTensorShape const &s) { + return num_shard_dims(s.dims); +} + +std::unordered_set + replica_dims(ParallelTensorShape const &s) { + return replica_dims(s.dims); +} + +int get_num_replicas(ParallelTensorShape const &shape) { + return product( + transform(replica_dims(shape), + [](ReplicaParallelDim const &d) -> int { return d.degree; })); +} + +int get_sum_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.sum_degree.value; +} + +int get_discard_copy_degree(ParallelTensorShape const &shape) { + return shape.dims.replica_dims.discard_copy_degree.value; +} + +int get_total_parallel_degree(ParallelTensorShape const &s) { + return total_parallel_degree(s.dims); +} + +bool is_valid(ParallelTensorShape const &shape) { + return is_valid(shape.dims); +} + +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); +} + +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { + return shard_dim_at_idx(s.dims, d); +} + +FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &s) { + return ff_ordered_shard_degrees(s.dims); +} + +std::optional + try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + if (s.dims.shard_dims.idx_is_valid(d)) { + return s.dims.shard_dims.at(d); + } else { + return std::nullopt; + } +} + +ParallelTensorShape lift_to_parallel(TensorShape const &s) { + return {lift_to_parallel(s.dims), s.data_type}; +} + +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { + return ParallelTensorShape{ + lift_to_parallel_with_degrees( + s.dims, sum_degree, discard_copy_degree, shard_degrees), + s.data_type, + }; +} + +TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +TensorShape get_reduced_shape(ParallelTensorShape const &s) { + return TensorShape{ + get_reduced_dims(s.dims), + s.data_type, + }; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc new file mode 100644 index 0000000000..1fe82ce108 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.dtg.cc @@ -0,0 +1,94 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "06d657d1e95f34aebf4b721c768cbee8" +} +*/ + +#include "op-attrs/parallel_tensor_shape.dtg.h" + +#include "op-attrs/datatype.h" +#include "op-attrs/parallel_tensor_dims.h" +#include + +namespace FlexFlow { +ParallelTensorShape::ParallelTensorShape( + ::FlexFlow::ParallelTensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool ParallelTensorShape::operator==(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator!=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator<=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool ParallelTensorShape::operator>=(ParallelTensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::ParallelTensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorShape const &v) { + j["__type"] = "ParallelTensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::ParallelTensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ParallelTensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc new file mode 100644 index 0000000000..4547a5df9b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/discard_copy_degree.struct.toml +/* proj-data +{ + "generated_from": "e4677d1fb25d3833570ee567f5659914" +} +*/ + +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" + +#include + +namespace FlexFlow { +DiscardCopyDegree::DiscardCopyDegree(int const &value) : value(value) {} +bool DiscardCopyDegree::operator==(DiscardCopyDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool DiscardCopyDegree::operator!=(DiscardCopyDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool DiscardCopyDegree::operator<(DiscardCopyDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool DiscardCopyDegree::operator>(DiscardCopyDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool DiscardCopyDegree::operator<=(DiscardCopyDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool DiscardCopyDegree::operator>=(DiscardCopyDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DiscardCopyDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::DiscardCopyDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::DiscardCopyDegree const &v) { + j["__type"] = "DiscardCopyDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(DiscardCopyDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, DiscardCopyDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc new file mode 100644 index 0000000000..cf159a1ea7 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape/sum_degree.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/parallel_tensor_shape/sum_degree.struct.toml +/* proj-data +{ + "generated_from": "e94a05618f2ad92dd7b3328a1d9c6786" +} +*/ + +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" + +#include + +namespace FlexFlow { +SumDegree::SumDegree(int const &value) : value(value) {} +bool SumDegree::operator==(SumDegree const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool SumDegree::operator!=(SumDegree const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool SumDegree::operator<(SumDegree const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool SumDegree::operator>(SumDegree const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool SumDegree::operator<=(SumDegree const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool SumDegree::operator>=(SumDegree const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t + hash::operator()(FlexFlow::SumDegree const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.value) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SumDegree + adl_serializer::from_json(json const &j) { + return {j.at("value").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SumDegree const &v) { + j["__type"] = "SumDegree"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SumDegree const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SumDegree const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/param_sync.dtg.cc b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc new file mode 100644 index 0000000000..e0d13fdd2e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/param_sync.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/param_sync.enum.toml +/* proj-data +{ + "generated_from": "288c6e9e256cf58ba5dbd0e3791c08df" +} +*/ + +#include "op-attrs/param_sync.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ParamSync x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ParamSync x) { + switch (x) { + case ParamSync::PS: + return "PS"; + case ParamSync::NCCL: + return "NCCL"; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ParamSync x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ParamSync x) { + switch (x) { + case ParamSync::PS: + j = "PS"; + break; + case ParamSync::NCCL: + j = "NCCL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ParamSync value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ParamSync &x) { + std::string as_str = j.get(); + if (as_str == "PS") { + x = ParamSync::PS; + } else if (as_str == "NCCL") { + x = ParamSync::NCCL; + } else { + std::ostringstream oss; + oss << "Unknown ParamSync value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::ParamSync::PS, + FlexFlow::ParamSync::NCCL); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..76ad48d471 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,16 @@ +#include "op-attrs/pcg_operator_attrs.h" +#include "op-attrs/get_op_type.h" + +namespace FlexFlow { + +bool is_parallel_op(PCGOperatorAttrs const &attrs) { + return (attrs.has() || attrs.has() || + attrs.has() || attrs.has()); +} + +OperatorType get_op_type(PCGOperatorAttrs const &attrs) { + return attrs.visit( + [](auto const &x) { return get_op_type(x); }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc new file mode 100644 index 0000000000..5334c8a7ab --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.dtg.cc @@ -0,0 +1,599 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +/* proj-data +{ + "generated_from": "e1d10b0c7c98524c27886bdae0972321" +} +*/ + +#include "op-attrs/pcg_operator_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchMatmulAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::BatchNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CastAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::CombineAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ConcatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Conv2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::DropoutAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementBinaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ElementScalarUnaryAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::EmbeddingAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::FlatAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::GatherAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::InputAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LayerNormAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::LinearAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::MultiHeadAttentionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::NoopAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::Pool2DAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReduceAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReductionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::RepartitionAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReplicateAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReverseAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::ReshapeAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SplitAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::SoftmaxAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TopKAttrs const &v) + : raw_variant(v) {} +PCGOperatorAttrs::PCGOperatorAttrs(::FlexFlow::TransposeAttrs const &v) + : raw_variant(v) {} +bool PCGOperatorAttrs::operator==(PCGOperatorAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool PCGOperatorAttrs::operator!=(PCGOperatorAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool PCGOperatorAttrs::operator<(PCGOperatorAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool PCGOperatorAttrs::operator>(PCGOperatorAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool PCGOperatorAttrs::operator<=(PCGOperatorAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool PCGOperatorAttrs::operator>=(PCGOperatorAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::PCGOperatorAttrs>::operator()( + ::FlexFlow::PCGOperatorAttrs const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::PCGOperatorAttrs + adl_serializer<::FlexFlow::PCGOperatorAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "batch_matmul") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchMatmulAttrs>()}; + } else if (key == "batch_norm") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::BatchNormAttrs>()}; + } else if (key == "cast") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CastAttrs>()}; + } else if (key == "combine_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::CombineAttrs>()}; + } else if (key == "concat") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ConcatAttrs>()}; + } else if (key == "conv2d") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Conv2DAttrs>()}; + } else if (key == "dropout") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::DropoutAttrs>()}; + } else if (key == "element_binary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementBinaryAttrs>()}; + } else if (key == "element_unary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementUnaryAttrs>()}; + } else if (key == "element_scalar_unary") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ElementScalarUnaryAttrs>()}; + } else if (key == "embedding") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::EmbeddingAttrs>()}; + } else if (key == "flat") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::FlatAttrs>()}; + } else if (key == "gather") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::GatherAttrs>()}; + } else if (key == "input") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::InputAttrs>()}; + } else if (key == "layer_norm") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LayerNormAttrs>()}; + } else if (key == "linear") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::LinearAttrs>()}; + } else if (key == "multi_head_attention") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::MultiHeadAttentionAttrs>()}; + } else if (key == "noop") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::NoopAttrs>()}; + } else if (key == "pool2d") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::Pool2DAttrs>()}; + } else if (key == "reduce") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReduceAttrs>()}; + } else if (key == "reduce_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReductionAttrs>()}; + } else if (key == "partition_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::RepartitionAttrs>()}; + } else if (key == "replicate_distributed") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReplicateAttrs>()}; + } else if (key == "reverse") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReverseAttrs>()}; + } else if (key == "reshape") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::ReshapeAttrs>()}; + } else if (key == "split") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SplitAttrs>()}; + } else if (key == "softmax") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::SoftmaxAttrs>()}; + } else if (key == "topk") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TopKAttrs>()}; + } else if (key == "transpose") { + return ::FlexFlow::PCGOperatorAttrs{ + j.at("value").template get<::FlexFlow::TransposeAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::PCGOperatorAttrs>::to_json( + json &j, ::FlexFlow::PCGOperatorAttrs const &x) { + j["__type"] = "PCGOperatorAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "batch_matmul"; + j["value"] = x.get<::FlexFlow::BatchMatmulAttrs>(); + break; + } + case 1: { + j["type"] = "batch_norm"; + j["value"] = x.get<::FlexFlow::BatchNormAttrs>(); + break; + } + case 2: { + j["type"] = "cast"; + j["value"] = x.get<::FlexFlow::CastAttrs>(); + break; + } + case 3: { + j["type"] = "combine_distributed"; + j["value"] = x.get<::FlexFlow::CombineAttrs>(); + break; + } + case 4: { + j["type"] = "concat"; + j["value"] = x.get<::FlexFlow::ConcatAttrs>(); + break; + } + case 5: { + j["type"] = "conv2d"; + j["value"] = x.get<::FlexFlow::Conv2DAttrs>(); + break; + } + case 6: { + j["type"] = "dropout"; + j["value"] = x.get<::FlexFlow::DropoutAttrs>(); + break; + } + case 7: { + j["type"] = "element_binary"; + j["value"] = x.get<::FlexFlow::ElementBinaryAttrs>(); + break; + } + case 8: { + j["type"] = "element_unary"; + j["value"] = x.get<::FlexFlow::ElementUnaryAttrs>(); + break; + } + case 9: { + j["type"] = "element_scalar_unary"; + j["value"] = x.get<::FlexFlow::ElementScalarUnaryAttrs>(); + break; + } + case 10: { + j["type"] = "embedding"; + j["value"] = x.get<::FlexFlow::EmbeddingAttrs>(); + break; + } + case 11: { + j["type"] = "flat"; + j["value"] = x.get<::FlexFlow::FlatAttrs>(); + break; + } + case 12: { + j["type"] = "gather"; + j["value"] = x.get<::FlexFlow::GatherAttrs>(); + break; + } + case 13: { + j["type"] = "input"; + j["value"] = x.get<::FlexFlow::InputAttrs>(); + break; + } + case 14: { + j["type"] = "layer_norm"; + j["value"] = x.get<::FlexFlow::LayerNormAttrs>(); + break; + } + case 15: { + j["type"] = "linear"; + j["value"] = x.get<::FlexFlow::LinearAttrs>(); + break; + } + case 16: { + j["type"] = "multi_head_attention"; + j["value"] = x.get<::FlexFlow::MultiHeadAttentionAttrs>(); + break; + } + case 17: { + j["type"] = "noop"; + j["value"] = x.get<::FlexFlow::NoopAttrs>(); + break; + } + case 18: { + j["type"] = "pool2d"; + j["value"] = x.get<::FlexFlow::Pool2DAttrs>(); + break; + } + case 19: { + j["type"] = "reduce"; + j["value"] = x.get<::FlexFlow::ReduceAttrs>(); + break; + } + case 20: { + j["type"] = "reduce_distributed"; + j["value"] = x.get<::FlexFlow::ReductionAttrs>(); + break; + } + case 21: { + j["type"] = "partition_distributed"; + j["value"] = x.get<::FlexFlow::RepartitionAttrs>(); + break; + } + case 22: { + j["type"] = "replicate_distributed"; + j["value"] = x.get<::FlexFlow::ReplicateAttrs>(); + break; + } + case 23: { + j["type"] = "reverse"; + j["value"] = x.get<::FlexFlow::ReverseAttrs>(); + break; + } + case 24: { + j["type"] = "reshape"; + j["value"] = x.get<::FlexFlow::ReshapeAttrs>(); + break; + } + case 25: { + j["type"] = "split"; + j["value"] = x.get<::FlexFlow::SplitAttrs>(); + break; + } + case 26: { + j["type"] = "softmax"; + j["value"] = x.get<::FlexFlow::SoftmaxAttrs>(); + break; + } + case 27: { + j["type"] = "topk"; + j["value"] = x.get<::FlexFlow::TopKAttrs>(); + break; + } + case 28: { + j["type"] = "transpose"; + j["value"] = x.get<::FlexFlow::TransposeAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::PCGOperatorAttrs> + Arbitrary<::FlexFlow::PCGOperatorAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchMatmulAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::BatchNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CastAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::CombineAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ConcatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Conv2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::DropoutAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementBinaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ElementScalarUnaryAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::EmbeddingAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::FlatAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::GatherAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::InputAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LayerNormAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::LinearAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::MultiHeadAttentionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::NoopAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::Pool2DAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReduceAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReductionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::RepartitionAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReplicateAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReverseAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::ReshapeAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SplitAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::SoftmaxAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TopKAttrs>()), + gen::construct<::FlexFlow::PCGOperatorAttrs>( + gen::arbitrary<::FlexFlow::TransposeAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::PCGOperatorAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << ""; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + case 13: { + oss << ""; + break; + } + case 14: { + oss << ""; + break; + } + case 15: { + oss << ""; + break; + } + case 16: { + oss << ""; + break; + } + case 17: { + oss << ""; + break; + } + case 18: { + oss << ""; + break; + } + case 19: { + oss << ""; + break; + } + case 20: { + oss << ""; + break; + } + case 21: { + oss << ""; + break; + } + case 22: { + oss << ""; + break; + } + case 23: { + oss << ""; + break; + } + case 24: { + oss << ""; + break; + } + case 25: { + oss << ""; + break; + } + case 26: { + oss << ""; + break; + } + case 27: { + oss << ""; + break; + } + case 28: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type PCGOperatorAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::PCGOperatorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/pool_op.dtg.cc b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc new file mode 100644 index 0000000000..08a6f43943 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/pool_op.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/pool_op.enum.toml +/* proj-data +{ + "generated_from": "ed1d531c6227306c909eb28eb0a66538" +} +*/ + +#include "op-attrs/pool_op.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::PoolOp x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(PoolOp x) { + switch (x) { + case PoolOp::MAX: + return "MAX"; + case PoolOp::AVG: + return "AVG"; + default: + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, PoolOp x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, PoolOp x) { + switch (x) { + case PoolOp::MAX: + j = "MAX"; + break; + case PoolOp::AVG: + j = "AVG"; + break; + default: + std::ostringstream oss; + oss << "Unknown PoolOp value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, PoolOp &x) { + std::string as_str = j.get(); + if (as_str == "MAX") { + x = PoolOp::MAX; + } else if (as_str == "AVG") { + x = PoolOp::AVG; + } else { + std::ostringstream oss; + oss << "Unknown PoolOp value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::PoolOp::MAX, + FlexFlow::PoolOp::AVG); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc new file mode 100644 index 0000000000..d1f844ab10 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/regularizer_attrs.dtg.cc @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/regularizer_attrs.variant.toml +/* proj-data +{ + "generated_from": "ea060a8ab344c9772102f084903883ea" +} +*/ + +#include "op-attrs/regularizer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L1RegularizerAttrs const &v) + : raw_variant(v) {} +RegularizerAttrs::RegularizerAttrs(::FlexFlow::L2RegularizerAttrs const &v) + : raw_variant(v) {} +bool RegularizerAttrs::operator==(RegularizerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool RegularizerAttrs::operator!=(RegularizerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool RegularizerAttrs::operator<(RegularizerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool RegularizerAttrs::operator>(RegularizerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool RegularizerAttrs::operator<=(RegularizerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool RegularizerAttrs::operator>=(RegularizerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::RegularizerAttrs>::operator()( + ::FlexFlow::RegularizerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::RegularizerAttrs + adl_serializer<::FlexFlow::RegularizerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "l1") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L1RegularizerAttrs>()}; + } else if (key == "l2") { + return ::FlexFlow::RegularizerAttrs{ + j.at("value").template get<::FlexFlow::L2RegularizerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::RegularizerAttrs>::to_json( + json &j, ::FlexFlow::RegularizerAttrs const &x) { + j["__type"] = "RegularizerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "l1"; + j["value"] = x.get<::FlexFlow::L1RegularizerAttrs>(); + break; + } + case 1: { + j["type"] = "l2"; + j["value"] = x.get<::FlexFlow::L2RegularizerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace rc { +Gen<::FlexFlow::RegularizerAttrs> + Arbitrary<::FlexFlow::RegularizerAttrs>::arbitrary() { + return gen::oneOf(gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L1RegularizerAttrs>()), + gen::construct<::FlexFlow::RegularizerAttrs>( + gen::arbitrary<::FlexFlow::L2RegularizerAttrs>())); +} +} // namespace rc +namespace FlexFlow { +std::string format_as(::FlexFlow::RegularizerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type RegularizerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::RegularizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc new file mode 100644 index 0000000000..44b17c8b44 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/replica_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ReplicaParallelDim const &d) { + return d.degree > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc new file mode 100644 index 0000000000..a1256ad79a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "f501393070c8d55a05c43dd73a81a8d7" +} +*/ + +#include "op-attrs/replica_parallel_dim.dtg.h" + +#include "op-attrs/replica_type.dtg.h" +#include + +namespace FlexFlow { +ReplicaParallelDim::ReplicaParallelDim( + int const °ree, ::FlexFlow::ReplicaType const &replica_type) + : degree(degree), replica_type(replica_type) {} +bool ReplicaParallelDim::operator==(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) == + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator!=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) != + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) < + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) > + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator<=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) <= + std::tie(other.degree, other.replica_type); +} +bool ReplicaParallelDim::operator>=(ReplicaParallelDim const &other) const { + return std::tie(this->degree, this->replica_type) >= + std::tie(other.degree, other.replica_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::ReplicaType>{}(x.replica_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("degree").template get(), + j.at("replica_type").template get<::FlexFlow::ReplicaType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDim const &v) { + j["__type"] = "ReplicaParallelDim"; + j["degree"] = v.degree; + j["replica_type"] = v.replica_type; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary<::FlexFlow::ReplicaType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc new file mode 100644 index 0000000000..7ef228e97e --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.cc @@ -0,0 +1,36 @@ +#include "op-attrs/replica_parallel_dim_set.h" +#include "utils/exception.h" + +namespace FlexFlow { + +ReplicaParallelDimSet empty_replica_parallel_dim_set() { + return ReplicaParallelDimSet{1, 1}; +} + +int get_order_of_replica_type(ReplicaParallelDimSet const &s, + ReplicaType replica_type) { + switch (replica_type) { + case ReplicaType::SUM: + return s.sum_degree.value; + case ReplicaType::DISCARD_COPY: + return s.discard_copy_degree.value; + default: + throw mk_runtime_error(fmt::format("Unexpected ReplicaType value: {}", + static_cast(replica_type))); + } +} + +std::unordered_set + get_replica_dims(ReplicaParallelDimSet const &s) { + return std::unordered_set{ + ReplicaParallelDim{s.sum_degree.value, ReplicaType::SUM}, + ReplicaParallelDim{s.discard_copy_degree.value, + ReplicaType::DISCARD_COPY}, + }; +} + +bool is_valid(ReplicaParallelDimSet const &s) { + return s.sum_degree.value > 0 && s.discard_copy_degree.value > 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc new file mode 100644 index 0000000000..f8782be01b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_parallel_dim_set.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_parallel_dim_set.struct.toml +/* proj-data +{ + "generated_from": "74230e2d18db5c059d3e7be0f25e746e" +} +*/ + +#include "op-attrs/replica_parallel_dim_set.dtg.h" + +#include "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h" +#include "op-attrs/parallel_tensor_shape/sum_degree.dtg.h" +#include + +namespace FlexFlow { +ReplicaParallelDimSet::ReplicaParallelDimSet( + ::FlexFlow::SumDegree const &sum_degree, + ::FlexFlow::DiscardCopyDegree const &discard_copy_degree) + : sum_degree(sum_degree), discard_copy_degree(discard_copy_degree) {} +bool ReplicaParallelDimSet::operator==( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) == + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator!=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) != + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) < + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) > + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator<=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) <= + std::tie(other.sum_degree, other.discard_copy_degree); +} +bool ReplicaParallelDimSet::operator>=( + ReplicaParallelDimSet const &other) const { + return std::tie(this->sum_degree, this->discard_copy_degree) >= + std::tie(other.sum_degree, other.discard_copy_degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ReplicaParallelDimSet const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::SumDegree>{}(x.sum_degree) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DiscardCopyDegree>{}(x.discard_copy_degree) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ReplicaParallelDimSet + adl_serializer::from_json(json const &j) { + return {j.at("sum_degree").template get<::FlexFlow::SumDegree>(), + j.at("discard_copy_degree") + .template get<::FlexFlow::DiscardCopyDegree>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ReplicaParallelDimSet const &v) { + j["__type"] = "ReplicaParallelDimSet"; + j["sum_degree"] = v.sum_degree; + j["discard_copy_degree"] = v.discard_copy_degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::SumDegree>(), + gen::arbitrary<::FlexFlow::DiscardCopyDegree>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ReplicaParallelDimSet const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ReplicaParallelDimSet const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/replica_type.dtg.cc b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc new file mode 100644 index 0000000000..d0410c49e2 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/replica_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/replica_type.enum.toml +/* proj-data +{ + "generated_from": "6ecba7a6851b8bea93705bba24661149" +} +*/ + +#include "op-attrs/replica_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::ReplicaType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + return "SUM"; + case ReplicaType::DISCARD_COPY: + return "DISCARD_COPY"; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ReplicaType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ReplicaType x) { + switch (x) { + case ReplicaType::SUM: + j = "SUM"; + break; + case ReplicaType::DISCARD_COPY: + j = "DISCARD_COPY"; + break; + default: + std::ostringstream oss; + oss << "Unknown ReplicaType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ReplicaType &x) { + std::string as_str = j.get(); + if (as_str == "SUM") { + x = ReplicaType::SUM; + } else if (as_str == "DISCARD_COPY") { + x = ReplicaType::DISCARD_COPY; + } else { + std::ostringstream oss; + oss << "Unknown ReplicaType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ReplicaType::SUM, FlexFlow::ReplicaType::DISCARD_COPY); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc new file mode 100644 index 0000000000..d27a857723 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.cc @@ -0,0 +1,9 @@ +#include "op-attrs/shard_parallel_dim.h" + +namespace FlexFlow { + +bool is_valid(ShardParallelDim const &d) { + return d.degree > 0 && d.size > 0 && (d.size % d.degree) == 0; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc new file mode 100644 index 0000000000..9566eb486b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/shard_parallel_dim.dtg.cc @@ -0,0 +1,89 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/shard_parallel_dim.struct.toml +/* proj-data +{ + "generated_from": "18e074f80556d90b9b27d6515bbf9071" +} +*/ + +#include "op-attrs/shard_parallel_dim.dtg.h" + +#include + +namespace FlexFlow { +ShardParallelDim::ShardParallelDim(size_t const &size, int const °ree) + : size(size), degree(degree) {} +bool ShardParallelDim::operator==(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) == + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator!=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) != + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) < + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) > + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator<=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) <= + std::tie(other.size, other.degree); +} +bool ShardParallelDim::operator>=(ShardParallelDim const &other) const { + return std::tie(this->size, this->degree) >= + std::tie(other.size, other.degree); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ShardParallelDim const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.size) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.degree) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ShardParallelDim + adl_serializer::from_json(json const &j) { + return {j.at("size").template get(), + j.at("degree").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ShardParallelDim const &v) { + j["__type"] = "ShardParallelDim"; + j["size"] = v.size; + j["degree"] = v.degree; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ShardParallelDim const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ShardParallelDim const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc new file mode 100644 index 0000000000..ed40f509d9 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -0,0 +1,52 @@ +#include "op-attrs/tensor_dims.h" +#include "op-attrs/replica_parallel_dim_set.h" +#include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/containers.h" +#include "utils/containers/zip_vectors.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +FFOrdered const &ff_ordered(TensorDims const &dims) { + return dims.ff_ordered; +} + +size_t num_dims(TensorDims const &dims) { + return dims.ff_ordered.size(); +} + +size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { + return dims.ff_ordered.at(idx); +} + +size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { + return dims.ff_ordered.at(idx); +} + +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees(dims, 1, 1, shard_degrees); +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &dims, + SumDegree sum_degree, + DiscardCopyDegree discard_copy_degree, + FFOrdered const &shard_degrees) { + std::vector lifted = + transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim(size, degree); + }); + + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc new file mode 100644 index 0000000000..909be323ac --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_dims.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +/* proj-data +{ + "generated_from": "5beb89eeae9eba303f90e726c794375d" +} +*/ + +#include "op-attrs/tensor_dims.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include + +namespace FlexFlow { +TensorDims::TensorDims(::FlexFlow::FFOrdered const &ff_ordered) + : ff_ordered(ff_ordered) {} +bool TensorDims::operator==(TensorDims const &other) const { + return std::tie(this->ff_ordered) == std::tie(other.ff_ordered); +} +bool TensorDims::operator!=(TensorDims const &other) const { + return std::tie(this->ff_ordered) != std::tie(other.ff_ordered); +} +bool TensorDims::operator<(TensorDims const &other) const { + return std::tie(this->ff_ordered) < std::tie(other.ff_ordered); +} +bool TensorDims::operator>(TensorDims const &other) const { + return std::tie(this->ff_ordered) > std::tie(other.ff_ordered); +} +bool TensorDims::operator<=(TensorDims const &other) const { + return std::tie(this->ff_ordered) <= std::tie(other.ff_ordered); +} +bool TensorDims::operator>=(TensorDims const &other) const { + return std::tie(this->ff_ordered) >= std::tie(other.ff_ordered); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorDims const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::FFOrdered>{}(x.ff_ordered) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorDims + adl_serializer::from_json(json const &j) { + return {j.at("ff_ordered").template get<::FlexFlow::FFOrdered>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorDims const &v) { + j["__type"] = "TensorDims"; + j["ff_ordered"] = v.ff_ordered; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::FFOrdered>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorDims const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorDims const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc new file mode 100644 index 0000000000..850bea6d00 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -0,0 +1,18 @@ +#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_dims.h" + +namespace FlexFlow { + +size_t num_dims(TensorShape const &s) { + return s.dims.ff_ordered.size(); +} + +size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); +} + +size_t &dim_at_idx(TensorShape &s, ff_dim_t idx) { + return dim_at_idx(s.dims, idx); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc new file mode 100644 index 0000000000..92b31930fa --- /dev/null +++ b/lib/op-attrs/src/op-attrs/tensor_shape.dtg.cc @@ -0,0 +1,92 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/op-attrs/include/op-attrs/tensor_shape.struct.toml +/* proj-data +{ + "generated_from": "ef6fa5088b89d6da4dc8bddf0a6d3294" +} +*/ + +#include "op-attrs/tensor_shape.dtg.h" + +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/tensor_dims.dtg.h" +#include + +namespace FlexFlow { +TensorShape::TensorShape(::FlexFlow::TensorDims const &dims, + ::FlexFlow::DataType const &data_type) + : dims(dims), data_type(data_type) {} +bool TensorShape::operator==(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) == + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator!=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) != + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) < + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) > + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator<=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) <= + std::tie(other.dims, other.data_type); +} +bool TensorShape::operator>=(TensorShape const &other) const { + return std::tie(this->dims, this->data_type) >= + std::tie(other.dims, other.data_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorShape const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorDims>{}(x.dims) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::DataType>{}(x.data_type) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorShape + adl_serializer::from_json(json const &j) { + return {j.at("dims").template get<::FlexFlow::TensorDims>(), + j.at("data_type").template get<::FlexFlow::DataType>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorShape const &v) { + j["__type"] = "TensorShape"; + j["dims"] = v.dims; + j["data_type"] = v.data_type; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorDims>(), + gen::arbitrary<::FlexFlow::DataType>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorShape const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op.cc b/lib/op-attrs/src/op.cc deleted file mode 100644 index 5bc5498d6e..0000000000 --- a/lib/op-attrs/src/op.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "op-attrs/op.h" - -namespace FlexFlow { - -std::string get_operator_type_name(Op op) { - return fmt::to_string(op); -} - -bool is_parallel_op(OperatorType const &t) { - switch (t) { - case Op::REPARTITION: - case Op::COMBINE: - case Op::REPLICATE: - case Op::REDUCTION: - case Op::BATCH: - case Op::PIPELINE: - case Op::FUSED_PARALLEL: - return true; - default: - return false; - } -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc index a524ab3d14..7a0027fe61 100644 --- a/lib/op-attrs/src/operator_attrs.cc +++ b/lib/op-attrs/src/operator_attrs.cc @@ -193,7 +193,7 @@ struct IsValidFunctor { bool is_valid(PCGOperatorAttrs const &attrs, std::vector const &input_shapes) { - return visit(IsValidFunctor{input_shapes}, attrs); + NOT_IMPLEMENTED(); } /* int num_outputs(OperatorParameters const &o) { */ diff --git a/lib/op-attrs/src/parallel_dim.cc b/lib/op-attrs/src/parallel_dim.cc deleted file mode 100644 index e103625fab..0000000000 --- a/lib/op-attrs/src/parallel_dim.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "op-attrs/parallel_dim.h" - -namespace FlexFlow { - -bool is_valid(ParallelDim const &dim) { - return dim.size > 0 && dim.degree >= 1 && dim.size % dim.degree == 0; -} - -bool is_replica_dim(ParallelDim const &dim) { - return dim.is_replica_dim; -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc deleted file mode 100644 index c7e70bb906..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ /dev/null @@ -1,362 +0,0 @@ -#include "parallel_dim_mapping_record_solver.h" -#include "op-attrs/parallel_tensor_shape.h" -#include -#include - -namespace FlexFlow { - -std::vector construct_weight_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int weight_idx) { - - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::tuple const &mapping) { - return construct_weight_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - weight_idx, - std::get<1>(mapping)); - }); - return output; -} - -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx, - int output_idx) { - NOT_IMPLEMENTED(); -} - -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx, - int weight_idx) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_output_parallel_dims(int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} - -ParallelDimMappingRecord - construct_weight_parallel_dims(int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - NOT_IMPLEMENTED(); -} -/* int get_output_to_input_dim_mapping(ParallelTensorShape const &output, */ -/* int output_dim, */ -/* ParallelTensorShape const &input) { */ -/* int output_idx = -1, input_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (input == inputs[i]) { */ -/* input_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(input_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].input_idx != input_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].weight_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].weight_dim = -1); */ -/* return (*parallel_dims_mapping)[i].input_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* int get_output_to_weight_dim_mapping(const ParallelTensor output, */ -/* int output_dim, */ -/* const ParallelTensor weight) { */ -/* int output_idx = -1, weight_idx = -1; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (output == outputs[i]) { */ -/* output_idx = i; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (weight == weights[i]) { */ -/* weight_idx = i; */ -/* } */ -/* } */ -/* assert(output_idx != -1); */ -/* assert(weight_idx != -1); */ -/* for (size_t i = 0; i < parallel_dims_mapping->size(); i++) { */ -/* if ((*parallel_dims_mapping)[i].output_idx != output_idx) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].output_dim != output_dim) { */ -/* continue; */ -/* } */ -/* if ((*parallel_dims_mapping)[i].weight_idx != weight_idx) { */ -/* continue; */ -/* } */ -/* // Check validness */ -/* assert((*parallel_dims_mapping)[i].input_idx = -1); */ -/* assert((*parallel_dims_mapping)[i].input_dim = -1); */ -/* return (*parallel_dims_mapping)[i].weight_dim; */ -/* } */ -/* assert(false); */ -/* return -1; */ -/* } */ - -/* bool check_output_input_weight_parallel_dims(bool allocate_weights) const { - */ -/* // if (!allocate_weights) { */ -/* // assert(this->numWeights == 0); */ -/* // } */ - -/* for (ParallelDimMappingRecord const &record : *parallel_dims_mapping) { */ -/* assert(record.input_idx < this->numInputs); */ -/* assert(record.input_dim < this->inputs[record.input_idx]->num_dims); */ -/* ParallelDim const &input_dim = */ -/* inputs[record.input_idx]->dims[record.input_dim]; */ -/* /1* assert (input_dim.degree != ParallelDim::UNKNOWN_DEGREE); *1/ */ -/* /1* assert (input_dim.parallel_idx != ParallelDim::UNKNOWN_INDEX); *1/ */ - -/* ParallelDim other_dim; */ -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: */ -/* assert(record.output_idx < this->numOutputs); */ -/* assert(record.output_dim < - * this->outputs[record.output_idx]->num_dims); */ -/* other_dim = outputs[record.output_idx]->dims[record.output_dim]; */ -/* break; */ -/* case MappingRecordType::INPUT_WEIGHT: */ -/* if (!allocate_weights) { */ -/* continue; */ -/* } */ -/* if (record.weight_idx >= this->numWeights) { */ -/* // The case where some weights are not used (e.g., no bias for - * linear) */ -/* continue; */ -/* } */ -/* assert(record.weight_dim < - * this->weights[record.weight_idx]->num_dims); */ -/* other_dim = weights[record.weight_idx]->dims[record.weight_dim]; */ -/* break; */ -/* } */ - -/* assert(other_dim.degree == input_dim.degree); */ -/* assert(other_dim.parallel_idx == input_dim.parallel_idx); */ -/* } */ -/* return true; */ -/* } */ - -/* bool check_output_input_weight_same_machine_view() const { */ -/* assert(numOutputs > 0); */ -/* MachineView machine_view = outputs[0]->machine_view; */ -/* for (int i = 0; i < numOutputs; i++) { */ -/* if (outputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numInputs; i++) { */ -/* if (inputs[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* for (int i = 0; i < numWeights; i++) { */ -/* if (weights[i]->machine_view != machine_view) { */ -/* return false; */ -/* } */ -/* } */ -/* return true; */ -/* } */ - -std::vector construct_weight_parallel_dims( - std::vector> mappings, int input_idx, int weight_idx) { - std::vector output; - std::transform(mappings.cbegin(), - mappings.cend(), - output.begin(), - [&](std::pair const &mapping) { - return construct_weight_parallel_dims( - mapping.first, mapping.second, input_idx, weight_idx); - }); - return output; -} - -void construct_weight_parallel_dims( - std::vector &records, - int input_dim, - int weight_dim, - int input_idx, - int weight_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_weight_record( - input_idx, input_dim, weight_idx, weight_dim, operation)); -} - -/* void ParallelDimMappingRecordSolver::register_weight_parallel_dims( */ -/* std::vector> mappings, int input_idx, int weight_idx) - * { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int weight_idx) { */ -/* construct_weight_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, weight_idx); */ -/* } */ - -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx, */ -/* int weight_idx, */ -/* tl::optional operation) { */ -/* construct_weight_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* weight_dim, */ -/* input_idx, */ -/* weight_idx, */ -/* operation); */ -/* } */ - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::tuple const &mapping : mappings) { - construct_output_parallel_dims(std::get<0>(mapping), - std::get<2>(mapping), - input_idx, - output_idx, - std::get<1>(mapping)); - } -} - -void construct_output_parallel_dims( - std::vector &records, - std::vector> mappings, - int input_idx, - int output_idx) { - for (std::pair const &mapping : mappings) { - construct_output_parallel_dims( - mapping.first, mapping.second, input_idx, output_idx); - } -} - -void construct_output_parallel_dims( - std::vector &records, - int input_dim, - int output_dim, - int input_idx, - int output_idx, - std::optional operation) { - records.push_back(ParallelDimMappingRecord::input_output_record( - input_idx, input_dim, output_idx, output_dim, operation)); -} - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, int input_idx, int output_idx) - * { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx, */ -/* int output_idx) { */ -/* construct_output_parallel_dims( */ -/* *this->parallel_dims_mapping, mappings, input_idx, output_idx); */ -/* } */ - -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx, */ -/* int output_idx, */ -/* tl::optional operation) { */ -/* construct_output_parallel_dims(*this->parallel_dims_mapping, */ -/* input_dim, */ -/* output_dim, */ -/* input_idx, */ -/* output_idx, */ -/* operation); */ -/* } */ - -/* ParallelDimMappingSolution solve_parallel_dim_mappings( */ -/* std::vector const &mappings, */ -/* std::vector const &inputs, */ -/* int numWeights, int numOutputs) { */ - -/* ParallelDimMappingSolution solution = [&]() -> ParallelDimMappingSolution { - */ -/* std::vector weight_shapes(numWeights); */ -/* std::vector output_shapes(numOutputs); */ -/* return { weight_shapes, output_shapes }; */ -/* }(); */ - -/* for (ParallelDimMappingRecord const &record : mappings) { */ -/* ParallelDim const &input_dim = - * inputs.at(record.input_idx).at(record.input_dim); */ - -/* switch (record.get_type()) { */ -/* case MappingRecordType::INPUT_OUTPUT: { */ -/* ParallelDim &output_dim = - * solution.output_shapes.at(record.output_idx).at(record.output_dim); */ -/* output_dim.degree = input_dim.degree; */ -/* output_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (output_dim.is_replica_dim) { */ -/* output_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* case MappingRecordType::INPUT_WEIGHT: { */ -/* ParallelDim &weight_dim = - * solution.weight_shapes.at(record.weight_idx).at(record.weight_dim); */ -/* weight_dim.degree = input_dim.degree; */ -/* weight_dim.parallel_idx = input_dim.parallel_idx; */ - -/* if (weight_dim.is_replica_dim) { */ -/* weight_dim.size = input_dim.degree; */ -/* } */ -/* } break; */ -/* } */ -/* } */ - -/* return solution; */ -/* } */ - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h b/lib/op-attrs/src/parallel_dim_mapping_record_solver.h deleted file mode 100644 index a46192edeb..0000000000 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * @file - * @warning This is legacy code the should be removed - * (partially tracked in - * https://github.com/flexflow/FlexFlow/issues/519). - * @brief Helper functions for computing data dependencies of parallel - * operators. Functions based on an incorrect abstraction that should eventually - * be removed in favor of something like https://doi.org/10.1145/3302424.3303953 - */ - -#ifndef _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H -#define _FLEXFLOW_OP_META_SRC_PARELLEL_DIM_MAPPING_RECORD_SOLVER_H - -#include "op-attrs/parallel_tensor_shape.h" -#include "parallel_dim_mapping_record.h" - -namespace FlexFlow { - -std::vector - construct_weight_parallel_dims(std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -std::vector construct_weight_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int weight_idx = 0); -ParallelDimMappingRecord construct_weight_parallel_dims( - int input_dim, - int weight_dim, - int input_idx = 0, - int weight_idx = 0, - std::optional operation = std::nullopt); - -std::vector - construct_output_parallel_dims(std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -std::vector construct_output_parallel_dims( - std::vector> mappings, - int input_idx = 0, - int output_idx = 0); -ParallelDimMappingRecord construct_output_parallel_dims( - int input_dim, - int output_dim, - int input_idx = 0, - int output_idx = 0, - std::optional operation = std::nullopt); - -struct ParallelDimMappingSolution { - std::vector weight_shapes; - std::vector output_shapes; -}; - -ParallelDimMappingSolution solve_parallel_dim_mappings( - std::vector const &mappings, - std::vector const &input, - int numWeights, - int numOutputs); - -/* class ParallelDimMappingRecordSolver { */ -/* /1* void register_weight_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int weight_idx = 0); *1/ */ - -/* /1* void register_output_parallel_dims(std::vector> - * mappings, *1/ */ -/* /1* int input_idx = 0, *1/ */ -/* /1* int output_idx = 0); *1/ */ - -/* /1* int get_output_to_input_dim_mapping(const ParallelTensor output, *1/ */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor input); *1/ */ -/* /1* int get_output_to_weight_dim_mapping(const ParallelTensor output, *1/ - */ -/* /1* int output_dim, *1/ */ -/* /1* const ParallelTensor weight); *1/ - */ -/* void register_weight_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int weight_idx = 0); */ -/* void register_weight_parallel_dims( */ -/* int input_dim, */ -/* int weight_dim, */ -/* int input_idx = 0, */ -/* int weight_idx = 0, */ -/* std::optional operation = std::nullopt); */ -/* void register_output_parallel_dims( */ -/* std::vector> mappings, */ -/* int input_idx = 0, */ -/* int output_idx = 0); */ -/* void register_output_parallel_dims( */ -/* int input_dim, */ -/* int output_dim, */ -/* int input_idx = 0, */ -/* int output_idx = 0, */ -/* std::optional operation = std::nullopt); */ - -/* private: */ -/* std::vector *parallel_dims_mapping; */ -/* }; */ - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/src/parallel_tensor_shape.cc b/lib/op-attrs/src/parallel_tensor_shape.cc deleted file mode 100644 index e226c38eac..0000000000 --- a/lib/op-attrs/src/parallel_tensor_shape.cc +++ /dev/null @@ -1,97 +0,0 @@ -#include "op-attrs/parallel_tensor_shape.h" -#include "utils/containers.h" -#include "utils/hash-utils.h" - -namespace FlexFlow { - -int ParallelTensorShape::num_dims() const { - return dims.num_dims(); -} - -static std::vector lift_dims(TensorDims const &dims) { - std::vector lifted_dims; - for (size_t dim_size : dims) { - lifted_dims.push_back({dim_size, 1, false}); - } - lifted_dims.push_back({1, 1, true}); - return lifted_dims; -} - -ParallelTensorDims::ParallelTensorDims(TensorDims const &dims) - : data(lift_dims(dims)) {} - -ParallelTensorShape::ParallelTensorShape(TensorShape const &tensor_shape) - : dims(tensor_shape.dims), data_type(tensor_shape.data_type) {} - -int get_num_replica_dims(ParallelTensorShape const &shape) { - return count(shape.dims, is_replica_dim); -} - -int get_num_replicas(ParallelTensorShape const &shape) { - return product( - transform(filter(as_vector(shape.dims), is_replica_dim), - [](ParallelDim const &d) -> int { return d.degree; })); -} - -bool is_valid(ParallelTensorDims const &dims) { - return all_of(dims, [](ParallelDim const &d) { return is_valid(d); }); -} - -bool is_valid(ParallelTensorShape const &shape) { - return is_valid(shape.dims); -} - -ParallelTensorDims::iterator ParallelTensorDims::begin() { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::begin() const { - return data.begin(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cbegin() const { - return data.cbegin(); -} - -ParallelTensorDims::iterator ParallelTensorDims::end() { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::end() const { - return data.end(); -} - -ParallelTensorDims::const_iterator ParallelTensorDims::cend() const { - return data.cend(); -} - -ParallelDim const &ParallelTensorDims::at(ff_dim_t const &d) const { - return data.at(d); -} - -ParallelDim &ParallelTensorDims::at(ff_dim_t const &d) { - return data.at(d); -} - -size_t ParallelTensorDims::num_dims() const { - return data.size(); -} - -ParallelDim const &ParallelTensorShape::at(ff_dim_t const &d) const { - return dims.at(d); -} - -ParallelDim &ParallelTensorShape::at(ff_dim_t const &d) { - return dims.at(d); -} -ParallelDim const &ParallelTensorShape::operator[](ff_dim_t const &d) const { - return dims.at(d); -} -ParallelDim &ParallelTensorShape::operator[](ff_dim_t const &d) { - return dims.at(d); -} - -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} -} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduce.cc b/lib/op-attrs/src/reduce.cc deleted file mode 100644 index 9d1770d5be..0000000000 --- a/lib/op-attrs/src/reduce.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reduce.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reduction.cc b/lib/op-attrs/src/reduction.cc deleted file mode 100644 index 22fc9bab6a..0000000000 --- a/lib/op-attrs/src/reduction.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/ops/reduction.h" - -namespace FlexFlow { - -/* ParallelTensorShape ReductionAttrs::output_shape(ParallelTensorShape const - * &input_shape) const { */ -/* ParallelTensorShape output = input_shape; */ -/* output.at(this->reduction_legion_dim).degree /= this->reduction_degree; */ -/* output.at(this->reduction_legion_dim).size /= this->reduction_degree; */ -/* return output; */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/repartition.cc b/lib/op-attrs/src/repartition.cc deleted file mode 100644 index 672e68b4f6..0000000000 --- a/lib/op-attrs/src/repartition.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "op-attrs/ops/repartition.h" - -namespace FlexFlow { - -/* bool RepartitionAttrs::is_valid(ParallelTensorShape const &input_shape) const - * { */ -/* ParallelDim dim = input_shape.at(this->repartition_legion_dim); */ -/* return (dim.size % this->repartition_degree * dim.degree == 0); */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/replicate.cc b/lib/op-attrs/src/replicate.cc deleted file mode 100644 index 73ad288d8c..0000000000 --- a/lib/op-attrs/src/replicate.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/replicate.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/reshape.cc b/lib/op-attrs/src/reshape.cc deleted file mode 100644 index e8349e1f26..0000000000 --- a/lib/op-attrs/src/reshape.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/reshape.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/softmax.cc b/lib/op-attrs/src/softmax.cc deleted file mode 100644 index 9f95da4fb7..0000000000 --- a/lib/op-attrs/src/softmax.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/softmax.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/split.cc b/lib/op-attrs/src/split.cc deleted file mode 100644 index acda8f3262..0000000000 --- a/lib/op-attrs/src/split.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/split.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/tensor_shape.cc b/lib/op-attrs/src/tensor_shape.cc deleted file mode 100644 index e456b31e3c..0000000000 --- a/lib/op-attrs/src/tensor_shape.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/tensor_shape.h" - -namespace FlexFlow { - -size_t TensorShape::at(ff_dim_t d) const { - return dims.at(d); -} - -size_t TensorShape::operator[](ff_dim_t d) const { - return dims[d]; -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/topk.cc b/lib/op-attrs/src/topk.cc deleted file mode 100644 index 9d701e4868..0000000000 --- a/lib/op-attrs/src/topk.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/topk.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/transpose.cc b/lib/op-attrs/src/transpose.cc deleted file mode 100644 index ad4a84a3d5..0000000000 --- a/lib/op-attrs/src/transpose.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/transpose.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/test/CMakeLists.txt b/lib/op-attrs/test/CMakeLists.txt new file mode 100644 index 0000000000..b6ff72fc00 --- /dev/null +++ b/lib/op-attrs/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + op-attrs-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + op-attrs + doctest + utils-test-common +) diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc new file mode 100644 index 0000000000..8640b077dc --- /dev/null +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -0,0 +1,23 @@ +#include "op-attrs/dim_ordered/slice.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "slice(DimOrdered, std::optional, std::optional)") { + FFOrdered d = FFOrdered{ + 1, + 2, + 3, + 4, + }; + + FFOrdered result = slice(d, std::nullopt, ff_dim_t{-1}); + FFOrdered correct = FFOrdered{ + 1, + 2, + 3, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc new file mode 100644 index 0000000000..a50b3b01de --- /dev/null +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -0,0 +1,59 @@ +#include "op-attrs/ops/combine.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Combine shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + ff_dim_t dim = 2; + int degree = 3; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + ff_dim_t dim = 2; + int degree = 4; + CombineAttrs attrs = CombineAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc new file mode 100644 index 0000000000..0d23dc35df --- /dev/null +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -0,0 +1,231 @@ +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Linear shape inference") { + int out_channels = 16; + LinearAttrs attrs = LinearAttrs{ + /*out_channels=*/out_channels, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + size_t batch_size = 12; + size_t extra_dim = 16; + size_t in_channels = 8; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + in_channels, + }, + }, + DataType::FLOAT, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + extra_dim, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape kernel = TensorShape{ + TensorDims{ + FFOrdered{ + in_channels, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape bias = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = + get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weight_shape + { + tl::expected kernel_result = + get_kernel_shape(attrs, input); + tl::expected kernel_correct = kernel; + CHECK(kernel_result == kernel_correct); + } + + // get_bias_shape + { + tl::expected bias_result = + get_bias_shape(attrs, input); + tl::expected bias_correct = bias; + CHECK(bias_result == bias_correct); + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_extra_dim, + int o_channel) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); + }; + + auto make_kernel = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { + return lift_to_parallel_with_degrees( + kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + }; + + auto make_bias = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o_outchannel) { + return lift_to_parallel_with_degrees( + bias, o_sum, o_eq, FFOrdered{o_outchannel}); + }; + + SUBCASE("data parallelism") { + int input_sum_degree = 2; + int extra_dim_degree = 8; + int degree = 4; + + ParallelTensorShape par_input = make_input(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree}, + DiscardCopyDegree{1}, + degree, + extra_dim_degree, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = + make_bias(SumDegree{input_sum_degree}, + DiscardCopyDegree{degree * extra_dim_degree}, + 1); + CHECK(result == correct); + } + } + + SUBCASE("reduction parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{input_sum_degree * degree}, + DiscardCopyDegree{1}, + 1, + 1, + 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree * degree}, DiscardCopyDegree{1}, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel parallelism") { + int input_sum_degree = 2; + int degree = 4; + + ParallelTensorShape par_input = make_input( + SumDegree{input_sum_degree}, DiscardCopyDegree{degree}, 1, 1, 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = make_output( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, 1, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_kernel_shape(attrs, par_input); + tl::expected correct = make_kernel( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_bias_shape(attrs, par_input); + tl::expected correct = make_bias( + SumDegree{input_sum_degree}, DiscardCopyDegree{1}, degree); + CHECK(result == correct); + } + } + } +} diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc new file mode 100644 index 0000000000..6f73951e00 --- /dev/null +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -0,0 +1,55 @@ +#include "op-attrs/ops/reduction.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Reduction shape inference") { + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + int degree = 3; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.replica_dims.sum_degree.value /= degree; + return output; + }(); + + CHECK(result == correct); + } + + SUBCASE("invalid") { + int degree = 4; + ReductionAttrs attrs = ReductionAttrs{ + /*repartition_degree=*/degree, + }; + + tl::expected result = + get_output_shape(attrs, input); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc new file mode 100644 index 0000000000..3b3ae92b4c --- /dev/null +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -0,0 +1,40 @@ +#include "op-attrs/ops/repartition.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Repartition shape inference") { + ff_dim_t dim = 2; + int degree = 4; + RepartitionAttrs attrs = RepartitionAttrs{ + /*repartition_dim=*/dim, + /*repartition_degree=*/degree, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{14, 1}, + ShardParallelDim{16, 3}, + ShardParallelDim{18, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + + tl::expected correct = [&] { + ParallelTensorShape output = input; + output.dims.shard_dims.at(dim).degree *= degree; + return output; + }(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc new file mode 100644 index 0000000000..b326038388 --- /dev/null +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -0,0 +1,33 @@ +#include "op-attrs/ops/replicate.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Replicate shape inference") { + ReplicateAttrs attrs = ReplicateAttrs{ + /*replicate_degree=*/4, + }; + + ParallelTensorShape input = { + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{3}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelTensorShape result = get_output_shape(attrs, input); + + ParallelTensorShape correct_output = input; + correct_output.dims.replica_dims.discard_copy_degree = 8; + + CHECK(result == correct_output); + } +} diff --git a/lib/op-attrs/test/src/test_attention.cc b/lib/op-attrs/test/src/test_attention.cc new file mode 100644 index 0000000000..74ae4565ca --- /dev/null +++ b/lib/op-attrs/test/src/test_attention.cc @@ -0,0 +1,272 @@ +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " + "TensorShape, TensorShape)") { + int embed_dim = 32; + + /* Parameter meanings match those at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batch_size = 40; + size_t seq_len = 48; + + TensorShape input_q = { + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, + }; + + TensorShape input_k = { + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.kdim), + }, + }, + DataType::FLOAT, + }; + + TensorShape input_v = { + TensorDims{ + FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.vdim), + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape") { + tl::expected result = + get_output_shape(attrs, input_q, input_k, input_v); + + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + seq_len, + size_t_from_int(attrs.embed_dim), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_weights_shape") { + tl::expected result = + get_weights_shape(attrs, input_q, input_k, input_v); + + int qProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_q, ff_dim_t{-1}); + int kProjPerHeadWeightSize = + attrs.kdim * dim_at_idx(input_k, ff_dim_t{-1}); + int vProjPerHeadWeightSize = + attrs.vdim * dim_at_idx(input_v, ff_dim_t{-1}); + int oProjPerHeadWeightSize = attrs.embed_dim * attrs.vdim; + int perHeadWeightSize = qProjPerHeadWeightSize + kProjPerHeadWeightSize + + vProjPerHeadWeightSize + oProjPerHeadWeightSize; + + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(perHeadWeightSize), + size_t_from_int(attrs.num_heads), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("parallel shape inference for MultiHeadAttentionAttrs") { + int embed_dim = 32; + + /* Parameter meanings can be found at + * https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html + */ + MultiHeadAttentionAttrs attrs = { + /*embed_dim=*/embed_dim, + /*num_heads=*/10, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batchsize = 40; + size_t seq_len = 48; + size_t q_size = 56; + size_t k_size = 64; + size_t v_size = 72; + + TensorShape unpar_q_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + q_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_k_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + k_size, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_v_shape = TensorShape{ + TensorDims{ + FFOrdered{ + batchsize, + seq_len, + v_size, + }, + }, + DataType::FLOAT, + }; + + tl::expected result_unpar_o_shape = + get_output_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_o_shape = result_unpar_o_shape.value(); + + tl::expected result_unpar_w_shape = + get_weights_shape(attrs, unpar_q_shape, unpar_k_shape, unpar_v_shape); + REQUIRE(result_unpar_o_shape.has_value()); + TensorShape unpar_w_shape = result_unpar_w_shape.value(); + + auto make_q = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_seq_len, + int o_q) { + return lift_to_parallel_with_degrees( + unpar_q_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_q}); + }; + + auto make_k = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_k) { + return lift_to_parallel_with_degrees( + unpar_k_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_k}); + }; + + auto make_v = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_v) { + return lift_to_parallel_with_degrees( + unpar_v_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_v}); + }; + + auto make_o = [&](int o_sum, + int o_eq, + int o_batch, + int o_seq_len, + int o_o) { + return lift_to_parallel_with_degrees( + unpar_o_shape, o_sum, o_eq, FFOrdered{o_batch, o_seq_len, o_o}); + }; + + auto make_w = [&](int o_sum, int o_eq, int o_e, int o_h) { + return lift_to_parallel_with_degrees( + unpar_w_shape, o_sum, o_eq, FFOrdered{o_e, o_h}); + }; + + SUBCASE("data parallelism") { + int o_b = 4; + ParallelTensorShape q = make_q(1, 1, o_b, 1, 1); + ParallelTensorShape k = make_k(1, 1, o_b, 1, 1); + ParallelTensorShape v = make_v(1, 1, o_b, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(1, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, 1); + + CHECK(result_w == correct_w); + } + + SUBCASE("attention head parallelism") { + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, 1, 1, 1); + ParallelTensorShape k = make_k(1, o_h, 1, 1, 1); + ParallelTensorShape v = make_v(1, o_h, 1, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, 1, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, 1, 1, o_h); + + CHECK(result_w == correct_w); + } + + SUBCASE("combined data & attention head parallelism") { + int o_b = 4; + int o_h = 2; + ParallelTensorShape q = make_q(1, o_h, o_b, 1, 1); + ParallelTensorShape k = make_k(1, o_h, o_b, 1, 1); + ParallelTensorShape v = make_v(1, o_h, o_b, 1, 1); + + tl::expected result_o = + get_output_shape(attrs, q, k, v); + tl::expected correct_o = + make_o(o_h, 1, o_b, 1, 1); + + CHECK(result_o == correct_o); + + tl::expected result_w = + get_weights_shape(attrs, q, k, v); + tl::expected correct_w = + make_w(1, o_b, 1, o_h); + + CHECK(result_w == correct_w); + } + } +} diff --git a/lib/op-attrs/test/src/test_batch_matmul.cc b/lib/op-attrs/test/src/test_batch_matmul.cc new file mode 100644 index 0000000000..f48478be10 --- /dev/null +++ b/lib/op-attrs/test/src/test_batch_matmul.cc @@ -0,0 +1,268 @@ +#include "op-attrs/ops/batch_matmul.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { + size_t b = 4; + size_t m = 6; + size_t n = 8; + size_t p = 10; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, + }; + + TensorShape input_lhs_shape = { + TensorDims{ + FFOrdered{ + b, + n, + m, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("valid") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + tl::expected correct_output_shape = TensorShape{ + TensorDims{ + FFOrdered{ + b, + n, + p, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } + + SUBCASE("mismatched b") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b + 1, + m, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + + SUBCASE("mismatched m") { + TensorShape input_rhs_shape = { + TensorDims{ + FFOrdered{ + b, + m + 1, + p, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input_lhs_shape, input_rhs_shape); + + CHECK(!result.has_value()); + } + } + + TEST_CASE("get_output_shape(BatchMatmulAttrs, ParallelTensorShape)") { + size_t b = 2 * 2; + int o_b = 2; + size_t m = 3 * 3; + int o_m = 3; + size_t n = 5 * 5; + int o_n = 5; + size_t p = 7 * 7; + int o_p = 7; + int o_sum = 11; + + BatchMatmulAttrs attrs = { + /*a_seq_length_dim=*/0, // TODO figure out if these arguments are still + // relevant + /*b_seq_length_dim=*/0, + }; + + auto make_lhs = [&](int o_sum, int o_eq, int o_b, int o_n, int o_m) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{m, o_m}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_rhs = [&](int o_sum, int o_eq, int o_b, int o_m, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{m, o_m}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + auto make_output = [&](int o_sum, int o_eq, int o_b, int o_n, int o_p) { + return ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{b, o_b}, + ShardParallelDim{n, o_n}, + ShardParallelDim{p, o_p}, + }, + ReplicaParallelDimSet{ + o_sum, + o_eq, + }, + }, + DataType::FLOAT, + }; + }; + + SUBCASE("data parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, o_b, 1, 1), make_rhs(1, 1, o_b, 1, 1)); + tl::expected correct = + make_output(1, 1, o_b, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("n parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, o_n, 1), make_rhs(1, o_n, 1, 1, 1)); + tl::expected correct = + make_output(1, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("p parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_p, 1, 1, 1), make_rhs(1, 1, 1, 1, o_p)); + tl::expected correct = + make_output(1, 1, 1, 1, o_p); + + CHECK(result == correct); + } + + SUBCASE("reduction parallel") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, 1, 1, 1, o_m), make_rhs(1, 1, 1, o_m, 1)); + tl::expected correct = + make_output(o_m, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction lhs") { + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(1, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("propagate reduction rhs") { + tl::expected result = get_output_shape( + attrs, make_lhs(1, o_sum, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, 1, 1), + make_rhs(o_sum, o_sum, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & rhs (invalid)") { + tl::expected result = get_output_shape( + attrs, make_lhs(o_sum, 1, 1, 1, 1), make_rhs(o_sum, 1, 1, 1, 1)); + + CHECK_MESSAGE( + !result.has_value(), "Unexpected successful value: ", result); + } + + SUBCASE("reduction lhs & n") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, 1, 1, o_n, 1), + make_rhs(1, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, 1), + make_rhs(o_sum, o_sum * o_n, 1, 1, 1)); + tl::expected correct = + make_output(o_sum * o_sum, 1, 1, o_n, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction lhs & reduction rhs & n & m") { + tl::expected result = + get_output_shape(attrs, + make_lhs(o_sum, o_sum, 1, o_n, o_m), + make_rhs(o_sum, o_sum * o_n, 1, o_m, 1)); + tl::expected correct = + make_output(o_sum * o_sum * o_m, 1, 1, o_n, 1); + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/test_conv_2d.cc b/lib/op-attrs/test/src/test_conv_2d.cc new file mode 100644 index 0000000000..b16a26a7b1 --- /dev/null +++ b/lib/op-attrs/test/src/test_conv_2d.cc @@ -0,0 +1,62 @@ +#include "doctest/doctest.h" +#include "op-attrs/ops/conv_2d.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(Conv2DAttrs, TensorShape)") { + int out_channels = 4; + int kernel_h = 3; + int kernel_w = 2; + int stride_h = 2; + int stride_w = 2; + int padding_h = 1; + int padding_w = 1; + int groups = 1; + std::optional activation = std::nullopt; + bool use_bias = true; + + Conv2DAttrs attrs = { + /*out_channels=*/out_channels, + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/padding_h, + /*padding_w=*/padding_w, + /*groups=*/groups, + /*activation=*/activation, + /*use_bias=*/true, + }; + + size_t num_samples = 7; + size_t input_channels = 6; + size_t input_height = 10; + size_t input_width = 15; + + TensorShape input_shape = { + TensorDims{FFOrdered{ + num_samples, + input_channels, + input_height, + input_width, + }}, + DataType::FLOAT, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + + size_t correct_output_height = 3; + size_t correct_output_width = 6; + + TensorShape correct_output_shape = { + TensorDims{FFOrdered{ + num_samples, + static_cast(out_channels), + correct_output_height, + correct_output_width, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct_output_shape); + } +} diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/test_dim_ordered.cc new file mode 100644 index 0000000000..17f4bae05f --- /dev/null +++ b/lib/op-attrs/test/src/test_dim_ordered.cc @@ -0,0 +1,13 @@ +#include "doctest/doctest.h" +#include "op-attrs/dim_ordered.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC", T, int, double, char) { + CHECK(rc::check("generate", + [](FFOrdered ff_dim, DimOrdered dim) {})); + } +} diff --git a/lib/op-attrs/test/src/test_element_binary.cc b/lib/op-attrs/test/src/test_element_binary.cc new file mode 100644 index 0000000000..b1aedbf6b5 --- /dev/null +++ b/lib/op-attrs/test/src/test_element_binary.cc @@ -0,0 +1,162 @@ +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("EWAdd shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape input_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape input_rhs = input_lhs; + + SUBCASE("correct") { + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = input_lhs; + + CHECK(result == correct); + } + + SUBCASE("mismatched dim size") { + TensorShape incorrect_rhs = input_lhs; + dim_at_idx(incorrect_rhs, ff_dim_t{0}) += 1; + + tl::expected result = + get_output_shape(attrs, input_lhs, incorrect_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } + + TEST_CASE("EWAdd parallel shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementBinaryAttrs attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }; + + TensorShape unpar_lhs = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + TensorShape unpar_rhs = unpar_lhs; + tl::expected result_unpar_output = + get_output_shape(attrs, unpar_lhs, unpar_rhs); + REQUIRE(result_unpar_output.has_value()); + TensorShape unpar_output = result_unpar_output.value(); + + auto make_lhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_lhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_rhs = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_rhs, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + unpar_output, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, degree, 1, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, degree, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(1, 1, degree, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("reduction parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(SumDegree{degree}, 1, 1, 1, 1); + ParallelTensorShape input_rhs = make_rhs(SumDegree{degree}, 1, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + tl::expected correct = + make_output(SumDegree{degree}, 1, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("invalid discard copy parallelism") { + int degree = 4; + + ParallelTensorShape input_lhs = + make_lhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + ParallelTensorShape input_rhs = + make_rhs(1, DiscardCopyDegree{degree}, 1, 1, 1); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + + SUBCASE("invalid mismatched parallelism degrees") { + int degree = 4; + + ParallelTensorShape input_lhs = make_lhs(1, 1, 1, degree, 1); + ParallelTensorShape input_rhs = make_rhs(1, 1, 1, 1, degree); + tl::expected result = + get_output_shape(attrs, input_lhs, input_rhs); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/test_element_unary.cc b/lib/op-attrs/test/src/test_element_unary.cc new file mode 100644 index 0000000000..384dbc1a53 --- /dev/null +++ b/lib/op-attrs/test/src/test_element_unary.cc @@ -0,0 +1,73 @@ +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ReLU shape inference") { + size_t d1 = 16; + size_t d2 = 32; + size_t d3 = 24; + + ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU}; + + TensorShape input = TensorShape{ + TensorDims{ + FFOrdered{ + d1, + d2, + d3, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + + auto make_i = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_1, + int o_2, + int o_3) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_1, o_2, o_3}); + }; + + SUBCASE("partition i.e., sharding parallelism") { + int degree1 = 4; + int degree2 = 8; + ParallelTensorShape par_input = make_i(1, 1, degree1, 1, degree2); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); + } + + SUBCASE("sum degree > 1") { + int degree = 2; + + tl::expected result = + get_output_shape(attrs, make_i(SumDegree{degree}, 1, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + + SUBCASE("discard copy degree > 1") { + int degree = 2; + + tl::expected result = get_output_shape( + attrs, make_i(1, DiscardCopyDegree{degree}, 1, 1, 1)); + + CHECK_MESSAGE(!result.has_value(), + "Unexpected successful result: ", + result.error()); + } + } +} diff --git a/lib/op-attrs/test/src/test_embedding.cc b/lib/op-attrs/test/src/test_embedding.cc new file mode 100644 index 0000000000..7bce6bd4d9 --- /dev/null +++ b/lib/op-attrs/test/src/test_embedding.cc @@ -0,0 +1,160 @@ +#include "op-attrs/ops/embedding.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest.h" +#include "utils/integer_conversions.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Sum embedding shape inference") { + int out_channels = 128; + int num_entries = 1024; + EmbeddingAttrs attrs = EmbeddingAttrs{ + /*num_entries=*/num_entries, + /*out_channels=*/out_channels, + /*aggr=*/AggregateOp::SUM, + /*data_type=*/DataType::FLOAT, + }; + + size_t batch_size = 48; + size_t features_dim = 56; + + TensorShape input = { + TensorDims{FFOrdered{ + batch_size, + features_dim, + }}, + DataType::INT32, + }; + + TensorShape output = TensorShape{ + TensorDims{ + FFOrdered{ + batch_size, + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + TensorShape weights = TensorShape{ + TensorDims{ + FFOrdered{ + size_t_from_int(num_entries), + size_t_from_int(out_channels), + }, + }, + DataType::FLOAT, + }; + + // get_output_shape + { + tl::expected output_result = + get_output_shape(attrs, input); + tl::expected output_correct = output; + CHECK(output_result == output_correct); + } + + // get_weights_shape + { + tl::expected weight_result = + get_weights_shape(attrs, input); + tl::expected weight_correct = weights; + CHECK(weight_result == weight_correct); + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_features) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o_batch, o_features}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_batch, + int o_outchannels) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o_batch, o_outchannels}); + }; + + auto make_weights = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_entries, + int o_outchannels) { + return lift_to_parallel_with_degrees( + weights, o_sum, o_eq, FFOrdered{o_entries, o_outchannels}); + }; + + SUBCASE("data parallelism") { + int degree = 4; + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + + { + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, par_input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + SUBCASE("input features parallelism") { + int degree = 4; + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + + { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{degree}, DiscardCopyDegree{1}, 1, 1); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + CHECK(result == correct); + } + } + + SUBCASE("output channel shard parallelism") { + // NOTE (@lockshaw): in the current (parallel shape inference from just + // input tensor) representation we have to choose between either + // parallelism in the weight channel dimension or in the weight entry + // dimension. For now we choose to represent parallelism in the channel + // dimension, but partitioning in the entry dimension is also potentially + // useful as it produces sum parallelism in the output + int degree = 4; + ParallelTensorShape input = + make_input(SumDegree{1}, DiscardCopyDegree{degree}, 1, 1); + + { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + + { + tl::expected result = + get_weights_shape(attrs, input); + tl::expected correct = + make_weights(SumDegree{1}, DiscardCopyDegree{1}, 1, degree); + CHECK(result == correct); + } + } + } +} diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc new file mode 100644 index 0000000000..a7724dba69 --- /dev/null +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -0,0 +1,35 @@ +#include "doctest/doctest.h" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/json.h" +#include +#include + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{true}; + json j = correct; + auto result = j.get(); + CHECK(result == correct); + } + + TEST_CASE("ComputationGraphAttrs to/from json") { + ComputationGraphOpAttrs correct = + ComputationGraphOpAttrs{BatchNormAttrs{true}}; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } + + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }}; + json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc new file mode 100644 index 0000000000..198c3add38 --- /dev/null +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -0,0 +1,14 @@ +#include "doctest/doctest.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("RC") { + CHECK(rc::check("valid variant", [](RegularizerAttrs reg) { + return reg.has() || reg.has(); + })); + } +} diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index 81009b0f1f..e1875ca694 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -13,3 +13,4 @@ ff_add_library( ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/pcg/include/pcg/computation_graph.dtg.h b/lib/pcg/include/pcg/computation_graph.dtg.h new file mode 100644 index 0000000000..217b940ce6 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.dtg.h @@ -0,0 +1,29 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H + +#include "pcg/dataflow_graph.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" + +namespace FlexFlow { +struct ComputationGraph { + ComputationGraph() = delete; + ComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph); + + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index c051fcc8c3..23003641cf 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,59 +1,23 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H -#include "layer.h" -#include "operator_guid_t.h" -#include "tensor.h" -#include "tensor_guid_t.h" -#include "utils/containers.h" -#include "utils/graph.h" -#include "utils/strong_typedef.h" -#include "visit_struct/visit_struct.hpp" +#include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { -struct ComputationGraph - : public strong_typedef> { - using strong_typedef::strong_typedef; +ComputationGraph make_empty_computation_graph(); - Layer &at(operator_guid_t const &n) { - return this->value().at(n.value()); - } +std::unordered_set get_layers(ComputationGraph const &); - Layer const &at(operator_guid_t const &n) const { - return this->value().at(n.value()); - } - - Tensor &at(tensor_guid_t const &e) { - return this->value().at(e.value()); - } - - Tensor const &at(tensor_guid_t const &e) const { - return this->value().at(e.value()); - } -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ComputationGraph); - -std::vector - traverse_comp_graph_forward(ComputationGraph const &comp_graph); -std::vector - traverse_comp_graph_backward(ComputationGraph const &comp_graph); -std::vector - get_outgoing_tensors(ComputationGraph const &comp_graph, operator_guid_t n); -std::vector - get_incoming_tensors(ComputationGraph const &comp_graph, operator_guid_t n); -operator_guid_t create_node(ComputationGraph &comp_graph, Layer const &layer); -tensor_guid_t create_outgoing_edge(ComputationGraph &comp_graph, - operator_guid_t node, - int idx, - Tensor tensor); - -void connect_incoming_edges(ComputationGraph &comp_graph, - std::vector const &incoming_edges, - operator_guid_t node); -CompGraphOperatorAttrs get_layer_attrs(ComputationGraph const &comp_graph, - operator_guid_t const &n); +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs); +TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.struct.toml b/lib/pcg/include/pcg/computation_graph.struct.toml new file mode 100644 index 0000000000..a270cb8fbe --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ComputationGraph" +features = [ ] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/dataflow_graph.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h new file mode 100644 index 0000000000..4fd78f2d44 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.dtg.h @@ -0,0 +1,37 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H + +#include "fmt/format.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include +#include + +namespace FlexFlow { +struct LayerAddedResult { + LayerAddedResult() = delete; + LayerAddedResult(::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs); + + bool operator==(LayerAddedResult const &) const; + bool operator!=(LayerAddedResult const &) const; + ::FlexFlow::layer_guid_t layer; + std::vector<::FlexFlow::tensor_guid_t> outputs; +}; +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &); +std::ostream &operator<<(std::ostream &, LayerAddedResult const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_LAYER_ADDED_RESULT_DTG_H diff --git a/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml new file mode 100644 index 0000000000..b02e992ba1 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "LayerAddedResult" +features = [ + "eq", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "pcg/tensor_guid_t.dtg.h", +] + +[[fields]] +name = "layer" +type = "::FlexFlow::layer_guid_t" + +[[fields]] +name = "outputs" +type = "std::vector<::FlexFlow::tensor_guid_t>" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 7ba95d701b..3a1526e9c8 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_BUILDER_H -#include "computation_graph.h" -#include "optimizer.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { -struct ComputationGraphBuilder - : public use_visitable_cmp { +struct ComputationGraphBuilder { public: ComputationGraphBuilder(); @@ -95,8 +95,8 @@ struct ComputationGraphBuilder std::optional const &activation = std::nullopt, int groups = 1, bool use_bias = true, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, std::optional const &kernel_regularizer = std::nullopt, std::optional const &name = std::nullopt); // Add a dropout layer @@ -111,7 +111,7 @@ struct ComputationGraphBuilder int outDim, AggregateOp aggr, DataType dtype = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &kernel_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a gather layer std::vector @@ -154,15 +154,15 @@ struct ComputationGraphBuilder int a_seq_length_dim = -1, int b_seq_length_dim = -1, std::optional const &name = std::nullopt); - tensor_guid_t - dense(tensor_guid_t const &input, - int outDim, - std::optional activation = std::nullopt, - bool use_bias = true, - DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, - std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + tensor_guid_t dense( + tensor_guid_t const &input, + int outDim, + std::optional activation = std::nullopt, + bool use_bias = true, + DataType data_type = DataType::FLOAT, + std::optional const &kernel_initializer = std::nullopt, + std::optional const &bias_initializer = std::nullopt, + std::optional const &name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, @@ -178,11 +178,11 @@ struct ComputationGraphBuilder bool keepdims, char const *name); // Add a split layer - void split(tensor_guid_t const &input, - tensor_guid_t *outputs, - std::vector const &split, - int axis, - std::optional const &name = std::nullopt); + std::vector + split(tensor_guid_t const &input, + std::vector const &split, + int axis, + std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, std::optional const &name = std::nullopt); @@ -191,8 +191,6 @@ struct ComputationGraphBuilder int dim = -1, std::optional const &name = std::nullopt); // Create input tensors and constants - tensor_guid_t input(Tensor const &input_tensor, - std::optional const &name = std::nullopt); tensor_guid_t transpose(tensor_guid_t const &input, std::vector const &perm, @@ -208,11 +206,11 @@ struct ComputationGraphBuilder tensor_guid_t reverse(tensor_guid_t const &input, int axis, std::optional const &name = std::nullopt); - void top_k(tensor_guid_t const &input, - tensor_guid_t *outputs, - int k, - bool sorted, - std::optional const &name = std::nullopt); + std::vector + top_k(tensor_guid_t const &input, + int k, + bool sorted, + std::optional const &name = std::nullopt); tensor_guid_t multihead_attention( tensor_guid_t const &query, tensor_guid_t const &key, @@ -225,44 +223,48 @@ struct ComputationGraphBuilder bool bias = true, bool add_bias_kv = false, bool add_zero_attn = false, - std::optional initializer = std::nullopt, + std::optional initializer = std::nullopt, std::optional const &name = std::nullopt); tensor_guid_t create_tensor(TensorShape const &, bool create_grad = true); - Parameter create_weight( + tensor_guid_t create_weight( TensorShape const &, bool create_grad = true, - std::optional const &initializer = std::nullopt, + std::optional const &initializer = std::nullopt, std::optional sync_type = std::nullopt); - std::vector get_outputs(operator_guid_t const &) const; - tensor_guid_t get_output(operator_guid_t const &, int idx) const; - Tensor get_tensor(tensor_guid_t const &) const; + std::vector get_outputs(LayerAttrs const &) const; + tensor_guid_t get_output(LayerAttrs const &, int idx) const; private: - tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); + TensorShape get_shape(tensor_guid_t const &) const; - void add_layer(Layer const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - tensor_guid_t add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape); - std::vector add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes); + tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output); + + std::vector add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); + + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + + TensorShape get_broadcast_target_shape(std::vector const &); TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_shape(tensor_guid_t const &t); - std::vector get_shapes(std::vector const &t); + tensor_guid_t element_binary(OperatorType, tensor_guid_t const &lhs, @@ -293,11 +295,4 @@ struct ComputationGraphBuilder } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ComputationGraphBuilder, computation_graph); - -namespace FlexFlow { -static_assert( - is_well_behaved_value_type_no_hash::value, ""); -} - #endif diff --git a/lib/pcg/include/pcg/cpu_id_t.dtg.h b/lib/pcg/include/pcg/cpu_id_t.dtg.h new file mode 100644 index 0000000000..a6c81e80b0 --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct cpu_id_t { + cpu_id_t() = delete; + cpu_id_t(int const &cpu_index); + + bool operator==(cpu_id_t const &) const; + bool operator!=(cpu_id_t const &) const; + bool operator<(cpu_id_t const &) const; + bool operator>(cpu_id_t const &) const; + bool operator<=(cpu_id_t const &) const; + bool operator>=(cpu_id_t const &) const; + int cpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::cpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::cpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::cpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &); +std::ostream &operator<<(std::ostream &, cpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/cpu_id_t.struct.toml b/lib/pcg/include/pcg/cpu_id_t.struct.toml new file mode 100644 index 0000000000..0492a937be --- /dev/null +++ b/lib/pcg/include/pcg/cpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "cpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "cpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/create_grad.dtg.h b/lib/pcg/include/pcg/create_grad.dtg.h new file mode 100644 index 0000000000..494ff06b75 --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class CreateGrad { YES, NO }; +std::string format_as(CreateGrad); +std::ostream &operator<<(std::ostream &, CreateGrad); +void to_json(::nlohmann::json &, CreateGrad); +void from_json(::nlohmann::json const &, CreateGrad &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::CreateGrad) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_CREATE_GRAD_DTG_H diff --git a/lib/pcg/include/pcg/create_grad.enum.toml b/lib/pcg/include/pcg/create_grad.enum.toml new file mode 100644 index 0000000000..20febe49fb --- /dev/null +++ b/lib/pcg/include/pcg/create_grad.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "CreateGrad" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "YES" + +[[values]] +name = "NO" diff --git a/lib/pcg/include/pcg/create_grad.h b/lib/pcg/include/pcg/create_grad.h index 7dd843b76d..5a12d310c2 100644 --- a/lib/pcg/include/pcg/create_grad.h +++ b/lib/pcg/include/pcg/create_grad.h @@ -1,36 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H #define _FLEXFLOW_PCG_INCLUDE_PCG_CREATE_GRAD_H -#include "utils/fmt.h" +#include "pcg/create_grad_t.h" -namespace FlexFlow { - -enum class CreateGrad { YES, NO }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::CreateGrad> : formatter { - template - auto format(::FlexFlow::CreateGrad ps, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (ps) { - case CreateGrad::YES: - name = "yes"; - break; - case CreateGrad::NO: - name = "no"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/dataflow_graph.h b/lib/pcg/include/pcg/dataflow_graph.h new file mode 100644 index 0000000000..f649c0444c --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_graph.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_DATAFLOW_GRAPH_H + +#include "utils/containers/enumerate_vector.h" +#include "utils/graph.h" + +namespace FlexFlow { + +template +struct DataflowGraph { +public: + DataflowGraph() + : g(OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>()) {} + + std::vector + add_operator(NodeLabel const &func, + std::vector const &inputs, + std::vector const &outputs) { + Node n = this->g.add_node(func); + for (auto const &[idx, input] : enumerate_vector(inputs)) { + this->g.add_edge(MultiDiEdge{ + input.src, input.src_idx, n, this->make_port_for_idx(idx)}); + } + + std::vector result; + for (auto const &[idx, label] : enumerate_vector(outputs)) { + MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; + this->g.add_output(output, label); + result.push_back(output); + } + + return result; + } + + NodePort make_port_for_idx(int idx) { + if (!this->port_mapping.contains_l(idx)) { + this->port_mapping.equate(idx, this->g.add_node_port()); + } + return this->port_mapping.at_l(idx); + } + + NodePort port_for_idx(int idx) const { + return this->port_mapping.at_l(idx); + } + + int idx_for_port(NodePort const &p) const { + return this->port_mapping.at_r(p); + } + + OutputLabelledMultiDiGraphView const & + get_raw_graph() const { + return this->g; + } + + NodeLabel const &at(Node const &n) const { + return this->g.at(n); + } + + OutputLabel const &at(MultiDiOutput const &o) const { + return this->g.at(o); + } + +private: + OutputLabelledMultiDiGraph g; + bidict port_mapping; +}; + +template +std::unordered_set + get_nodes(DataflowGraph const &g) { + return get_nodes(g.get_raw_graph()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/dataflow_input.dtg.h b/lib/pcg/include/pcg/dataflow_input.dtg.h new file mode 100644 index 0000000000..c698c75c25 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.dtg.h @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H + +#include "utils/graph/multidiedge.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct DataflowInput { + DataflowInput() = delete; + explicit DataflowInput(::FlexFlow::MultiDiOutput const &); + explicit DataflowInput(int const &); + template + static constexpr bool IsPartOfDataflowInput_v = + std::is_same_v || std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::MultiDiOutput>()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type DataflowInput", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::has() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfDataflowInput_v, + "DataflowInput::get() expected one of " + "[::FlexFlow::MultiDiOutput, int], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(DataflowInput const &) const; + bool operator!=(DataflowInput const &) const; + bool operator<(DataflowInput const &) const; + bool operator>(DataflowInput const &) const; + bool operator<=(DataflowInput const &) const; + bool operator>=(DataflowInput const &) const; + std::variant<::FlexFlow::MultiDiOutput, int> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::DataflowInput> { + size_t operator()(::FlexFlow::DataflowInput const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DATAFLOW_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/dataflow_input.variant.toml b/lib/pcg/include/pcg/dataflow_input.variant.toml new file mode 100644 index 0000000000..ac7c3ae5d7 --- /dev/null +++ b/lib/pcg/include/pcg/dataflow_input.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "DataflowInput" +features = [ + "eq", + "ord", + "hash", + # "json", + # "fmt", +] + +includes = [ + "utils/graph/multidiedge.h" , +] + +[[values]] +type = "::FlexFlow::MultiDiOutput" +key = "internal" + +[[values]] +type = "int" +key = "external" diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index b118d69259..be92be7081 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -1,35 +1,21 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H #define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_ID_H -#include "device_type.h" -#include "utils/strong_typedef.h" -#include +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/gpu_id_t.dtg.h" namespace FlexFlow { -struct gpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct cpu_id_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -using device_id_t = std::variant; device_id_t operator+(device_id_t, size_t); DeviceType get_device_type(device_id_t); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); -device_id_t from_index(int, DeviceType); +device_id_t device_id_from_index(int, DeviceType); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::gpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::gpu_id_t, "gpu_id"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::cpu_id_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::cpu_id_t, "cpu_id"); - #endif diff --git a/lib/pcg/include/pcg/device_id_t.dtg.h b/lib/pcg/include/pcg/device_id_t.dtg.h new file mode 100644 index 0000000000..d46f3dd079 --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.dtg.h @@ -0,0 +1,117 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct device_id_t { + device_id_t() = delete; + explicit device_id_t(::FlexFlow::gpu_id_t const &); + explicit device_id_t(::FlexFlow::cpu_id_t const &); + template + static constexpr bool IsPartOfdevice_id_t_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::gpu_id_t>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::cpu_id_t>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type device_id_t", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::has() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfdevice_id_t_v, + "device_id_t::get() expected one of [::FlexFlow::gpu_id_t, " + "::FlexFlow::cpu_id_t], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(device_id_t const &) const; + bool operator!=(device_id_t const &) const; + bool operator<(device_id_t const &) const; + bool operator>(device_id_t const &) const; + bool operator<=(device_id_t const &) const; + bool operator>=(device_id_t const &) const; + std::variant<::FlexFlow::gpu_id_t, ::FlexFlow::cpu_id_t> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::device_id_t> { + size_t operator()(::FlexFlow::device_id_t const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::device_id_t> { + static ::FlexFlow::device_id_t from_json(json const &); + static void to_json(json &, ::FlexFlow::device_id_t const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::device_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/device_id_t.variant.toml b/lib/pcg/include/pcg/device_id_t.variant.toml new file mode 100644 index 0000000000..71af18919f --- /dev/null +++ b/lib/pcg/include/pcg/device_id_t.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "device_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/cpu_id_t.dtg.h", + "pcg/gpu_id_t.dtg.h", +] + +[[values]] +type = "::FlexFlow::gpu_id_t" +key = "gpu" + +[[values]] +type = "::FlexFlow::cpu_id_t" +key = "cpu" diff --git a/lib/pcg/include/pcg/device_type.dtg.h b/lib/pcg/include/pcg/device_type.dtg.h new file mode 100644 index 0000000000..f5e90dc193 --- /dev/null +++ b/lib/pcg/include/pcg/device_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class DeviceType { GPU, CPU }; +std::string format_as(DeviceType); +std::ostream &operator<<(std::ostream &, DeviceType); +void to_json(::nlohmann::json &, DeviceType); +void from_json(::nlohmann::json const &, DeviceType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DeviceType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_DEVICE_TYPE_DTG_H diff --git a/lib/pcg/include/pcg/device_type.enum.toml b/lib/pcg/include/pcg/device_type.enum.toml new file mode 100644 index 0000000000..67f89fbc6f --- /dev/null +++ b/lib/pcg/include/pcg/device_type.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "DeviceType" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "GPU" + +[[values]] +name = "CPU" diff --git a/lib/pcg/include/pcg/device_type.h b/lib/pcg/include/pcg/device_type.h deleted file mode 100644 index 3ae374c5ea..0000000000 --- a/lib/pcg/include/pcg/device_type.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_DEVICE_TYPE_H - -#include "utils/fmt.h" - -namespace FlexFlow { - -enum class DeviceType { GPU, CPU }; - -} - -namespace fmt { - -template <> -struct formatter<::FlexFlow::DeviceType> : formatter { - template - auto format(::FlexFlow::DeviceType d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using ::FlexFlow::DeviceType; - - string_view name = "unknown"; - switch (d) { - case DeviceType::GPU: - name = "GPU"; - break; - case DeviceType::CPU: - name = "CPU"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/data_type.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h similarity index 64% rename from lib/pcg/include/pcg/file_format/v1/data_type.h rename to lib/pcg/include/pcg/file_format/v1/data_type_value.h index eab188155f..6e4e5abc54 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -9,23 +9,6 @@ namespace FlexFlow { using V1DataTypeValue = std::variant; -enum class V1DataType { - BOOL, - INT32, - INT64, - HALF, - FLOAT, - DOUBLE, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1DataType, - {{V1DataType::BOOL, "BOOL"}, - {V1DataType::INT32, "INT32"}, - {V1DataType::INT64, "INT64"}, - {V1DataType::HALF, "HALF"}, - {V1DataType::FLOAT, "FLOAT"}, - {V1DataType::DOUBLE, "DOUBLE"}}); - } // namespace FlexFlow namespace nlohmann { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h index 6bc852b0f1..dad73ce142 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs.h @@ -1,73 +1,23 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#include "operator_attrs.h" -#include "parallel_tensor.h" -#include "pcg/computation_graph.h" -#include "pcg/parallel_computation_graph.h" -#include "tensor.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/parallel_computation_graph.dtg.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" #include "utils/json.h" -#include "utils/required.h" -#include "utils/visitable.h" namespace FlexFlow { -struct V1GraphOutput { - req srcNode; - req srcIdx; -}; -FF_VISITABLE_STRUCT(V1GraphOutput, srcNode, srcIdx); -CHECK_IS_JSONABLE(V1GraphOutput); - -struct V1GraphEdge { - req srcNode; - req srcIdx; - req dstNode; - req dstIdx; -}; -FF_VISITABLE_STRUCT(V1GraphEdge, srcNode, srcIdx, dstNode, dstIdx); -CHECK_IS_JSONABLE(V1GraphEdge); - -struct V1MultiDiGraph { - req> nodes; - req> ports; - req> edges; -}; -FF_VISITABLE_STRUCT(V1MultiDiGraph, nodes, ports, edges); -CHECK_IS_JSONABLE(V1MultiDiGraph); -V1MultiDiGraph to_v1(MultiDiGraphView const &); -V1MultiDiGraph to_v1(MultiDiGraphView const &, - std::unordered_map const &, - std::unordered_map const &); - -template -struct V1JsonableGraph { - using node_id = size_t; - using tensor_id = size_t; - - req> node_labels; - req> outputs; - req> output_labels; - V1MultiDiGraph graph; -}; - -struct V1Layer { - V1CompGraphOperatorAttrs attrs; - req> name; -}; -FF_VISITABLE_STRUCT(V1Layer, attrs, name); -V1Layer to_v1(Layer const &); - -using V1ComputationGraph = V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ComputationGraph, node_labels, outputs, output_labels, graph); +using V1ComputationGraph = V1JsonableGraph; CHECK_IS_JSONABLE(V1ComputationGraph); V1ComputationGraph to_v1(ComputationGraph const &); using V1ParallelComputationGraph = - V1JsonableGraph; -FF_VISITABLE_STRUCT( - V1ParallelComputationGraph, node_labels, outputs, output_labels, graph); + V1JsonableGraph; CHECK_IS_JSONABLE(V1ParallelComputationGraph); V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h new file mode 100644 index 0000000000..e9238301d0 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphEdge { + V1GraphEdge() = delete; + V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx); + + bool operator==(V1GraphEdge const &) const; + bool operator!=(V1GraphEdge const &) const; + bool operator<(V1GraphEdge const &) const; + bool operator>(V1GraphEdge const &) const; + bool operator<=(V1GraphEdge const &) const; + bool operator>=(V1GraphEdge const &) const; + size_t srcNode; + size_t srcIdx; + size_t dstNode; + size_t dstIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphEdge const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphEdge from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphEdge const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &); +std::ostream &operator<<(std::ostream &, V1GraphEdge const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_EDGE_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml new file mode 100644 index 0000000000..b0d2546977 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "V1GraphEdge" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" + +[[fields]] +name = "dstNode" +type = "size_t" + +[[fields]] +name = "dstIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h new file mode 100644 index 0000000000..730282bdb9 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.dtg.h @@ -0,0 +1,55 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct V1GraphOutput { + V1GraphOutput() = delete; + V1GraphOutput(size_t const &srcNode, size_t const &srcIdx); + + bool operator==(V1GraphOutput const &) const; + bool operator!=(V1GraphOutput const &) const; + bool operator<(V1GraphOutput const &) const; + bool operator>(V1GraphOutput const &) const; + bool operator<=(V1GraphOutput const &) const; + bool operator>=(V1GraphOutput const &) const; + size_t srcNode; + size_t srcIdx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::V1GraphOutput const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1GraphOutput from_json(json const &); + static void to_json(json &, FlexFlow::V1GraphOutput const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &); +std::ostream &operator<<(std::ostream &, V1GraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml new file mode 100644 index 0000000000..ba41f7e43f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1GraphOutput" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "srcNode" +type = "size_t" + +[[fields]] +name = "srcIdx" +type = "size_t" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h new file mode 100644 index 0000000000..f183a14a9e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h @@ -0,0 +1,109 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include +#include +#include + +namespace FlexFlow { +template +struct V1JsonableGraph { + V1JsonableGraph() = delete; + V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph); + + std::unordered_map node_labels; + std::unordered_map outputs; + std::unordered_map output_labels; + ::FlexFlow::V1MultiDiGraph graph; +}; +} // namespace FlexFlow + +namespace nlohmann { +template +struct adl_serializer> { + static FlexFlow::V1JsonableGraph from_json(json const &); + static void to_json(json &, + FlexFlow::V1JsonableGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &); +template +std::ostream &operator<<(std::ostream &, + V1JsonableGraph const &); +} // namespace FlexFlow + +namespace FlexFlow { +template +V1JsonableGraph::V1JsonableGraph( + std::unordered_map const &node_labels, + std::unordered_map const &outputs, + std::unordered_map const &output_labels, + ::FlexFlow::V1MultiDiGraph const &graph) + : node_labels(node_labels), outputs(outputs), output_labels(output_labels), + graph(graph) {} +} // namespace FlexFlow + +namespace nlohmann { +template +FlexFlow::V1JsonableGraph + adl_serializer>::from_json( + json const &j) { + return { + j.at("node_labels").template get>(), + j.at("outputs") + .template get< + std::unordered_map>(), + j.at("output_labels").template get>(), + j.at("graph").template get<::FlexFlow::V1MultiDiGraph>()}; +} +template +void adl_serializer>::to_json( + json &j, FlexFlow::V1JsonableGraph const &v) { + j["__type"] = "V1JsonableGraph"; + j["node_labels"] = v.node_labels; + j["outputs"] = v.outputs; + j["output_labels"] = v.output_labels; + j["graph"] = v.graph; +} +} // namespace nlohmann + +namespace FlexFlow { +template +std::string format_as(V1JsonableGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +template +std::ostream &operator<<(std::ostream &s, + V1JsonableGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_JSONABLE_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml new file mode 100644 index 0000000000..ad9ba21c60 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "V1JsonableGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +template_params = [ + "NodeT", + "TensorT", +] + +includes = [ + "", + "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h", + "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", +] + +[[fields]] +name = "node_labels" +type = "std::unordered_map" + +[[fields]] +name = "outputs" +type = "std::unordered_map" + +[[fields]] +name = "output_labels" +type = "std::unordered_map" + +[[fields]] +name = "graph" +type = "::FlexFlow::V1MultiDiGraph" + diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h new file mode 100644 index 0000000000..5d7edcf1d8 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1MultiDiGraph { + V1MultiDiGraph() = delete; + V1MultiDiGraph(std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::vector ports; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1MultiDiGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1MultiDiGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &); +std::ostream &operator<<(std::ostream &, V1MultiDiGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h new file mode 100644 index 0000000000..49ff850a29 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_MULTIDIGRAPH_H + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +V1MultiDiGraph to_v1(MultiDiGraphView const &); +V1MultiDiGraph to_v1(MultiDiGraphView const &, + std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml new file mode 100644 index 0000000000..9650f3bd43 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "V1MultiDiGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "ports" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h new file mode 100644 index 0000000000..7e5554d44a --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.dtg.h @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +struct V1OperatorGraph { + V1OperatorGraph() = delete; + V1OperatorGraph(std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges); + + std::vector nodes; + std::unordered_set<::FlexFlow::V1GraphEdge> edges; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::V1OperatorGraph from_json(json const &); + static void to_json(json &, FlexFlow::V1OperatorGraph const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &); +std::ostream &operator<<(std::ostream &, V1OperatorGraph const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_V1_OPERATOR_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml new file mode 100644 index 0000000000..61dc45ae2e --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1OperatorGraph" +features = [ + # "eq", + # "ord", + # "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", + "utils/fmt.h", +] + +[[fields]] +name = "nodes" +type = "std::vector" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::V1GraphEdge>" diff --git a/lib/pcg/include/pcg/file_format/v1/initializer.h b/lib/pcg/include/pcg/file_format/v1/initializer.h deleted file mode 100644 index 21af7d55e0..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/initializer.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_INITIALIZER_H - -#include "data_type.h" -#include "utils/json.h" -#include "utils/required.h" -#include "utils/variant.h" -#include "utils/visitable.h" -#include "visit_struct/visit_struct_intrusive.hpp" - -namespace FlexFlow { - -struct V1GlorotInitializer { - req seed; -}; -FF_VISITABLE_STRUCT(V1GlorotInitializer, seed); - -struct V1ZeroInitializer {}; -FF_VISITABLE_STRUCT(V1ZeroInitializer); - -struct V1UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(V1UniformInitializer, seed, min_val, max_val); - -struct V1NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(V1NormInitializer, seed, mean, stddev); - -struct V1ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(V1ConstantInitializer, value); - -using V1Initializer = std::variant; - -} // namespace FlexFlow - -namespace FlexFlow { -CHECK_IS_JSONABLE(V1GlorotInitializer); -CHECK_IS_JSONABLE(V1ZeroInitializer); -CHECK_IS_JSONABLE(V1UniformInitializer); -CHECK_IS_JSONABLE(V1NormInitializer); -CHECK_IS_JSONABLE(V1ConstantInitializer); -CHECK_IS_JSONABLE(V1Initializer); -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h b/lib/pcg/include/pcg/file_format/v1/operator_attrs.h deleted file mode 100644 index 2830fbd301..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/operator_attrs.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_OPERATOR_ATTRS_H - -#include "utils/json.h" -#include - -namespace FlexFlow { - -struct V1Conv2DAttrs {}; -FF_VISITABLE_STRUCT(V1Conv2DAttrs); - -static_assert( - std::is_same, std::tuple<>>::value, ""); - -using V1CompGraphOperatorAttrs = std::variant; -using V1PCGOperatorAttrs = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h b/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h deleted file mode 100644 index c215569b21..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/parallel_tensor.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_PARALLEL_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "param_sync.h" -#include "utils/json.h" -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct V1ParallelDim { - size_t size; - int degree; - req is_replica_dim; -}; -FF_VISITABLE_STRUCT(V1ParallelDim, size, degree, is_replica_dim); - -struct V1ParallelTensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1ParallelTensorShape, dims, data_type); - -struct V1ParallelTensor { - V1ParallelTensorShape shape; - std::optional sync_type; - std::optional initializer; - req create_grad; -}; -FF_VISITABLE_STRUCT( - V1ParallelTensor, shape, sync_type, initializer, create_grad); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/param_sync.h b/lib/pcg/include/pcg/file_format/v1/param_sync.h deleted file mode 100644 index 32769a8d20..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/param_sync.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H -#define _FLEXFLOW_PCG_FILE_FORMAT_V1_PARAM_SYNC_H - -#include "utils/json.h" - -namespace FlexFlow { - -enum class V1ParamSync { PARAM_SERVER, NCCL }; - -NLOHMANN_JSON_SERIALIZE_ENUM(V1ParamSync, - {{V1ParamSync::PARAM_SERVER, "PARAM_SERVER"}, - {V1ParamSync::NCCL, "NCCL"}}); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/tensor.h b/lib/pcg/include/pcg/file_format/v1/tensor.h deleted file mode 100644 index c304a41401..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/tensor.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_TENSOR_H - -#include "data_type.h" -#include "initializer.h" -#include "op-attrs/tensor_shape.h" -#include "param_sync.h" -#include "pcg/tensor.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -struct V1TensorShape { - std::vector dims; - req data_type; -}; -FF_VISITABLE_STRUCT(V1TensorShape, dims, data_type); -CHECK_IS_JSONABLE(V1TensorShape); -V1TensorShape to_v1(TensorShape const &); - -struct V1Tensor { - V1TensorShape shape; - std::optional initializer; - bool create_gradients; - std::optional sync_type; - req> name; -}; -FF_VISITABLE_STRUCT( - V1Tensor, shape, initializer, create_gradients, sync_type, name); -CHECK_IS_JSONABLE(V1Tensor); -V1Tensor to_v1(Tensor const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/gpu_id_t.dtg.h b/lib/pcg/include/pcg/gpu_id_t.dtg.h new file mode 100644 index 0000000000..f0847848ca --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct gpu_id_t { + gpu_id_t() = delete; + gpu_id_t(int const &gpu_index); + + bool operator==(gpu_id_t const &) const; + bool operator!=(gpu_id_t const &) const; + bool operator<(gpu_id_t const &) const; + bool operator>(gpu_id_t const &) const; + bool operator<=(gpu_id_t const &) const; + bool operator>=(gpu_id_t const &) const; + int gpu_index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::gpu_id_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::gpu_id_t from_json(json const &); + static void to_json(json &, FlexFlow::gpu_id_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &); +std::ostream &operator<<(std::ostream &, gpu_id_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_GPU_ID_T_DTG_H diff --git a/lib/pcg/include/pcg/gpu_id_t.struct.toml b/lib/pcg/include/pcg/gpu_id_t.struct.toml new file mode 100644 index 0000000000..170dbb96fa --- /dev/null +++ b/lib/pcg/include/pcg/gpu_id_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "gpu_id_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "gpu_index" +type = "int" diff --git a/lib/pcg/include/pcg/initializer.h b/lib/pcg/include/pcg/initializer.h deleted file mode 100644 index 6913289653..0000000000 --- a/lib/pcg/include/pcg/initializer.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_INITIALIZER_H - -#include "op-attrs/datatype.h" -#include "utils/required.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct GlorotUniform { - req seed; - /* float scale; */ - /* DataType data_type; */ -}; -FF_VISITABLE_STRUCT(GlorotUniform, seed); - -struct ZeroInitializer { - ZeroInitializer() = default; -}; -FF_VISITABLE_STRUCT(ZeroInitializer); - -struct UniformInitializer { - int seed; - float min_val; - req max_val; -}; -FF_VISITABLE_STRUCT(UniformInitializer, seed, min_val, max_val); - -struct NormInitializer { - int seed; - float mean; - req stddev; -}; -FF_VISITABLE_STRUCT(NormInitializer, seed, mean, stddev); - -struct ConstantInitializer { - req value; -}; -FF_VISITABLE_STRUCT(ConstantInitializer, value); - -using Initializer = std::variant; -CHECK_WELL_BEHAVED_VALUE_TYPE(Initializer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializer_attrs.dtg.h new file mode 100644 index 0000000000..7f5a470a90 --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.dtg.h @@ -0,0 +1,169 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/initializers/constant_initializer_attrs.dtg.h" +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" +#include "pcg/initializers/norm_initializer_attrs.dtg.h" +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" +#include "pcg/initializers/zero_initializer_attrs.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct InitializerAttrs { + InitializerAttrs() = delete; + explicit InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &); + explicit InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::NormInitializerAttrs const &); + explicit InitializerAttrs(::FlexFlow::ConstantInitializerAttrs const &); + template + static constexpr bool IsPartOfInitializerAttrs_v = + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::GlorotUniformAttrs>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::ZeroInitializerAttrs>()); + return result; + } + case 2: { + ReturnType result = v(this->get<::FlexFlow::UniformInitializerAttrs>()); + return result; + } + case 3: { + ReturnType result = v(this->get<::FlexFlow::NormInitializerAttrs>()); + return result; + } + case 4: { + ReturnType result = + v(this->get<::FlexFlow::ConstantInitializerAttrs>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type InitializerAttrs", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::has() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfInitializerAttrs_v, + "InitializerAttrs::get() expected one of " + "[::FlexFlow::GlorotUniformAttrs, ::FlexFlow::ZeroInitializerAttrs, " + "::FlexFlow::UniformInitializerAttrs, " + "::FlexFlow::NormInitializerAttrs, " + "::FlexFlow::ConstantInitializerAttrs], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(InitializerAttrs const &) const; + bool operator!=(InitializerAttrs const &) const; + bool operator<(InitializerAttrs const &) const; + bool operator>(InitializerAttrs const &) const; + bool operator<=(InitializerAttrs const &) const; + bool operator>=(InitializerAttrs const &) const; + std::variant<::FlexFlow::GlorotUniformAttrs, + ::FlexFlow::ZeroInitializerAttrs, + ::FlexFlow::UniformInitializerAttrs, + ::FlexFlow::NormInitializerAttrs, + ::FlexFlow::ConstantInitializerAttrs> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::InitializerAttrs> { + size_t operator()(::FlexFlow::InitializerAttrs const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::InitializerAttrs> { + static ::FlexFlow::InitializerAttrs from_json(json const &); + static void to_json(json &, ::FlexFlow::InitializerAttrs const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ::FlexFlow::InitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml new file mode 100644 index 0000000000..14a5cfdcac --- /dev/null +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -0,0 +1,37 @@ +namespace = "FlexFlow" +name = "InitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/zero_initializer_attrs.dtg.h", + "pcg/initializers/uniform_initializer_attrs.dtg.h", + "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/constant_initializer_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::GlorotUniformAttrs" +key = "glorot_uniform" + +[[values]] +type = "::FlexFlow::ZeroInitializerAttrs" +key = "zero" + +[[values]] +type = "::FlexFlow::UniformInitializerAttrs" +key = "uniform" + +[[values]] +type = "::FlexFlow::NormInitializerAttrs" +key = "normal" + +[[values]] +type = "::FlexFlow::ConstantInitializerAttrs" +key = "constant" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1eb9eb8834 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include +#include +#include + +namespace FlexFlow { +struct ConstantInitializerAttrs { + ConstantInitializerAttrs() = delete; + ConstantInitializerAttrs(::FlexFlow::DataTypeValue const &value); + + bool operator==(ConstantInitializerAttrs const &) const; + bool operator!=(ConstantInitializerAttrs const &) const; + bool operator<(ConstantInitializerAttrs const &) const; + bool operator>(ConstantInitializerAttrs const &) const; + bool operator<=(ConstantInitializerAttrs const &) const; + bool operator>=(ConstantInitializerAttrs const &) const; + ::FlexFlow::DataTypeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstantInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ConstantInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ConstantInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ConstantInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_CONSTANT_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml new file mode 100644 index 0000000000..3a80559d7b --- /dev/null +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "ConstantInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/datatype.h", + "utils/json.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::DataTypeValue" diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h new file mode 100644 index 0000000000..04851fb333 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct GlorotUniformAttrs { + GlorotUniformAttrs() = delete; + GlorotUniformAttrs(int const &seed); + + bool operator==(GlorotUniformAttrs const &) const; + bool operator!=(GlorotUniformAttrs const &) const; + bool operator<(GlorotUniformAttrs const &) const; + bool operator>(GlorotUniformAttrs const &) const; + bool operator<=(GlorotUniformAttrs const &) const; + bool operator>=(GlorotUniformAttrs const &) const; + int seed; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::GlorotUniformAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::GlorotUniformAttrs from_json(json const &); + static void to_json(json &, FlexFlow::GlorotUniformAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &); +std::ostream &operator<<(std::ostream &, GlorotUniformAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_GLOROT_UNIFORM_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml new file mode 100644 index 0000000000..de7f9141b0 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotUniformAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h new file mode 100644 index 0000000000..e1d3e59ed7 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct NormInitializerAttrs { + NormInitializerAttrs() = delete; + NormInitializerAttrs(int const &seed, float const &mean, float const &stddev); + + bool operator==(NormInitializerAttrs const &) const; + bool operator!=(NormInitializerAttrs const &) const; + bool operator<(NormInitializerAttrs const &) const; + bool operator>(NormInitializerAttrs const &) const; + bool operator<=(NormInitializerAttrs const &) const; + bool operator>=(NormInitializerAttrs const &) const; + int seed; + float mean; + float stddev; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::NormInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::NormInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::NormInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, NormInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_NORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml new file mode 100644 index 0000000000..ec138de63e --- /dev/null +++ b/lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "NormInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h new file mode 100644 index 0000000000..1f4deada06 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct UniformInitializerAttrs { + UniformInitializerAttrs() = delete; + UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val); + + bool operator==(UniformInitializerAttrs const &) const; + bool operator!=(UniformInitializerAttrs const &) const; + bool operator<(UniformInitializerAttrs const &) const; + bool operator>(UniformInitializerAttrs const &) const; + bool operator<=(UniformInitializerAttrs const &) const; + bool operator>=(UniformInitializerAttrs const &) const; + int seed; + float min_val; + float max_val; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UniformInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::UniformInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::UniformInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, UniformInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_UNIFORM_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml new file mode 100644 index 0000000000..11a6597c0a --- /dev/null +++ b/lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "UniformInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "min_val" +type = "float" + +[[fields]] +name = "max_val" +type = "float" diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h new file mode 100644 index 0000000000..f3086ea087 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct ZeroInitializerAttrs { + bool operator==(ZeroInitializerAttrs const &) const; + bool operator!=(ZeroInitializerAttrs const &) const; + bool operator<(ZeroInitializerAttrs const &) const; + bool operator>(ZeroInitializerAttrs const &) const; + bool operator<=(ZeroInitializerAttrs const &) const; + bool operator>=(ZeroInitializerAttrs const &) const; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ZeroInitializerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ZeroInitializerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ZeroInitializerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &); +std::ostream &operator<<(std::ostream &, ZeroInitializerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_INITIALIZERS_ZERO_INITIALIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml new file mode 100644 index 0000000000..db1b6238d5 --- /dev/null +++ b/lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ZeroInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] +fields = [] diff --git a/lib/pcg/include/pcg/layer.h b/lib/pcg/include/pcg/layer.h deleted file mode 100644 index 9749cb9d06..0000000000 --- a/lib/pcg/include/pcg/layer.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_LAYER_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct Layer { -public: - Layer() = delete; - Layer(CompGraphOperatorAttrs const &attrs, - std::optional const &name); - -public: - std::optional> name; - CompGraphOperatorAttrs attrs; -}; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::Layer, attrs, name); -MAKE_VISIT_HASHABLE(::FlexFlow::Layer); - -namespace FlexFlow { - -FF_VISIT_FMTABLE(Layer); -// CHECK_FMTABLE(Layer); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/layer_attrs.dtg.h b/lib/pcg/include/pcg/layer_attrs.dtg.h new file mode 100644 index 0000000000..6afa1757dc --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct LayerAttrs { + LayerAttrs() = delete; + LayerAttrs(::FlexFlow::ComputationGraphOpAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(LayerAttrs const &) const; + bool operator!=(LayerAttrs const &) const; + bool operator<(LayerAttrs const &) const; + bool operator>(LayerAttrs const &) const; + bool operator<=(LayerAttrs const &) const; + bool operator>=(LayerAttrs const &) const; + ::FlexFlow::ComputationGraphOpAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::LayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::LayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &); +std::ostream &operator<<(std::ostream &, LayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml new file mode 100644 index 0000000000..9f8aaa5ba3 --- /dev/null +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "LayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/computation_graph_op_attrs.dtg.h", + "utils/stack_string.h", + "", + "utils/json.h" +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::ComputationGraphOpAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" + diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.h b/lib/pcg/include/pcg/layer_guid_t.dtg.h new file mode 100644 index 0000000000..4bbdd36fed --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct layer_guid_t { + layer_guid_t() = delete; + layer_guid_t(::FlexFlow::Node const &raw_node); + + bool operator==(layer_guid_t const &) const; + bool operator!=(layer_guid_t const &) const; + bool operator<(layer_guid_t const &) const; + bool operator>(layer_guid_t const &) const; + bool operator<=(layer_guid_t const &) const; + bool operator>=(layer_guid_t const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::layer_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &); +std::ostream &operator<<(std::ostream &, layer_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_LAYER_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/layer_guid_t.struct.toml b/lib/pcg/include/pcg/layer_guid_t.struct.toml new file mode 100644 index 0000000000..c6d4073f58 --- /dev/null +++ b/lib/pcg/include/pcg/layer_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "layer_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/machine_specification.dtg.h b/lib/pcg/include/pcg/machine_specification.dtg.h new file mode 100644 index 0000000000..cd6ffe6c0f --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include +#include +#include + +namespace FlexFlow { +struct MachineSpecification { + MachineSpecification() = delete; + MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth); + + bool operator==(MachineSpecification const &) const; + bool operator!=(MachineSpecification const &) const; + bool operator<(MachineSpecification const &) const; + bool operator>(MachineSpecification const &) const; + bool operator<=(MachineSpecification const &) const; + bool operator>=(MachineSpecification const &) const; + int num_nodes; + int num_cpus_per_node; + int num_gpus_per_node; + float inter_node_bandwidth; + float intra_node_bandwidth; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineSpecification const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineSpecification from_json(json const &); + static void to_json(json &, FlexFlow::MachineSpecification const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &); +std::ostream &operator<<(std::ostream &, MachineSpecification const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_DTG_H diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 1b2a02b070..cf84bf5048 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -1,31 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -#include "machine_view.h" -#include "utils/visitable.h" +#include "machine_specification_t.h" -namespace FlexFlow { - -struct BandwidthNetworkModelConfig - : public use_visitable_cmp { - int bandwidth; -}; - -struct MachineSpecification { - int num_nodes; - int num_cpus_per_node; - int num_gpus_per_node; - float inter_node_bandwidth; - req intra_node_bandwidth; -}; - -FF_VISITABLE_STRUCT(MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); - -} // namespace FlexFlow +namespace FlexFlow {} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_specification.struct.toml b/lib/pcg/include/pcg/machine_specification.struct.toml new file mode 100644 index 0000000000..e75b5018cb --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "MachineSpecification" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_nodes" +type = "int" + +[[fields]] +name = "num_cpus_per_node" +type = "int" + +[[fields]] +name = "num_gpus_per_node" +type = "int" + +[[fields]] +name = "inter_node_bandwidth" +type = "float" + +[[fields]] +name = "intra_node_bandwidth" +type = "float" diff --git a/lib/pcg/include/pcg/machine_view.dtg.h b/lib/pcg/include/pcg/machine_view.dtg.h new file mode 100644 index 0000000000..2eae6e2c8b --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.dtg.h @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct MachineView { + MachineView() = delete; + MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect); + + bool operator==(MachineView const &) const; + bool operator!=(MachineView const &) const; + bool operator<(MachineView const &) const; + bool operator>(MachineView const &) const; + bool operator<=(MachineView const &) const; + bool operator>=(MachineView const &) const; + ::FlexFlow::device_id_t start; + ::FlexFlow::StridedRectangle rect; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::MachineView const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::MachineView from_json(json const &); + static void to_json(json &, FlexFlow::MachineView const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &); +std::ostream &operator<<(std::ostream &, MachineView const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MACHINE_VIEW_DTG_H diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 7521cd209a..625b128d35 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,30 +1,19 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "device_id.h" -#include "device_type.h" -#include "strided_rectangle.h" -#include "utils/graph.h" -#include "utils/visitable.h" +#include "pcg/cpu_id_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/gpu_id_t.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/num_points_t.dtg.h" +#include "pcg/side_size_t.dtg.h" #include -#include #include namespace FlexFlow { -struct MachineView { - std::vector device_ids() const; - - device_id_t at(FFOrdered const &coord) const; - StridedRectangleSide at(size_t) const; - -public: - device_id_t start; - StridedRectangle rect; -}; - -FF_VISITABLE_STRUCT(MachineView, start, rect); - +std::vector device_ids(MachineView const &); std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml new file mode 100644 index 0000000000..c97731991f --- /dev/null +++ b/lib/pcg/include/pcg/machine_view.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineView" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_id_t.dtg.h", + "pcg/strided_rectangle.dtg.h", +] + +[[fields]] +name = "start" +type = "::FlexFlow::device_id_t" + +[[fields]] +name = "rect" +type = "::FlexFlow::StridedRectangle" diff --git a/lib/pcg/include/pcg/num_points_t.dtg.h b/lib/pcg/include/pcg/num_points_t.dtg.h new file mode 100644 index 0000000000..3b8e0e0c6c --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct num_points_t { + num_points_t() = delete; + num_points_t(int const &unwrapped); + + bool operator==(num_points_t const &) const; + bool operator!=(num_points_t const &) const; + bool operator<(num_points_t const &) const; + bool operator>(num_points_t const &) const; + bool operator<=(num_points_t const &) const; + bool operator>=(num_points_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::num_points_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::num_points_t from_json(json const &); + static void to_json(json &, FlexFlow::num_points_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &); +std::ostream &operator<<(std::ostream &, num_points_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_NUM_POINTS_T_DTG_H diff --git a/lib/pcg/include/pcg/num_points_t.struct.toml b/lib/pcg/include/pcg/num_points_t.struct.toml new file mode 100644 index 0000000000..b389245c63 --- /dev/null +++ b/lib/pcg/include/pcg/num_points_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "num_points_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/open_dataflow_graph.h b/lib/pcg/include/pcg/open_dataflow_graph.h new file mode 100644 index 0000000000..b3367686b3 --- /dev/null +++ b/lib/pcg/include/pcg/open_dataflow_graph.h @@ -0,0 +1,81 @@ +// #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPEN_DATAFLOW_GRAPH_H +// +// #include "utils/containers/enumerate_vector.h" +// #include "utils/graph.h" +// #include "pcg/dataflow_input.dtg.h" +// +// namespace FlexFlow { +// +// template +// struct OpenDataflowGraph { +// public: +// OpenDataflowGraph() +// : g(OutputLabelledOpenMultiDiGraph::template +// create< +// UnorderedOutputLabelledOpenMultiDiGraph>()) +// { } +// +// DataflowInput add_external_input(OutputLabel const &label) { +// /* size_t src_node_idx = edge_uid_ctr; */ +// /* edge_uid_ctr++; */ +// /* size_t src_port_idx = 0; */ +// /* edge_uid_t edge_uid = { src_node_idx, src_port_idx }; */ +// /* return MultiDiOutput{edge_uid}; */ +// } +// +// std::vector add_operator(NodeLabel const &func, +// std::vector const &inputs, std::vector const +// &outputs) { +// Node n = this->g.add_node(func); +// for (auto const &[idx, input] : enumerate_vector(inputs)) { +// this->g.add_edge(MultiDiEdge{input.src, input.src_idx, n, +// this->make_port_for_idx(idx)}); +// } +// +// std::vector result; +// for (auto const &[idx, label] : enumerate_vector(outputs)) { +// MultiDiOutput output = MultiDiOutput{n, this->make_port_for_idx(idx)}; +// this->g.add_output(output, label); +// result.push_back(output); +// } +// +// return result; +// } +// +// NodePort make_port_for_idx(int idx) { +// if (!this->port_mapping.contains_l(idx)) { +// this->port_mapping.equate(idx, this->g.add_node_port()); +// } +// return this->port_mapping.at_l(idx); +// } +// +// NodePort port_for_idx(int idx) const { +// return this->port_mapping.at_l(idx); +// } +// +// int idx_for_port(NodePort const &p) const { +// return this->port_mapping.at_r(p); +// } +// +// OutputLabelledMultiDiGraphView const +// &get_raw_graph() const { +// return this->g; +// } +// +// NodeLabel const &at(Node const &n) const { +// return this->g.at(n); +// } +// +// OutputLabel const &at(MultiDiOutput const &o) const { +// return this->g.at(o); +// } +// private: +// OutputLabelledOpenMultiDiGraph g; +// bidict port_mapping; +// size_t edge_uid_ctr = 0; +// }; +// +// } // namespace FlexFlow +// +// #endif diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h deleted file mode 100644 index bb9a4cf5e4..0000000000 --- a/lib/pcg/include/pcg/operator.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H - -#include "op-attrs/operator_attrs.h" -#include "utils/stack_string.h" -#include "utils/visitable.h" - -#include - -namespace FlexFlow { - -struct Operator { -public: - operator PCGOperatorAttrs() const; - -public: - PCGOperatorAttrs attrs; - req> name; -}; - -FF_VISITABLE_STRUCT(Operator, attrs, name); - -static_assert(is_well_behaved_value_type::value); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph.h b/lib/pcg/include/pcg/operator_graph/operator_graph.h new file mode 100644 index 0000000000..5fca50d4c7 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph.h @@ -0,0 +1,80 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_H + +#include "pcg/operator_graph/operator_graph_input.dtg.h" +#include "pcg/operator_graph/operator_graph_output.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +struct OperatorGraphOutputQuery {}; +struct OperatorGraphEdge {}; + +Node get_src_node(OperatorGraphEdge const &); +Node get_dst_node(OperatorGraphEdge const &); +int get_src_idx(OperatorGraphEdge const &); +int get_dst_idx(OperatorGraphEdge const &); + +struct OperatorGraphEdgeQuery; + +struct OperatorGraphView { +public: + using Edge = OperatorGraphEdge; + using EdgeQuery = OperatorGraphEdgeQuery; + + OperatorGraphView(OperatorGraphView const &); + OperatorGraphView &operator=(OperatorGraphView const &); + + OperatorGraphView(OperatorGraphView &&); + OperatorGraphView &&operator=(OperatorGraphView &&); + + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set + query_outputs(OperatorGraphOutputQuery const &) const; + std::unordered_set + query_edges(OperatorGraphEdgeQuery const &) const; + + struct Impl; + std::unique_ptr impl; +}; +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OperatorGraphView); + +std::unordered_set get_outputs(OperatorGraphView const &); +std::vector get_outputs(OperatorGraphView const &, + Node const &); +std::unordered_set get_uses(OperatorGraphView const &, + OperatorGraphOutput const &); + +struct OperatorGraph { +public: + OperatorGraph(); + OperatorGraph(OperatorGraph const &) = default; + OperatorGraph &operator=(OperatorGraph const &) = default; + + Node add_node(std::vector const &inputs, + int num_outputs); + +private: + struct Impl; + std::unique_ptr impl; +}; + +struct value_t; + +template +struct LabelledOperatorGraphView : virtual OperatorGraphView { + NodeLabel const &at(Node const &) const; + OutputLabel const &at(OperatorGraphOutput const &) const; +}; + +template +struct LabelledOperatorGraph + : virtual LabelledOperatorGraphView { + Node add_node(NodeLabel const &, + std::vector const &inputs, + std::vector const &output_labels); +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h new file mode 100644 index 0000000000..13904f220d --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphInput { + OperatorGraphInput() = delete; + OperatorGraphInput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphInput const &) const; + bool operator!=(OperatorGraphInput const &) const; + bool operator<(OperatorGraphInput const &) const; + bool operator>(OperatorGraphInput const &) const; + bool operator<=(OperatorGraphInput const &) const; + bool operator>=(OperatorGraphInput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphInput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphInput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.h b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h new file mode 100644 index 0000000000..18e7710186 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_INPUT_H + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &); +int get_idx(OperatorGraphInput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml new file mode 100644 index 0000000000..a729f75bae --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h new file mode 100644 index 0000000000..40bdc245b8 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorGraphOutput { + OperatorGraphOutput() = delete; + OperatorGraphOutput(::FlexFlow::Node const &node, int const &idx); + + bool operator==(OperatorGraphOutput const &) const; + bool operator!=(OperatorGraphOutput const &) const; + bool operator<(OperatorGraphOutput const &) const; + bool operator>(OperatorGraphOutput const &) const; + bool operator<=(OperatorGraphOutput const &) const; + bool operator>=(OperatorGraphOutput const &) const; + ::FlexFlow::Node node; + int idx; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorGraphOutput const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &); +std::ostream &operator<<(std::ostream &, OperatorGraphOutput const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_DTG_H diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.h b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h new file mode 100644 index 0000000000..d50b74f496 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GRAPH_OPERATOR_GRAPH_OUTPUT_H + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &); +int get_idx(OperatorGraphOutput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml new file mode 100644 index 0000000000..044d4c8df3 --- /dev/null +++ b/lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorGraphOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/pcg/include/pcg/operator_guid_t.dtg.h b/lib/pcg/include/pcg/operator_guid_t.dtg.h new file mode 100644 index 0000000000..bf08150e5e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct operator_guid_t { + operator_guid_t() = delete; + operator_guid_t(::FlexFlow::Node const &raw_graph_node); + + bool operator==(operator_guid_t const &) const; + bool operator!=(operator_guid_t const &) const; + bool operator<(operator_guid_t const &) const; + bool operator>(operator_guid_t const &) const; + bool operator<=(operator_guid_t const &) const; + bool operator>=(operator_guid_t const &) const; + ::FlexFlow::Node raw_graph_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::operator_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &); +std::ostream &operator<<(std::ostream &, operator_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPERATOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/operator_guid_t.h b/lib/pcg/include/pcg/operator_guid_t.h deleted file mode 100644 index 46b640774a..0000000000 --- a/lib/pcg/include/pcg/operator_guid_t.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_GUID_T_H - -#include "utils/graph.h" -#include "utils/strong_typedef.h" - -namespace FlexFlow { - -struct operator_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::operator_guid_t, "operator_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::operator_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/operator_guid_t.struct.toml b/lib/pcg/include/pcg/operator_guid_t.struct.toml new file mode 100644 index 0000000000..f89d30137e --- /dev/null +++ b/lib/pcg/include/pcg/operator_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "operator_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h deleted file mode 100644 index 0bb3fab974..0000000000 --- a/lib/pcg/include/pcg/optimizer.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H - -#include "utils/variant.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct SGDOptimizer { - double lr; - double momentum; - bool nesterov; - req weight_decay; -}; -FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); - -struct AdamOptimizer { - double alpha; - double beta1; - double beta2; - double weight_decay; - double epsilon; - double alpha_t; - double beta_t; - req beta2_t; -}; -FF_VISITABLE_STRUCT(AdamOptimizer, - alpha, - beta1, - beta2, - weight_decay, - epsilon, - alpha_t, - beta_t, - beta2_t); - -using Optimizer = std::variant; - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h new file mode 100644 index 0000000000..4bac74b999 --- /dev/null +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_OPTIMIZER_H + +#include "pcg/optimizers/adam_optimizer_attrs.h" +#include "pcg/optimizers/sgd_optimizer_attrs.h" +#include "utils/variant.h" + +namespace FlexFlow { + +using OptimizerAttrs = std::variant; + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..a5a6a5ed0a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.dtg.h @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct AdamOptimizerAttrs { + AdamOptimizerAttrs() = delete; + AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t); + + bool operator==(AdamOptimizerAttrs const &) const; + bool operator!=(AdamOptimizerAttrs const &) const; + bool operator<(AdamOptimizerAttrs const &) const; + bool operator>(AdamOptimizerAttrs const &) const; + bool operator<=(AdamOptimizerAttrs const &) const; + bool operator>=(AdamOptimizerAttrs const &) const; + double alpha; + double beta1; + double beta2; + double weight_decay; + double alpha_t; + double beta_t; + double beta2_t; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AdamOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::AdamOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::AdamOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, AdamOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_ADAM_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..fd3e83cc4a --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml @@ -0,0 +1,38 @@ +namespace = "FlexFlow" +name = "AdamOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "alpha" +type = "double" + +[[fields]] +name = "beta1" +type = "double" + +[[fields]] +name = "beta2" +type = "double" + +[[fields]] +name = "weight_decay" +type = "double" + +[[fields]] +name = "alpha_t" +type = "double" + +[[fields]] +name = "beta_t" +type = "double" + +[[fields]] +name = "beta2_t" +type = "double" diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h new file mode 100644 index 0000000000..f6a17f2354 --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.dtg.h @@ -0,0 +1,68 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct SGDOptimizerAttrs { + SGDOptimizerAttrs() = delete; + SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay); + + bool operator==(SGDOptimizerAttrs const &) const; + bool operator!=(SGDOptimizerAttrs const &) const; + bool operator<(SGDOptimizerAttrs const &) const; + bool operator>(SGDOptimizerAttrs const &) const; + bool operator<=(SGDOptimizerAttrs const &) const; + bool operator>=(SGDOptimizerAttrs const &) const; + double lr; + double momentum; + bool nesterov; + double weight_decay; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::SGDOptimizerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::SGDOptimizerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::SGDOptimizerAttrs const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &); +std::ostream &operator<<(std::ostream &, SGDOptimizerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_OPTIMIZERS_SGD_OPTIMIZER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml new file mode 100644 index 0000000000..37affb0e1f --- /dev/null +++ b/lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "SGDOptimizerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "lr" +type = "double" + +[[fields]] +name = "momentum" +type = "double" + +[[fields]] +name = "nesterov" +type = "bool" + +[[fields]] +name = "weight_decay" +type = "double" diff --git a/lib/pcg/include/pcg/parallel_computation_graph.dtg.h b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..01fbb7d30c --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/dataflow_graph.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { +struct ParallelComputationGraph { + ParallelComputationGraph() = delete; + ParallelComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const + &raw_graph); + + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 39a69a80ab..9d7103f4fd 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -1,31 +1,8 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H -#include "operator.h" -#include "parallel_tensor.h" -#include "utils/graph.h" +#include "pcg/parallel_computation_graph_t.h" -namespace FlexFlow { - -struct ParallelComputationGraph - : public strong_typedef< - ParallelComputationGraph, - OutputLabelledMultiDiGraph> { - using strong_typedef::strong_typedef; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); - -bool operator==(ParallelComputationGraph const &, - ParallelComputationGraph const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash { - size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; -}; -} // namespace std +namespace FlexFlow {} #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..d4e305abe5 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraph" +features = [ ] + +includes = [ + "pcg/dataflow_graph.h", + "pcg/parallel_tensor_attrs.dtg.h", + "pcg/parallel_layer_attrs.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h new file mode 100644 index 0000000000..4c7fce4038 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.dtg.h @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelLayerAttrs { + ParallelLayerAttrs() = delete; + ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name); + + bool operator==(ParallelLayerAttrs const &) const; + bool operator!=(ParallelLayerAttrs const &) const; + bool operator<(ParallelLayerAttrs const &) const; + bool operator>(ParallelLayerAttrs const &) const; + bool operator<=(ParallelLayerAttrs const &) const; + bool operator>=(ParallelLayerAttrs const &) const; + ::FlexFlow::PCGOperatorAttrs attrs; + std::optional<::FlexFlow::stack_string> name; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelLayerAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelLayerAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelLayerAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelLayerAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_LAYER_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml new file mode 100644 index 0000000000..9b1f8f47aa --- /dev/null +++ b/lib/pcg/include/pcg/parallel_layer_attrs.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "ParallelLayerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/operator_attrs.h", + "utils/stack_string.h", + "", +] + +[[fields]] +name = "attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "name" +type = "std::optional<::FlexFlow::stack_string>" diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index 652b408c15..de41e0fb21 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -21,56 +21,12 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_TENSOR_H -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/param_sync.h" +#include "pcg/parallel_tensor_attrs.h" -namespace FlexFlow { - -/** - * @brief Base structure of the parallel tensor representation. - * - * @details Parallel tensor is the fundamental component to support the - * representation and exploration of parallelization strategies. - */ -struct ParallelTensor : public use_visitable_cmp { - ParallelTensor() = delete; - - ParallelTensor(ParallelTensorShape const &, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - ParallelTensor(ParallelTensorDims const &, - DataType, - CreateGrad create_gradients, - std::optional sync_type = std::nullopt, - std::optional initializer = std::nullopt); - - ParallelTensorShape get_shape() const; - -public: - ParallelTensorDims dims; - DataType data_type; - std::optional sync_type = std::nullopt; - std::optional initializer = std::nullopt; - CreateGrad create_gradients; -}; - -using ParallelParameter = ParallelTensor; - -} // namespace FlexFlow - -VISITABLE_STRUCT(::FlexFlow::ParallelTensor, - dims, - data_type, - sync_type, - initializer, - create_gradients); -MAKE_VISIT_HASHABLE(::FlexFlow::ParallelTensor); +namespace FlexFlow {} // namespace FlexFlow namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); +static_assert(is_well_behaved_value_type::value, ""); } #endif diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h new file mode 100644 index 0000000000..fa6b153b0a --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct ParallelTensorAttrs { + ParallelTensorAttrs() = delete; + ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients); + + bool operator==(ParallelTensorAttrs const &) const; + bool operator!=(ParallelTensorAttrs const &) const; + bool operator<(ParallelTensorAttrs const &) const; + bool operator>(ParallelTensorAttrs const &) const; + bool operator<=(ParallelTensorAttrs const &) const; + bool operator>=(ParallelTensorAttrs const &) const; + ::FlexFlow::ParallelTensorShape shape; + std::optional<::FlexFlow::ParamSync> sync_type; + std::optional<::FlexFlow::InitializerAttrs> initializer; + ::FlexFlow::CreateGrad create_gradients; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ParallelTensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::ParallelTensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::ParallelTensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &); +std::ostream &operator<<(std::ostream &, ParallelTensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml new file mode 100644 index 0000000000..1f81b56ec8 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ParallelTensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "op-attrs/param_sync.dtg.h", + "pcg/initializer_attrs.dtg.h", + "pcg/create_grad.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "::FlexFlow::CreateGrad" diff --git a/lib/pcg/include/pcg/serialization.h b/lib/pcg/include/pcg/serialization.h deleted file mode 100644 index 28e16aeb1e..0000000000 --- a/lib/pcg/include/pcg/serialization.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_SERIALIZATION_H - -#include "computation_graph.h" -#include "layer.h" -#include "machine_specification.h" -#include "parallel_computation_graph.h" -#include "parallel_tensor.h" -#include "tensor_mapping.h" -#include "utils/json.h" - -namespace FlexFlow { - -void from_json(json const &, ComputationGraph &); -void to_json(json &, ComputationGraph const &); - -void from_json(json const &, ParallelComputationGraph &); -void to_json(json &, ParallelComputationGraph const &); - -void from_json(json const &, Layer &); -void to_json(json &, Layer const &); - -void from_json(json const &, ParallelTensor &); -void to_json(json &, ParallelTensor const &); - -void from_json(json const &, Tensor &); -void to_json(json &, Tensor const &); - -void from_json(json const &, Initializer &); -void to_json(json &, Initializer const &); - -void from_json(json const &, MachineSpecification &); -void to_json(json &, MachineSpecification const &); - -void from_json(json const &, Operator &); -void to_json(json &, Operator const &); - -void from_json(json const &, MachineView &); -void to_json(json &, MachineView const &); - -void from_json(json const &, StridedRectangle &); -void to_json(json &, StridedRectangle const &); - -void from_json(json const &, StridedRectangleSide &); -void to_json(json &, StridedRectangleSide const &); - -void from_json(json const &, ParallelTensorDims &); -void to_json(json &, ParallelTensorDims const &); - -void from_json(json const &, TensorDims &); -void to_json(json &, TensorDims const &); - -void from_json(json const &, TensorMapping &); -void to_json(json &, TensorMapping const &); - -void from_json(json const &, ParallelTensorShape &); -void to_json(json &, ParallelTensorShape const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/side_size_t.dtg.h b/lib/pcg/include/pcg/side_size_t.dtg.h new file mode 100644 index 0000000000..fce31b1c9d --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct side_size_t { + side_size_t() = delete; + side_size_t(int const &unwrapped); + + bool operator==(side_size_t const &) const; + bool operator!=(side_size_t const &) const; + bool operator<(side_size_t const &) const; + bool operator>(side_size_t const &) const; + bool operator<=(side_size_t const &) const; + bool operator>=(side_size_t const &) const; + int unwrapped; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::side_size_t const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::side_size_t from_json(json const &); + static void to_json(json &, FlexFlow::side_size_t const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &); +std::ostream &operator<<(std::ostream &, side_size_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_SIDE_SIZE_T_DTG_H diff --git a/lib/pcg/include/pcg/side_size_t.struct.toml b/lib/pcg/include/pcg/side_size_t.struct.toml new file mode 100644 index 0000000000..dbaad4fedb --- /dev/null +++ b/lib/pcg/include/pcg/side_size_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "side_size_t" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "unwrapped" +type = "int" diff --git a/lib/pcg/include/pcg/strided_rectangle.dtg.h b/lib/pcg/include/pcg/strided_rectangle.dtg.h new file mode 100644 index 0000000000..df6a16a0ad --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "817bbe017d179aa469822a4032d08836" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangle { + StridedRectangle() = delete; + StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides); + + bool operator==(StridedRectangle const &) const; + bool operator!=(StridedRectangle const &) const; + bool operator<(StridedRectangle const &) const; + bool operator>(StridedRectangle const &) const; + bool operator<=(StridedRectangle const &) const; + bool operator>=(StridedRectangle const &) const; + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> sides; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangle const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangle from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangle const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangle const &); +std::ostream &operator<<(std::ostream &, StridedRectangle const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index d123d7c6ac..24ae51ac41 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -1,62 +1,16 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H #define _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H -#include "op-attrs/dim_ordered.h" -#include "op-attrs/ff_dim.h" -#include "utils/stack_vector.h" -#include "utils/strong_typedef.h" -#include "utils/visitable.h" +#include "op-attrs/ff_dim.dtg.h" +#include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" namespace FlexFlow { -struct num_points_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct side_size_t : public strong_typedef { - using strong_typedef::strong_typedef; -}; - -struct StridedRectangleSide { -public: - StridedRectangleSide() = delete; - StridedRectangleSide(num_points_t const &, int stride); - StridedRectangleSide(side_size_t const &, int stride); - - num_points_t get_num_points() const; - side_size_t get_size() const; - int get_stride() const; - - side_size_t at(num_points_t) const; - num_points_t at(side_size_t) const; - -public: - num_points_t num_points; - req stride; -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, - num_points, - stride); - -struct StridedRectangle { -public: - size_t at(FFOrdered const &) const; - StridedRectangleSide at(ff_dim_t const &) const; - size_t num_dims() const; - -public: - FFOrdered sides; -}; - -FF_VISITABLE_STRUCT(StridedRectangle, sides); +size_t get_num_dims(StridedRectangle const &); +StridedRectangleSide get_side_at_idx(StridedRectangle const &, + ff_dim_t const &); } // namespace FlexFlow -MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); - -MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); - #endif diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml new file mode 100644 index 0000000000..3dfd90e296 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "StridedRectangle" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/strided_rectangle_side.dtg.h", + "op-attrs/dim_ordered.h", +] + +[[fields]] +name = "sides" +type = "::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>" diff --git a/lib/pcg/include/pcg/strided_rectangle_side.dtg.h b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h new file mode 100644 index 0000000000..3e4365c24d --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.dtg.h @@ -0,0 +1,65 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "pcg/num_points_t.dtg.h" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +struct StridedRectangleSide { + StridedRectangleSide() = delete; + StridedRectangleSide(::FlexFlow::num_points_t const &num_points, + int const &stride); + + bool operator==(StridedRectangleSide const &) const; + bool operator!=(StridedRectangleSide const &) const; + bool operator<(StridedRectangleSide const &) const; + bool operator>(StridedRectangleSide const &) const; + bool operator<=(StridedRectangleSide const &) const; + bool operator>=(StridedRectangleSide const &) const; + ::FlexFlow::num_points_t num_points; + int stride; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::StridedRectangleSide const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::StridedRectangleSide from_json(json const &); + static void to_json(json &, FlexFlow::StridedRectangleSide const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &); +std::ostream &operator<<(std::ostream &, StridedRectangleSide const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_DTG_H diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h new file mode 100644 index 0000000000..1486b73143 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H + +#include "pcg/side_size_t.dtg.h" +#include "pcg/strided_rectangle_side.dtg.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride); + +side_size_t get_side_size(StridedRectangleSide const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml new file mode 100644 index 0000000000..f26adfafd5 --- /dev/null +++ b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "StridedRectangleSide" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "pcg/num_points_t.dtg.h", +] + +[[fields]] +name = "num_points" +type = "::FlexFlow::num_points_t" + +[[fields]] +name = "stride" +type = "int" diff --git a/lib/pcg/include/pcg/tensor.h b/lib/pcg/include/pcg/tensor.h deleted file mode 100644 index b5ff857a6c..0000000000 --- a/lib/pcg/include/pcg/tensor.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_H - -#include "create_grad.h" -#include "initializer.h" -#include "op-attrs/param_sync.h" -#include "op-attrs/tensor_shape.h" - -namespace FlexFlow { - -struct Tensor { - /* Tensor() = delete; */ - /* Tensor(TensorShape const &, */ - /* CreateGrad create_gradients, */ - /* optional initializer = nullopt, */ - /* optional sync_type = nullopt); */ - - size_t get_volume() const; - TensorShape get_shape() const; - int num_dims() const; - - operator TensorShape() const; - -public: - TensorDims dims; - DataType data_type; - std::optional initializer; - bool create_gradients; - req> sync_type; -}; -FF_VISITABLE_STRUCT( - Tensor, dims, data_type, initializer, create_gradients, sync_type); - -using Parameter = Tensor; - -Tensor construct_tensor_from_output_shape(TensorShape const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/tensor_attrs.dtg.h b/lib/pcg/include/pcg/tensor_attrs.dtg.h new file mode 100644 index 0000000000..8bc9d3ce9d --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttrs { + TensorAttrs() = delete; + TensorAttrs(::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type); + + bool operator==(TensorAttrs const &) const; + bool operator!=(TensorAttrs const &) const; + bool operator<(TensorAttrs const &) const; + bool operator>(TensorAttrs const &) const; + bool operator<=(TensorAttrs const &) const; + bool operator>=(TensorAttrs const &) const; + ::FlexFlow::TensorShape shape; + std::optional<::FlexFlow::InitializerAttrs> initializer; + bool create_gradients; + std::optional<::FlexFlow::ParamSync> sync_type; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttrs const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttrs from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttrs const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &); +std::ostream &operator<<(std::ostream &, TensorAttrs const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_ATTRS_DTG_H diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml new file mode 100644 index 0000000000..eefb6da702 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "TensorAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/tensor_shape.dtg.h", + "pcg/initializer_attrs.dtg.h", + "op-attrs/param_sync.dtg.h", + "", +] + +[[fields]] +name = "shape" +type = "::FlexFlow::TensorShape" + +[[fields]] +name = "initializer" +type = "std::optional<::FlexFlow::InitializerAttrs>" + +[[fields]] +name = "create_gradients" +type = "bool" + +[[fields]] +name = "sync_type" +type = "std::optional<::FlexFlow::ParamSync>" diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.h b/lib/pcg/include/pcg/tensor_guid_t.dtg.h new file mode 100644 index 0000000000..c6109c6103 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H + +#include "fmt/format.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct tensor_guid_t { + tensor_guid_t() = delete; + tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output); + + bool operator==(tensor_guid_t const &) const; + bool operator!=(tensor_guid_t const &) const; + bool operator<(tensor_guid_t const &) const; + bool operator>(tensor_guid_t const &) const; + bool operator<=(tensor_guid_t const &) const; + bool operator>=(tensor_guid_t const &) const; + ::FlexFlow::MultiDiOutput raw_graph_output; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::tensor_guid_t const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &); +std::ostream &operator<<(std::ostream &, tensor_guid_t const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_PCG_INCLUDE_PCG_TENSOR_GUID_T_DTG_H diff --git a/lib/pcg/include/pcg/tensor_guid_t.h b/lib/pcg/include/pcg/tensor_guid_t.h deleted file mode 100644 index 3e4e840a5f..0000000000 --- a/lib/pcg/include/pcg/tensor_guid_t.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_TENSOR_GUID_T_H - -#include "utils/graph.h" - -namespace FlexFlow { - -struct tensor_guid_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -MAKE_TYPEDEF_PRINTABLE(::FlexFlow::tensor_guid_t, "tensor_guid"); -MAKE_TYPEDEF_HASHABLE(::FlexFlow::tensor_guid_t); - -#endif diff --git a/lib/pcg/include/pcg/tensor_guid_t.struct.toml b/lib/pcg/include/pcg/tensor_guid_t.struct.toml new file mode 100644 index 0000000000..aea4fad108 --- /dev/null +++ b/lib/pcg/include/pcg/tensor_guid_t.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + # "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph_output" +type = "::FlexFlow::MultiDiOutput" diff --git a/lib/pcg/src/computation_graph.cc b/lib/pcg/src/computation_graph.cc deleted file mode 100644 index 18fded6d3e..0000000000 --- a/lib/pcg/src/computation_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/computation_graph.h" - -namespace FlexFlow { - -std::vector - traverse_comp_graph_forward(ComputationGraph const &comp_graph) { - std::vector layers = get_topological_ordering(comp_graph.value()); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - traverse_comp_graph_backward(ComputationGraph const &comp_graph) { - std::vector layers = - reversed>(get_topological_ordering(comp_graph.value())); - return transform(layers, [&](Node const &e) -> operator_guid_t { - return operator_guid_t{e}; - }); -} - -std::vector - sort_edge_set(std::unordered_set edges) { - return transform( - sorted_by(edges, compare_by([](MultiDiEdge const &e) { - return e.src_idx; - })), - [&](MultiDiEdge const &e) -> tensor_guid_t { return tensor_guid_t{e}; }); -} - -std::vector - get_outgoing_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_outgoing_edges(comp_graph.value(), n.value())); -} - -std::vector - get_incoming_tensors(ComputationGraph const &comp_graph, - operator_guid_t n) { - return sort_edge_set(get_incoming_edges(comp_graph.value(), n.value())); -} - -operator_guid_t create_node(ComputationGraph &comp_graph, Layer const &layer) { - Node added_node = comp_graph.value().add_node(layer); - return operator_guid_t{added_node}; -} - -tensor_guid_t create_outgoing_edge(ComputationGraph &comp_graph, - operator_guid_t node, - int idx, - Tensor tensor) { - MultiDiOutput edge = {node.value(), NodePort{idx}}; - comp_graph.value().add_output(edge, tensor); - return tensor_guid_t{edge}; -} - -void connect_incoming_edges(ComputationGraph &comp_graph, - std::vector const &incoming_edges, - operator_guid_t node) { - size_t incoming_edge_dst_port = 0; - for (tensor_guid_t input : incoming_edges) { - MultiDiOutput input_view = input.value(); - MultiDiEdge edge = {node.value(), - NodePort{incoming_edge_dst_port++}, - input_view.src, - input_view.src_idx}; - comp_graph.value().add_edge(edge); - } -} - -CompGraphOperatorAttrs get_layer_attrs(ComputationGraph const &comp_graph, - operator_guid_t const &n) { - return comp_graph.at(n).attrs; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/device_id.cc b/lib/pcg/src/device_id.cc deleted file mode 100644 index 2849df7c3c..0000000000 --- a/lib/pcg/src/device_id.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "pcg/device_id.h" -#include "utils/exception.h" -#include - -namespace FlexFlow { - -DeviceType get_device_type(device_id_t const &id) { - if (std::holds_alternative(id)) { - return DeviceType::GPU; - } else { - assert(std::holds_alternative(id)); - return DeviceType::CPU; - } -} - -device_id_t operator+(device_id_t, size_t) { - NOT_IMPLEMENTED(); -} -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc index d00de7b0c1..eabd266e25 100644 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ b/lib/pcg/src/file_format/v1/graphs.cc @@ -1,23 +1,39 @@ #include "pcg/file_format/v1/graphs.h" +#include "pcg/dataflow_graph.h" +#include "pcg/file_format/v1/graphs/v1_multidigraph.h" +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" #include "utils/graph/algorithms.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -V1MultiDiGraph to_v1(MultiDiGraphView const &g) { - return to_v1(g, - enumerate(get_nodes(g)).reversed(), - enumerate(get_present_node_ports(g)).reversed()); -} +/* static V1OperatorGraph to_v1(OperatorGraphView const &g, bidict + * const &nodes) { */ +/* std::unordered_set edges; */ +/* for (MultiDiEdge const &e : get_edges(g)) { */ +/* size_t src_node = nodes.at_l(get_src_node(e)); */ +/* size_t dst_node = nodes.at_l(get_dst_node(e)); */ +/* size_t src_idx = size_t_from_int(get_src_idx(e)); */ +/* size_t dst_idx = size_t_from_int(get_dst_idx(e)); */ +/* V1GraphEdge v1_e = {src_node, src_idx, dst_node, dst_idx}; */ +/* edges.insert(v1_e); */ +/* } */ + +/* return V1OperatorGraph{ */ +/* count(nodes.size()), */ +/* edges, */ +/* }; */ +/* } */ -V1MultiDiGraph to_v1(MultiDiGraphView const &g, - std::unordered_map const &nodes, - std::unordered_map const &node_ports) { +static V1MultiDiGraph to_v1(MultiDiGraphView const &g, + bidict const &nodes, + bidict const &node_ports) { std::unordered_set edges; for (MultiDiEdge const &e : get_edges(g)) { - edges.insert({nodes.at(e.src), - node_ports.at(e.src_idx), - nodes.at(e.dst), - node_ports.at(e.dst_idx)}); + edges.insert({nodes.at_l(e.src), + node_ports.at_l(e.src_idx), + nodes.at_l(e.dst), + node_ports.at_l(e.dst_idx)}); } return V1MultiDiGraph{ @@ -27,32 +43,101 @@ V1MultiDiGraph to_v1(MultiDiGraphView const &g, }; } +/* static V1MultiDiGraph to_v1(MultiDiGraphView const &g) { */ +/* return to_v1(g, */ +/* enumerate(get_nodes(g)).reversed(), */ +/* enumerate(get_present_node_ports(g)).reversed()); */ +/* } */ + +/* template */ +/* static V1JsonableGraph */ +/* to_v1(LabelledOperatorGraphView const &g) { */ + +/* bidict nodes = enumerate(get_nodes(g)); */ + +/* V1OperatorGraph unlabelled = to_v1(g, nodes.reversed()); */ +/* std::unordered_map node_labels = */ +/* map_values(nodes, [&](Node const &n) { return g.at(n); }); */ + +/* bidict outputs_bidict = + * enumerate(get_outputs(g)); */ +/* std::unordered_map outputs = */ +/* map_values(outputs_bidict, [&](OperatorGraphOutput const &o) { */ +/* return V1GraphOutput{nodes.at_r(get_node(o)), + * size_t_from_int(get_idx(o))}; */ +/* }); */ + +/* std::unordered_map output_labels = map_values( */ +/* outputs_bidict, [&](OperatorGraphOutput const &o) { return g.at(o); }); + */ + +/* return {node_labels, outputs, output_labels, unlabelled}; */ +/* } */ + template -V1JsonableGraph())), - decltype(to_v1(std::declval()))> - to_v1(OutputLabelledMultiDiGraph const &g) { - using V1NodeLabel = decltype(to_v1(std::declval())); - using V1OutputLabel = decltype(to_v1(std::declval())); +static bidict + get_ports_by_idx(DataflowGraph const &g) { + bidict result; + for (NodePort const &p : get_present_node_ports(g.get_raw_graph())) { + size_t idx = size_t_from_int(g.idx_for_port(p)); + result.equate(idx, p); + } + return result; +} + +template +static V1JsonableGraph + to_v1(DataflowGraph const &g) { + + bidict nodes = enumerate(get_nodes(g.get_raw_graph())); + bidict node_ports = get_ports_by_idx(g); + + V1MultiDiGraph unlabelled = + to_v1(g.get_raw_graph(), nodes.reversed(), node_ports.reversed()); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + bidict outputs_bidict = + enumerate(get_outputs(g.get_raw_graph())); + std::unordered_map outputs = + map_values(outputs_bidict, [&](MultiDiOutput const &o) { + return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; + }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); + + return {node_labels, outputs, output_labels, unlabelled}; +} + +template +static V1JsonableGraph + to_v1(OutputLabelledMultiDiGraphView const &g) { bidict nodes = enumerate(get_nodes(g)); bidict node_ports = enumerate(get_present_node_ports(g)); V1MultiDiGraph unlabelled = to_v1(g, nodes.reversed(), node_ports.reversed()); - std::unordered_map node_labels = - map_values(nodes, [&](Node const &n) { return to_v1(g.at(n)); }); + std::unordered_map node_labels = + map_values(nodes, [&](Node const &n) { return g.at(n); }); + bidict outputs_bidict = enumerate(get_outputs(g)); std::unordered_map outputs = map_values(outputs_bidict, [&](MultiDiOutput const &o) { return V1GraphOutput{nodes.at_r(o.src), node_ports.at_r(o.src_idx)}; }); - std::unordered_map output_labels = map_values( - outputs_bidict, [&](MultiDiOutput const &o) { return to_v1(g.at(o)); }); + + std::unordered_map output_labels = map_values( + outputs_bidict, [&](MultiDiOutput const &o) { return g.at(o); }); return {node_labels, outputs, output_labels, unlabelled}; } V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.value()); + return to_v1(g.raw_graph); +} + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return to_v1(g.raw_graph); } } // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/v1.cc b/lib/pcg/src/file_format/v1/v1.cc deleted file mode 100644 index 7715985eed..0000000000 --- a/lib/pcg/src/file_format/v1/v1.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -V1Tensor to_v1(Tensor const &) { - NOT_IMPLEMENTED(); -} - -V1Layer to_v1(Layer const &) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/layer.cc b/lib/pcg/src/layer.cc deleted file mode 100644 index 00fb07a8c5..0000000000 --- a/lib/pcg/src/layer.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/layer.h" - -namespace FlexFlow { - -Layer::Layer(CompGraphOperatorAttrs const &_attrs, - std::optional const &_name) - : attrs(_attrs), name(_name) {} - -} // namespace FlexFlow diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc deleted file mode 100644 index 46f87833f0..0000000000 --- a/lib/pcg/src/machine_view.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "pcg/machine_view.h" -#include "utils/utils.h" - -namespace FlexFlow { - -static StridedRectangle make_1d_rect(int start, int stop, int stride) { - assert(stop > start); - assert(stride > 0); - StridedRectangleSide side = {side_size_t(stop - start), stride}; - StridedRectangle rect = {{side}}; - return rect; -} - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.value(), stop.value(), stride); - return {start, rect}; -} - -device_id_t MachineView::at(FFOrdered const &coord) const { - size_t offset = this->rect.at(coord); - return this->start + offset; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc deleted file mode 100644 index 9d36ae1b25..0000000000 --- a/lib/pcg/src/operator.cc +++ /dev/null @@ -1,9 +0,0 @@ -#include "pcg/operator.h" - -namespace FlexFlow { - -Operator::operator PCGOperatorAttrs() const { - return attrs; -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc deleted file mode 100644 index 011c40eb4c..0000000000 --- a/lib/pcg/src/parallel_computation_graph.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "pcg/parallel_computation_graph.h" -#include "utils/graph/algorithms.h" - -namespace FlexFlow { - -bool operator==(ParallelComputationGraph const &lhs, - ParallelComputationGraph const &rhs) { - return std::hash{}(lhs) == - std::hash{}(rhs); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash::operator()( - FlexFlow::ParallelComputationGraph const &g) const { - using namespace FlexFlow; - - size_t h = 0; - - std::vector ordered_nodes = get_topological_ordering(g.value()); - hash_combine(h, ordered_nodes.size()); - - std::unordered_map node_index; - for (int i = 0; i < ordered_nodes.size(); ++i) { - node_index[ordered_nodes[i]] = i; - hash_combine(h, g.value().at(ordered_nodes[i])); - } - - for (MultiDiEdge const &edge : get_edges(g.value())) { - hash_combine(h, node_index.at(edge.src)); - hash_combine(h, node_index.at(edge.dst)); - hash_combine(h, g.value().at(edge)); - } - - return h; -} - -} // namespace std diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc deleted file mode 100644 index ff53e456ec..0000000000 --- a/lib/pcg/src/parallel_tensor.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, - DataType data_type, - CreateGrad create_gradients, - std::optional sync_type, - std::optional initializer) - : dims(dims), data_type(data_type), sync_type(sync_type), - initializer(initializer), create_gradients(create_gradients) {} - -ParallelTensorShape ParallelTensor::get_shape() const { - return ParallelTensorShape(dims, data_type); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..12a72ca837 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -0,0 +1,20 @@ +#include "pcg/computation_graph.h" +#include "utils/containers.h" + +namespace FlexFlow { + +ComputationGraph make_empty_computation_graph() { + return ComputationGraph{DataflowGraph{}}; +} + +std::unordered_set get_layers(ComputationGraph const &cg) { + return transform(get_nodes(cg.raw_graph), + [&](Node const &n) { return layer_guid_t{n}; }); +} + +TensorAttrs get_tensor_attrs(ComputationGraph const &cg, + tensor_guid_t const &t) { + return cg.raw_graph.at(t.raw_graph_output); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.dtg.cc b/lib/pcg/src/pcg/computation_graph.dtg.cc new file mode 100644 index 0000000000..bb6233a910 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph.struct.toml +/* proj-data +{ + "generated_from": "8f1f0e13d75065944f7fe307e12fe280" +} +*/ + +#include "pcg/computation_graph.dtg.h" + +#include "pcg/dataflow_graph.h" +#include "pcg/layer_attrs.dtg.h" +#include "pcg/tensor_attrs.dtg.h" + +namespace FlexFlow { +ComputationGraph::ComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::LayerAttrs, + ::FlexFlow::TensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc new file mode 100644 index 0000000000..18b394f6d0 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/layer_added_result.dtg.cc @@ -0,0 +1,43 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/computation_graph/layer_added_result.struct.toml +/* proj-data +{ + "generated_from": "15bf9d73ef934599c9b11807d86ae5d4" +} +*/ + +#include "pcg/computation_graph/layer_added_result.dtg.h" + +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" +#include + +namespace FlexFlow { +LayerAddedResult::LayerAddedResult( + ::FlexFlow::layer_guid_t const &layer, + std::vector<::FlexFlow::tensor_guid_t> const &outputs) + : layer(layer), outputs(outputs) {} +bool LayerAddedResult::operator==(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) == + std::tie(other.layer, other.outputs); +} +bool LayerAddedResult::operator!=(LayerAddedResult const &other) const { + return std::tie(this->layer, this->outputs) != + std::tie(other.layer, other.outputs); +} +} // namespace FlexFlow + +namespace FlexFlow { +std::string format_as(LayerAddedResult const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAddedResult const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc similarity index 52% rename from lib/pcg/src/computation_graph_builder.cc rename to lib/pcg/src/pcg/computation_graph_builder.cc index f237232a76..8c69b3a724 100644 --- a/lib/pcg/src/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,94 +1,151 @@ #include "pcg/computation_graph_builder.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "pcg/computation_graph.h" +#include "utils/containers.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/enumerate_vector.h" #include "utils/expected.h" #include "utils/fmt.h" namespace FlexFlow { -tensor_guid_t ComputationGraphBuilder::add_layer( - Layer const &layer, - std::vector const &inputs, - std::vector>> const - &weight_shapes, - TensorShape const &output_shape) { - operator_guid_t node = create_node(computation_graph, layer); - connect_incoming_edges(computation_graph, inputs, node); - return create_outgoing_edge(computation_graph, - node, - 0, - construct_tensor_from_output_shape(output_shape)); +ComputationGraphBuilder::ComputationGraphBuilder() + : computation_graph(make_empty_computation_graph()) {} + +TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { + return get_tensor_attrs(this->computation_graph, t).shape; +} + +tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, + bool create_grad) { + TensorAttrs tensor_attrs = {shape, std::nullopt, create_grad, std::nullopt}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + std::nullopt, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); } std::vector ComputationGraphBuilder::add_layer( - Layer const &layer, + LayerAttrs const &layer, std::vector const &inputs, - std::vector>> const - &weight_shapes, - std::vector const &output_shapes) { - operator_guid_t node = create_node(computation_graph, layer); - connect_incoming_edges(computation_graph, inputs, node); - std::vector output_tensor_guids; - for (int i = 0; i < output_shapes.size(); ++i) { - output_tensor_guids.push_back(create_outgoing_edge( - computation_graph, - node, - i, - construct_tensor_from_output_shape(output_shapes[i]))); + std::vector const &weights, + std::vector const &outputs) { + std::vector raw_weight_tensors; + for (auto const &kv : enumerate_vector(weights)) { + int weight_idx = kv.first; + TensorAttrs weight_tensor_attrs = kv.second; + + std::optional weight_name = + transform(layer.name, [&](std::string const &layer_name) { + return fmt::format("{}.weights[{}]", layer_name, weight_idx); + }); + LayerAttrs weight_layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{WeightAttrs{}}, + weight_name, + }; + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + raw_weight_tensors.push_back( + get_only(this->computation_graph.raw_graph.add_operator( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs))); } - return output_tensor_guids; + + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = + this->computation_graph.raw_graph.add_operator( + layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs); + return transform(raw_outputs, + [](MultiDiOutput const &o) { return tensor_guid_t{o}; }); } -tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, - TensorShape const &) { - NOT_IMPLEMENTED(); +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorAttrs const &output) { + std::vector outputs = {output}; + return get_only(this->add_layer(layer, inputs, weights, outputs)); } + +std::vector ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs) { + return this->add_layer( + layer, inputs, weights, transform(outputs, [](TensorShape const &s) { + return TensorAttrs{s, std::nullopt, true, std::nullopt}; + })); +} + tensor_guid_t - ComputationGraphBuilder::cast(tensor_guid_t const &input, - DataType dtype, - std::optional const &name){ - NOT_IMPLEMENTED()} + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output) { + return get_only(this->add_layer( + layer, inputs, weights, std::vector{output})); +} tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, DataType data_type, std::string const &name) { - Tensor tensor = computation_graph.at(x); - if (tensor.data_type < data_type) { + DataType x_datatype = this->get_shape(x).data_type; + if (x_datatype < data_type) { return this->cast(x, data_type, name); - } else if (tensor.data_type > data_type) { - throw mk_runtime_error("Could not convert provided tensor data type {} to " - "desired data type {}", - tensor.data_type, - data_type); + } else if (x_datatype > data_type) { + throw mk_runtime_error( + fmt::format("Could not convert provided tensor data type {} to " + "desired data type {}", + x_datatype, + data_type)); + } else { + return x; } - return x; +} + +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + +tensor_guid_t + ComputationGraphBuilder::cast(tensor_guid_t const &input, + DataType dtype, + std::optional const &name) { + NOT_IMPLEMENTED() } static std::string get_default_name(OperatorType op_type) { return get_operator_type_name(op_type); } -static std::string get_default_name(ComputationGraphAttrs const &attrs) { +static std::string get_default_name(ComputationGraphOpAttrs const &attrs) { return get_default_name(get_op_type(attrs)); } -template -static std::string get_default_name(std::variant const &attrs) { - return get_default_name(widen(attrs)); -} - tensor_guid_t ComputationGraphBuilder::element_unary( ElementUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -97,14 +154,16 @@ tensor_guid_t ComputationGraphBuilder::element_scalar_unary( ElementScalarUnaryAttrs const &attrs, tensor_guid_t const &x, std::optional const &maybe_name) { - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); return this->add_layer(layer, {input}, {}, output_shape); } @@ -133,146 +192,109 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - Tensor lhs_tensor = computation_graph.at(lhs); - Tensor rhs_tensor = computation_graph.at(rhs); - - TensorShape compute_shape = - this->get_broadcast_target_shape({lhs_tensor, rhs_tensor}); - DataType compute_type = std::max(lhs_tensor.data_type, rhs_tensor.data_type); + TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + DataType compute_type = + std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); - tensor_guid_t const lhs_input = - this->as_type(this->broadcast(lhs, compute_shape), - compute_type, - name + "_inputl_pre_cast"); - tensor_guid_t const rhs_input = - this->as_type(this->broadcast(rhs, compute_shape), - compute_type, - name + "_inputr_pre_cast"); + tensor_guid_t lhs_input = this->as_type(this->broadcast(lhs, compute_shape), + compute_type, + name + "_inputl_pre_cast"); + tensor_guid_t rhs_input = this->as_type(this->broadcast(rhs, compute_shape), + compute_type, + name + "_inputr_pre_cast"); ElementBinaryAttrs attrs = {op_type, compute_type, false, false}; - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape( - attrs, computation_graph.at(lhs_input), computation_graph.at(rhs_input)); + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); -} + TensorShape output_shape = throw_if_unexpected(get_output_shape( + attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); -tensor_guid_t ComputationGraphBuilder::dense( - tensor_guid_t const &input, - int outDim, - std::optional activation, - bool use_bias, - DataType data_type, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, - std::optional const &name) { - LinearAttrs attrs = { - outDim, use_bias, data_type, activation.value(), std::nullopt}; - std::string unwrapped_name = name.value_or(get_default_name(attrs)); - - tensor_guid_t input_recast = - this->as_type(input, data_type, unwrapped_name + "input_recast"); - - Tensor input_recast_tensor = computation_graph.at(input_recast); - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input_recast_tensor); - Tensor output = { - output_shape.dims, data_type, std::nullopt, false, std::nullopt}; - - std::vector>> weights; - - weights.push_back( - {get_weights_shape(attrs, input_recast_tensor), kernel_initializer}); - - if (use_bias) { - weights.push_back( - {get_bias_shape(attrs, input_recast_tensor), bias_initializer}); - } - - return this->add_layer(layer, {input_recast}, weights, output); + return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); } tensor_guid_t ComputationGraphBuilder::exp(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::EXP, input, name); + return this->element_unary(OperatorType::EXP, input, name); } tensor_guid_t ComputationGraphBuilder::add(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_ADD, lhs, rhs, name); + return this->element_binary(OperatorType::EW_ADD, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::subtract(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_SUB, lhs, rhs, name); + return this->element_binary(OperatorType::EW_SUB, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::multiply(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MUL, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MUL, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::divide(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_DIV, lhs, rhs, name); + return this->element_binary(OperatorType::EW_DIV, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::max(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MAX, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MAX, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::min(tensor_guid_t const &lhs, tensor_guid_t const &rhs, std::optional const &name) { - return this->element_binary(Op::EW_MIN, lhs, rhs, name); + return this->element_binary(OperatorType::EW_MIN, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::rsqrt(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::RSQRT, input, name); + return this->element_unary(OperatorType::RSQRT, input, name); } tensor_guid_t ComputationGraphBuilder::pow(tensor_guid_t const &input, float exponent, std::optional const &name) { - return this->element_scalar_unary(Op::POW, input, exponent, name); + return this->element_scalar_unary(OperatorType::POW, input, exponent, name); } tensor_guid_t ComputationGraphBuilder::scalar_multiply( tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_MULTIPLY, input, scalar, name); + return this->element_scalar_unary( + OperatorType::SCALAR_MULTIPLY, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_add( tensor_guid_t const &input, float scalar, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_ADD, input, scalar, name); + return this->element_scalar_unary( + OperatorType::SCALAR_ADD, input, scalar, name); } tensor_guid_t ComputationGraphBuilder::scalar_sub( tensor_guid_t const &lhs, float rhs, std::optional const &name) { - return this->element_scalar_unary(Op::SCALAR_SUB, lhs, rhs, name); + return this->element_scalar_unary(OperatorType::SCALAR_SUB, lhs, rhs, name); } tensor_guid_t ComputationGraphBuilder::scalar_truediv( @@ -280,55 +302,61 @@ tensor_guid_t ComputationGraphBuilder::scalar_truediv( float denominator, std::optional const &name) { return this->element_scalar_unary( - Op::SCALAR_TRUE_DIV, numerator, denominator, name); + OperatorType::SCALAR_TRUE_DIV, numerator, denominator, name); } tensor_guid_t ComputationGraphBuilder::sin(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::SIN, input, name); + return this->element_unary(OperatorType::SIN, input, name); } tensor_guid_t ComputationGraphBuilder::cos(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::COS, input, name); + return this->element_unary(OperatorType::COS, input, name); } tensor_guid_t ComputationGraphBuilder::relu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::RELU, input, name); + return this->element_unary(OperatorType::RELU, input, name); } tensor_guid_t ComputationGraphBuilder::identity(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::IDENTITY, input, name); + return this->element_unary(OperatorType::IDENTITY, input, name); } tensor_guid_t ComputationGraphBuilder::gelu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::GELU, input, name); + return this->element_unary(OperatorType::GELU, input, name); } tensor_guid_t ComputationGraphBuilder::sigmoid(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::SIGMOID, input, name); + return this->element_unary(OperatorType::SIGMOID, input, name); } tensor_guid_t ComputationGraphBuilder::tanh(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::TANH, input, name); + return this->element_unary(OperatorType::TANH, input, name); } tensor_guid_t ComputationGraphBuilder::elu(tensor_guid_t const &input, std::optional const &name) { - return this->element_unary(Op::ELU, input, name); + return this->element_unary(OperatorType::ELU, input, name); +} + +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, true, std::nullopt}; } tensor_guid_t ComputationGraphBuilder::conv2d( @@ -343,8 +371,8 @@ tensor_guid_t ComputationGraphBuilder::conv2d( std::optional const &activation, int groups, bool use_bias, - std::optional const &kernel_initializer, - std::optional const &bias_initializer, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, std::optional const &kernel_regularizer, std::optional const &maybe_name) { Conv2DAttrs attrs = {outChannels, @@ -357,23 +385,26 @@ tensor_guid_t ComputationGraphBuilder::conv2d( groups, activation, use_bias}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Tensor input_tensor = computation_graph.at(input); + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - Layer layer = {attrs, name}; - TensorShape output_shape = get_output_shape(attrs, input_tensor); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = get_output_shape(attrs, input_shape); - std::vector>> weights; + std::vector weights; - weights.push_back( - {get_kernel_shape(attrs, input_tensor), kernel_initializer}); + weights.push_back(make_weight_attrs(get_kernel_shape(attrs, input_shape), + kernel_initializer)); if (use_bias) { - weights.push_back({get_bias_shape(attrs, input_tensor), bias_initializer}); + weights.push_back(make_weight_attrs(get_bias_shape(attrs, input_shape), + bias_initializer)); } return this->add_layer(layer, {input}, weights, output_shape); @@ -385,14 +416,14 @@ tensor_guid_t ComputationGraphBuilder::dropout( unsigned long long seed, std::optional const &maybe_name) { DropoutAttrs attrs = {rate, seed}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - TensorShape output_shape = - get_output_shape(attrs, computation_graph.at(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } @@ -403,21 +434,26 @@ tensor_guid_t ComputationGraphBuilder::embedding( int outDim, AggregateOp aggr, DataType dtype, - std::optional const &kernel_initializer, + std::optional const &kernel_initializer, std::optional const &maybe_name) { EmbeddingAttrs attrs = {num_entries, outDim, aggr, dtype}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; tensor_guid_t input = this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); - Tensor input_tensor = computation_graph.at(input); - TensorShape output_shape = get_output_shape(attrs, input_tensor); - TensorShape weights_shape = get_weights_shape(attrs, input_tensor); + TensorShape input_shape = this->get_shape(input); - return this->add_layer( - layer, {input}, {{weights_shape, kernel_initializer}}, output_shape); + TensorAttrs weight_attrs = make_weight_attrs( + throw_if_unexpected(get_weights_shape(attrs, input_shape)), + kernel_initializer); + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } std::vector ComputationGraphBuilder::gather( @@ -426,42 +462,30 @@ std::vector ComputationGraphBuilder::gather( ff_dim_t dim, std::optional const &maybe_name) { GatherAttrs attrs = {dim}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; - Tensor index_tensor = computation_graph.at(index); - if (index_tensor.data_type != DataType::INT32 && - index_tensor.data_type != DataType::INT64) { + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; + if (this->get_shape(index).data_type != DataType::INT32 && + this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " "{} (should be {} or {})", - index_tensor.data_type, + this->get_shape(input).data_type, DataType::INT32, DataType::INT64); } std::vector output_shapes = - get_output_shapes(attrs, computation_graph.at(input), index_tensor); + get_output_shapes(attrs, this->get_shape(input), this->get_shape(index)); return this->add_layer(layer, {input}, {}, output_shapes); } -tensor_guid_t - ComputationGraphBuilder::input(Tensor const &input_tensor, - std::optional const &name) { - InputAttrs input_attrs = {}; - std::string str_name = name.value_or(get_default_name(input_attrs)); - - Layer layer = {input_attrs, str_name}; - - return this->add_layer(layer, {}, {}, input_tensor); -} - -TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) { - return computation_graph.at(t).get_shape(); -} -std::vector - ComputationGraphBuilder::get_shapes(std::vector const &) { - NOT_IMPLEMENTED(); -} +/* std::vector + * ComputationGraphBuilder::get_shapes(std::vector const &ts) + * const { */ +/* return transform(ts, [&](tensor_guid_t const &t) { return + * this->get_shape(t); }); */ +/* } */ // tensor_guid_t ComputationGraphBuilder::aggregate( // tensor_guid_t const &gate_preds, @@ -475,7 +499,7 @@ std::vector // AggregateAttrs attrs = {n, lambda_bal}; // std::string name = maybe_name.value_or(get_default_name(attrs)); -// Layer layer = {attrs, name}; +// LayerAttrs layer = {attrs, name}; // TensorShape output_shape = get_output_shape(attrs, // get_shape(gate_preds), // get_shape(gate_assign), @@ -494,15 +518,21 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( bool relu, std::optional const &maybe_name) { BatchNormAttrs attrs = BatchNormAttrs{relu}; - std::string name = maybe_name.value_or(get_default_name(attrs)); + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - Layer layer = {attrs, name}; + LayerAttrs layer = {ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, get_shape(input)); + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); return this->add_layer(layer, {input}, {}, output_shape); } +TensorShape ComputationGraphBuilder::get_broadcast_target_shape( + std::vector const &) { + NOT_IMPLEMENTED(); +} + TensorShape ComputationGraphBuilder::get_broadcast_target_shape( std::vector const &) { NOT_IMPLEMENTED(); diff --git a/lib/pcg/src/pcg/cpu_id_t.dtg.cc b/lib/pcg/src/pcg/cpu_id_t.dtg.cc new file mode 100644 index 0000000000..f865442eb0 --- /dev/null +++ b/lib/pcg/src/pcg/cpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/cpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "a0faf78831febfa3a02929169943d9f5" +} +*/ + +#include "pcg/cpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +cpu_id_t::cpu_id_t(int const &cpu_index) : cpu_index(cpu_index) {} +bool cpu_id_t::operator==(cpu_id_t const &other) const { + return std::tie(this->cpu_index) == std::tie(other.cpu_index); +} +bool cpu_id_t::operator!=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) != std::tie(other.cpu_index); +} +bool cpu_id_t::operator<(cpu_id_t const &other) const { + return std::tie(this->cpu_index) < std::tie(other.cpu_index); +} +bool cpu_id_t::operator>(cpu_id_t const &other) const { + return std::tie(this->cpu_index) > std::tie(other.cpu_index); +} +bool cpu_id_t::operator<=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) <= std::tie(other.cpu_index); +} +bool cpu_id_t::operator>=(cpu_id_t const &other) const { + return std::tie(this->cpu_index) >= std::tie(other.cpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::cpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.cpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::cpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("cpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::cpu_id_t const &v) { + j["__type"] = "cpu_id_t"; + j["cpu_index"] = v.cpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(cpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, cpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/create_grad.dtg.cc b/lib/pcg/src/pcg/create_grad.dtg.cc new file mode 100644 index 0000000000..b2b7e3233b --- /dev/null +++ b/lib/pcg/src/pcg/create_grad.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/create_grad.enum.toml +/* proj-data +{ + "generated_from": "9fd617027e850b6d6db476a49b3e0334" +} +*/ + +#include "pcg/create_grad.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::CreateGrad x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(CreateGrad x) { + switch (x) { + case CreateGrad::YES: + return "YES"; + case CreateGrad::NO: + return "NO"; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, CreateGrad x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, CreateGrad x) { + switch (x) { + case CreateGrad::YES: + j = "YES"; + break; + case CreateGrad::NO: + j = "NO"; + break; + default: + std::ostringstream oss; + oss << "Unknown CreateGrad value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, CreateGrad &x) { + std::string as_str = j.get(); + if (as_str == "YES") { + x = CreateGrad::YES; + } else if (as_str == "NO") { + x = CreateGrad::NO; + } else { + std::ostringstream oss; + oss << "Unknown CreateGrad value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::CreateGrad::YES, + FlexFlow::CreateGrad::NO); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/dataflow_input.dtg.cc b/lib/pcg/src/pcg/dataflow_input.dtg.cc new file mode 100644 index 0000000000..bd5a43dfa9 --- /dev/null +++ b/lib/pcg/src/pcg/dataflow_input.dtg.cc @@ -0,0 +1,41 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/dataflow_input.variant.toml +/* proj-data +{ + "generated_from": "d6a7f4570e36e257383529e9bf9390ec" +} +*/ + +#include "pcg/dataflow_input.dtg.h" + +namespace FlexFlow { +DataflowInput::DataflowInput(::FlexFlow::MultiDiOutput const &v) + : raw_variant(v) {} +DataflowInput::DataflowInput(int const &v) : raw_variant(v) {} +bool DataflowInput::operator==(DataflowInput const &other) const { + return this->raw_variant == other.raw_variant; +} +bool DataflowInput::operator!=(DataflowInput const &other) const { + return this->raw_variant != other.raw_variant; +} +bool DataflowInput::operator<(DataflowInput const &other) const { + return this->raw_variant < other.raw_variant; +} +bool DataflowInput::operator>(DataflowInput const &other) const { + return this->raw_variant > other.raw_variant; +} +bool DataflowInput::operator<=(DataflowInput const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool DataflowInput::operator>=(DataflowInput const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::DataflowInput>::operator()( + ::FlexFlow::DataflowInput const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std diff --git a/lib/pcg/src/pcg/device_id.cc b/lib/pcg/src/pcg/device_id.cc new file mode 100644 index 0000000000..35b0c9aeda --- /dev/null +++ b/lib/pcg/src/pcg/device_id.cc @@ -0,0 +1,32 @@ +#include "pcg/device_id.h" +#include "utils/exception.h" +#include + +namespace FlexFlow { + +device_id_t operator+(device_id_t, size_t) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(device_id_t const &device_id) { + if (device_id.has()) { + return DeviceType::GPU; + } else { + assert(device_id.has()); + return DeviceType::CPU; + } +} + +gpu_id_t unwrap_gpu(device_id_t device_id) { + return device_id.get(); +} + +cpu_id_t unwrap_cpu(device_id_t device_id) { + return device_id.get(); +} + +device_id_t device_id_from_index(int, DeviceType) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_id_t.dtg.cc b/lib/pcg/src/pcg/device_id_t.dtg.cc new file mode 100644 index 0000000000..517c6c198c --- /dev/null +++ b/lib/pcg/src/pcg/device_id_t.dtg.cc @@ -0,0 +1,103 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_id_t.variant.toml +/* proj-data +{ + "generated_from": "85870050c742b0159775399ec2be67e3" +} +*/ + +#include "pcg/device_id_t.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +device_id_t::device_id_t(::FlexFlow::gpu_id_t const &v) : raw_variant(v) {} +device_id_t::device_id_t(::FlexFlow::cpu_id_t const &v) : raw_variant(v) {} +bool device_id_t::operator==(device_id_t const &other) const { + return this->raw_variant == other.raw_variant; +} +bool device_id_t::operator!=(device_id_t const &other) const { + return this->raw_variant != other.raw_variant; +} +bool device_id_t::operator<(device_id_t const &other) const { + return this->raw_variant < other.raw_variant; +} +bool device_id_t::operator>(device_id_t const &other) const { + return this->raw_variant > other.raw_variant; +} +bool device_id_t::operator<=(device_id_t const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool device_id_t::operator>=(device_id_t const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::device_id_t>::operator()( + ::FlexFlow::device_id_t const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::device_id_t + adl_serializer<::FlexFlow::device_id_t>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "gpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::gpu_id_t>()}; + } else if (key == "cpu") { + return ::FlexFlow::device_id_t{ + j.at("value").template get<::FlexFlow::cpu_id_t>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::device_id_t>::to_json( + json &j, ::FlexFlow::device_id_t const &x) { + j["__type"] = "device_id_t"; + switch (x.index()) { + case 0: { + j["type"] = "gpu"; + j["value"] = x.get<::FlexFlow::gpu_id_t>(); + break; + } + case 1: { + j["type"] = "cpu"; + j["value"] = x.get<::FlexFlow::cpu_id_t>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::device_id_t const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type device_id_t", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ::FlexFlow::device_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/device_type.dtg.cc b/lib/pcg/src/pcg/device_type.dtg.cc new file mode 100644 index 0000000000..8279cc4c16 --- /dev/null +++ b/lib/pcg/src/pcg/device_type.dtg.cc @@ -0,0 +1,70 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/device_type.enum.toml +/* proj-data +{ + "generated_from": "cfe4bc5e9f7c5796b9b90b420c33935f" +} +*/ + +#include "pcg/device_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()(FlexFlow::DeviceType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(DeviceType x) { + switch (x) { + case DeviceType::GPU: + return "GPU"; + case DeviceType::CPU: + return "CPU"; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, DeviceType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, DeviceType x) { + switch (x) { + case DeviceType::GPU: + j = "GPU"; + break; + case DeviceType::CPU: + j = "CPU"; + break; + default: + std::ostringstream oss; + oss << "Unknown DeviceType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, DeviceType &x) { + std::string as_str = j.get(); + if (as_str == "GPU") { + x = DeviceType::GPU; + } else if (as_str == "CPU") { + x = DeviceType::CPU; + } else { + std::ostringstream oss; + oss << "Unknown DeviceType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element(FlexFlow::DeviceType::GPU, + FlexFlow::DeviceType::CPU); +} +} // namespace rc diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc new file mode 100644 index 0000000000..713aa941d2 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_edge.dtg.cc @@ -0,0 +1,94 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +/* proj-data +{ + "generated_from": "865097b569b831af049343e933834329" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" + +#include + +namespace FlexFlow { +V1GraphEdge::V1GraphEdge(size_t const &srcNode, + size_t const &srcIdx, + size_t const &dstNode, + size_t const &dstIdx) + : srcNode(srcNode), srcIdx(srcIdx), dstNode(dstNode), dstIdx(dstIdx) {} +bool V1GraphEdge::operator==(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) == + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator!=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) != + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) < + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) > + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator<=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) <= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +bool V1GraphEdge::operator>=(V1GraphEdge const &other) const { + return std::tie(this->srcNode, this->srcIdx, this->dstNode, this->dstIdx) >= + std::tie(other.srcNode, other.srcIdx, other.dstNode, other.dstIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphEdge const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.dstIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphEdge + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get(), + j.at("dstNode").template get(), + j.at("dstIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphEdge const &v) { + j["__type"] = "V1GraphEdge"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; + j["dstNode"] = v.dstNode; + j["dstIdx"] = v.dstIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc new file mode 100644 index 0000000000..fa0b792a37 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_graph_output.dtg.cc @@ -0,0 +1,81 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_output.struct.toml +/* proj-data +{ + "generated_from": "05ff8401c3d976ea2220899edb8dfe3a" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_graph_output.dtg.h" + +#include + +namespace FlexFlow { +V1GraphOutput::V1GraphOutput(size_t const &srcNode, size_t const &srcIdx) + : srcNode(srcNode), srcIdx(srcIdx) {} +bool V1GraphOutput::operator==(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) == + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator!=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) != + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) < + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) > + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator<=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) <= + std::tie(other.srcNode, other.srcIdx); +} +bool V1GraphOutput::operator>=(V1GraphOutput const &other) const { + return std::tie(this->srcNode, this->srcIdx) >= + std::tie(other.srcNode, other.srcIdx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::V1GraphOutput const &x) const { + size_t result = 0; + result ^= std::hash{}(x.srcNode) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.srcIdx) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::V1GraphOutput + adl_serializer::from_json(json const &j) { + return {j.at("srcNode").template get(), + j.at("srcIdx").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1GraphOutput const &v) { + j["__type"] = "V1GraphOutput"; + j["srcNode"] = v.srcNode; + j["srcIdx"] = v.srcIdx; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1GraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1GraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc new file mode 100644 index 0000000000..7f7e670782 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.cc @@ -0,0 +1,10 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_jsonable_graph.struct.toml +/* proj-data +{ + "generated_from": "0595a9f5a6bc19f9a170cb0e42c4202d" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_jsonable_graph.dtg.h" diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc new file mode 100644 index 0000000000..0f5a83b02f --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_multidigraph.dtg.cc @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_multidigraph.struct.toml +/* proj-data +{ + "generated_from": "fb1033385645e54a19c9b44cef0be04b" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_multidigraph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1MultiDiGraph::V1MultiDiGraph( + std::vector const &nodes, + std::vector const &ports, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), ports(ports), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1MultiDiGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("ports").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1MultiDiGraph const &v) { + j["__type"] = "V1MultiDiGraph"; + j["nodes"] = v.nodes; + j["ports"] = v.ports; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1MultiDiGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1MultiDiGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc new file mode 100644 index 0000000000..19f1e09d07 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_operator_graph.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/file_format/v1/graphs/v1_operator_graph.struct.toml +/* proj-data +{ + "generated_from": "5bfd7d8755cfd8cd9dbf57d5c367038e" +} +*/ + +#include "pcg/file_format/v1/graphs/v1_operator_graph.dtg.h" + +#include "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h" +#include "utils/fmt.h" +#include +#include +#include + +namespace FlexFlow { +V1OperatorGraph::V1OperatorGraph( + std::vector const &nodes, + std::unordered_set<::FlexFlow::V1GraphEdge> const &edges) + : nodes(nodes), edges(edges) {} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::V1OperatorGraph + adl_serializer::from_json(json const &j) { + return {j.at("nodes").template get>(), + j.at("edges") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::V1OperatorGraph const &v) { + j["__type"] = "V1OperatorGraph"; + j["nodes"] = v.nodes; + j["edges"] = v.edges; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(V1OperatorGraph const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, V1OperatorGraph const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/gpu_id_t.dtg.cc b/lib/pcg/src/pcg/gpu_id_t.dtg.cc new file mode 100644 index 0000000000..e2385a83ce --- /dev/null +++ b/lib/pcg/src/pcg/gpu_id_t.dtg.cc @@ -0,0 +1,74 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/gpu_id_t.struct.toml +/* proj-data +{ + "generated_from": "022355e43f43141d332be50ea3080ee2" +} +*/ + +#include "pcg/gpu_id_t.dtg.h" + +#include + +namespace FlexFlow { +gpu_id_t::gpu_id_t(int const &gpu_index) : gpu_index(gpu_index) {} +bool gpu_id_t::operator==(gpu_id_t const &other) const { + return std::tie(this->gpu_index) == std::tie(other.gpu_index); +} +bool gpu_id_t::operator!=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) != std::tie(other.gpu_index); +} +bool gpu_id_t::operator<(gpu_id_t const &other) const { + return std::tie(this->gpu_index) < std::tie(other.gpu_index); +} +bool gpu_id_t::operator>(gpu_id_t const &other) const { + return std::tie(this->gpu_index) > std::tie(other.gpu_index); +} +bool gpu_id_t::operator<=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) <= std::tie(other.gpu_index); +} +bool gpu_id_t::operator>=(gpu_id_t const &other) const { + return std::tie(this->gpu_index) >= std::tie(other.gpu_index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()(FlexFlow::gpu_id_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.gpu_index) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::gpu_id_t + adl_serializer::from_json(json const &j) { + return {j.at("gpu_index").template get()}; +} +void adl_serializer::to_json(json &j, + FlexFlow::gpu_id_t const &v) { + j["__type"] = "gpu_id_t"; + j["gpu_index"] = v.gpu_index; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(gpu_id_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, gpu_id_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializer_attrs.dtg.cc new file mode 100644 index 0000000000..2a4e97db1e --- /dev/null +++ b/lib/pcg/src/pcg/initializer_attrs.dtg.cc @@ -0,0 +1,158 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializer_attrs.variant.toml +/* proj-data +{ + "generated_from": "f66f3a89ea937e96a058d83ab52e2826" +} +*/ + +#include "pcg/initializer_attrs.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +InitializerAttrs::InitializerAttrs(::FlexFlow::GlorotUniformAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::ZeroInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::UniformInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs(::FlexFlow::NormInitializerAttrs const &v) + : raw_variant(v) {} +InitializerAttrs::InitializerAttrs( + ::FlexFlow::ConstantInitializerAttrs const &v) + : raw_variant(v) {} +bool InitializerAttrs::operator==(InitializerAttrs const &other) const { + return this->raw_variant == other.raw_variant; +} +bool InitializerAttrs::operator!=(InitializerAttrs const &other) const { + return this->raw_variant != other.raw_variant; +} +bool InitializerAttrs::operator<(InitializerAttrs const &other) const { + return this->raw_variant < other.raw_variant; +} +bool InitializerAttrs::operator>(InitializerAttrs const &other) const { + return this->raw_variant > other.raw_variant; +} +bool InitializerAttrs::operator<=(InitializerAttrs const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool InitializerAttrs::operator>=(InitializerAttrs const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::InitializerAttrs>::operator()( + ::FlexFlow::InitializerAttrs const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::InitializerAttrs + adl_serializer<::FlexFlow::InitializerAttrs>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "glorot_uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::GlorotUniformAttrs>()}; + } else if (key == "zero") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ZeroInitializerAttrs>()}; + } else if (key == "uniform") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::UniformInitializerAttrs>()}; + } else if (key == "normal") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::NormInitializerAttrs>()}; + } else if (key == "constant") { + return ::FlexFlow::InitializerAttrs{ + j.at("value").template get<::FlexFlow::ConstantInitializerAttrs>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::InitializerAttrs>::to_json( + json &j, ::FlexFlow::InitializerAttrs const &x) { + j["__type"] = "InitializerAttrs"; + switch (x.index()) { + case 0: { + j["type"] = "glorot_uniform"; + j["value"] = x.get<::FlexFlow::GlorotUniformAttrs>(); + break; + } + case 1: { + j["type"] = "zero"; + j["value"] = x.get<::FlexFlow::ZeroInitializerAttrs>(); + break; + } + case 2: { + j["type"] = "uniform"; + j["value"] = x.get<::FlexFlow::UniformInitializerAttrs>(); + break; + } + case 3: { + j["type"] = "normal"; + j["value"] = x.get<::FlexFlow::NormInitializerAttrs>(); + break; + } + case 4: { + j["type"] = "constant"; + j["value"] = x.get<::FlexFlow::ConstantInitializerAttrs>(); + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::InitializerAttrs const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + case 3: { + oss << ""; + break; + } + case 4: { + oss << ""; + break; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type InitializerAttrs", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::InitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..9770c35248 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/constant_initializer_attrs.dtg.cc @@ -0,0 +1,80 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "0162b9c49fe6cbfc65410c6fa8dec427" +} +*/ + +#include "pcg/initializers/constant_initializer_attrs.dtg.h" + +#include "op-attrs/datatype.h" +#include "utils/json.h" +#include + +namespace FlexFlow { +ConstantInitializerAttrs::ConstantInitializerAttrs( + ::FlexFlow::DataTypeValue const &value) + : value(value) {} +bool ConstantInitializerAttrs::operator==( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool ConstantInitializerAttrs::operator!=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool ConstantInitializerAttrs::operator<=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool ConstantInitializerAttrs::operator>=( + ConstantInitializerAttrs const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ConstantInitializerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DataTypeValue>{}(x.value) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ConstantInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("value").template get<::FlexFlow::DataTypeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ConstantInitializerAttrs const &v) { + j["__type"] = "ConstantInitializerAttrs"; + j["value"] = v.value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ConstantInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ConstantInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc new file mode 100644 index 0000000000..0c8ae6e60c --- /dev/null +++ b/lib/pcg/src/pcg/initializers/glorot_uniform_attrs.dtg.cc @@ -0,0 +1,76 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/glorot_uniform_attrs.struct.toml +/* proj-data +{ + "generated_from": "a268b411b6d378faa11e60c8517d7be5" +} +*/ + +#include "pcg/initializers/glorot_uniform_attrs.dtg.h" + +#include + +namespace FlexFlow { +GlorotUniformAttrs::GlorotUniformAttrs(int const &seed) : seed(seed) {} +bool GlorotUniformAttrs::operator==(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) == std::tie(other.seed); +} +bool GlorotUniformAttrs::operator!=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) != std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) < std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) > std::tie(other.seed); +} +bool GlorotUniformAttrs::operator<=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) <= std::tie(other.seed); +} +bool GlorotUniformAttrs::operator>=(GlorotUniformAttrs const &other) const { + return std::tie(this->seed) >= std::tie(other.seed); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::GlorotUniformAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::GlorotUniformAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::GlorotUniformAttrs const &v) { + j["__type"] = "GlorotUniformAttrs"; + j["seed"] = v.seed; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(GlorotUniformAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, GlorotUniformAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..aceac12212 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/norm_initializer_attrs.dtg.cc @@ -0,0 +1,96 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/norm_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "6843fc9ca02aea2b40e57dbc497f99ac" +} +*/ + +#include "pcg/initializers/norm_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +NormInitializerAttrs::NormInitializerAttrs(int const &seed, + float const &mean, + float const &stddev) + : seed(seed), mean(mean), stddev(stddev) {} +bool NormInitializerAttrs::operator==(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) == + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator!=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) != + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) < + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) > + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator<=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) <= + std::tie(other.seed, other.mean, other.stddev); +} +bool NormInitializerAttrs::operator>=(NormInitializerAttrs const &other) const { + return std::tie(this->seed, this->mean, this->stddev) >= + std::tie(other.seed, other.mean, other.stddev); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::NormInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.mean) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stddev) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::NormInitializerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("seed").template get(), + j.at("mean").template get(), + j.at("stddev").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::NormInitializerAttrs const &v) { + j["__type"] = "NormInitializerAttrs"; + j["seed"] = v.seed; + j["mean"] = v.mean; + j["stddev"] = v.stddev; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary(), gen::arbitrary(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(NormInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, NormInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..a9c62675d0 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/uniform_initializer_attrs.dtg.cc @@ -0,0 +1,95 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/uniform_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f887e1db5d5dc710793ec5fa99bb7cd4" +} +*/ + +#include "pcg/initializers/uniform_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +UniformInitializerAttrs::UniformInitializerAttrs(int const &seed, + float const &min_val, + float const &max_val) + : seed(seed), min_val(min_val), max_val(max_val) {} +bool UniformInitializerAttrs::operator==( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) == + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator!=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) != + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) < + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) > + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator<=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) <= + std::tie(other.seed, other.min_val, other.max_val); +} +bool UniformInitializerAttrs::operator>=( + UniformInitializerAttrs const &other) const { + return std::tie(this->seed, this->min_val, this->max_val) >= + std::tie(other.seed, other.min_val, other.max_val); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UniformInitializerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.seed) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.min_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.max_val) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::UniformInitializerAttrs + adl_serializer::from_json( + json const &j) { + return {j.at("seed").template get(), + j.at("min_val").template get(), + j.at("max_val").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::UniformInitializerAttrs const &v) { + j["__type"] = "UniformInitializerAttrs"; + j["seed"] = v.seed; + j["min_val"] = v.min_val; + j["max_val"] = v.max_val; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(UniformInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UniformInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc new file mode 100644 index 0000000000..933501a734 --- /dev/null +++ b/lib/pcg/src/pcg/initializers/zero_initializer_attrs.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/initializers/zero_initializer_attrs.struct.toml +/* proj-data +{ + "generated_from": "a19d5a2cdc67a2840d6ba55250a10411" +} +*/ + +#include "pcg/initializers/zero_initializer_attrs.dtg.h" + +#include + +namespace FlexFlow { +bool ZeroInitializerAttrs::operator==(ZeroInitializerAttrs const &other) const { + return std::tie() == std::tie(); +} +bool ZeroInitializerAttrs::operator!=(ZeroInitializerAttrs const &other) const { + return std::tie() != std::tie(); +} +bool ZeroInitializerAttrs::operator<(ZeroInitializerAttrs const &other) const { + return std::tie() < std::tie(); +} +bool ZeroInitializerAttrs::operator>(ZeroInitializerAttrs const &other) const { + return std::tie() > std::tie(); +} +bool ZeroInitializerAttrs::operator<=(ZeroInitializerAttrs const &other) const { + return std::tie() <= std::tie(); +} +bool ZeroInitializerAttrs::operator>=(ZeroInitializerAttrs const &other) const { + return std::tie() >= std::tie(); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ZeroInitializerAttrs const &x) const { + size_t result = 0; + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ZeroInitializerAttrs + adl_serializer::from_json(json const &j) { + return {}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ZeroInitializerAttrs const &v) { + j["__type"] = "ZeroInitializerAttrs"; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(ZeroInitializerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ZeroInitializerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_attrs.dtg.cc b/lib/pcg/src/pcg/layer_attrs.dtg.cc new file mode 100644 index 0000000000..21c53ad4e8 --- /dev/null +++ b/lib/pcg/src/pcg/layer_attrs.dtg.cc @@ -0,0 +1,84 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e4f0c07a906139b599bd4696cb5e65" +} +*/ + +#include "pcg/layer_attrs.dtg.h" + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/json.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +LayerAttrs::LayerAttrs( + ::FlexFlow::ComputationGraphOpAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool LayerAttrs::operator==(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator!=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator<=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool LayerAttrs::operator>=(LayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::LayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ComputationGraphOpAttrs>{}(x.attrs) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::LayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::ComputationGraphOpAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::LayerAttrs const &v) { + j["__type"] = "LayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(LayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, LayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/layer_guid_t.dtg.cc b/lib/pcg/src/pcg/layer_guid_t.dtg.cc new file mode 100644 index 0000000000..9d92608569 --- /dev/null +++ b/lib/pcg/src/pcg/layer_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/layer_guid_t.struct.toml +/* proj-data +{ + "generated_from": "a672ffe470fd1dde8299f91f3038ca7a" +} +*/ + +#include "pcg/layer_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +layer_guid_t::layer_guid_t(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool layer_guid_t::operator==(layer_guid_t const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool layer_guid_t::operator!=(layer_guid_t const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool layer_guid_t::operator<(layer_guid_t const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool layer_guid_t::operator>(layer_guid_t const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool layer_guid_t::operator<=(layer_guid_t const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool layer_guid_t::operator>=(layer_guid_t const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::layer_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(layer_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, layer_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.dtg.cc b/lib/pcg/src/pcg/machine_specification.dtg.cc new file mode 100644 index 0000000000..238c61a014 --- /dev/null +++ b/lib/pcg/src/pcg/machine_specification.dtg.cc @@ -0,0 +1,151 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_specification.struct.toml +/* proj-data +{ + "generated_from": "72c3ae372af189d0c8bae74c2dbbc531" +} +*/ + +#include "pcg/machine_specification.dtg.h" + +#include + +namespace FlexFlow { +MachineSpecification::MachineSpecification(int const &num_nodes, + int const &num_cpus_per_node, + int const &num_gpus_per_node, + float const &inter_node_bandwidth, + float const &intra_node_bandwidth) + : num_nodes(num_nodes), num_cpus_per_node(num_cpus_per_node), + num_gpus_per_node(num_gpus_per_node), + inter_node_bandwidth(inter_node_bandwidth), + intra_node_bandwidth(intra_node_bandwidth) {} +bool MachineSpecification::operator==(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) == + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator!=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) != + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) < + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) > + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator<=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) <= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +bool MachineSpecification::operator>=(MachineSpecification const &other) const { + return std::tie(this->num_nodes, + this->num_cpus_per_node, + this->num_gpus_per_node, + this->inter_node_bandwidth, + this->intra_node_bandwidth) >= + std::tie(other.num_nodes, + other.num_cpus_per_node, + other.num_gpus_per_node, + other.inter_node_bandwidth, + other.intra_node_bandwidth); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineSpecification const &x) const { + size_t result = 0; + result ^= std::hash{}(x.num_nodes) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_cpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.num_gpus_per_node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.inter_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash{}(x.intra_node_bandwidth) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineSpecification + adl_serializer::from_json(json const &j) { + return {j.at("num_nodes").template get(), + j.at("num_cpus_per_node").template get(), + j.at("num_gpus_per_node").template get(), + j.at("inter_node_bandwidth").template get(), + j.at("intra_node_bandwidth").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineSpecification const &v) { + j["__type"] = "MachineSpecification"; + j["num_nodes"] = v.num_nodes; + j["num_cpus_per_node"] = v.num_cpus_per_node; + j["num_gpus_per_node"] = v.num_gpus_per_node; + j["inter_node_bandwidth"] = v.inter_node_bandwidth; + j["intra_node_bandwidth"] = v.intra_node_bandwidth; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineSpecification const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineSpecification const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc new file mode 100644 index 0000000000..ff1d34852b --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.cc @@ -0,0 +1,63 @@ +#include "pcg/machine_view.h" +#include "pcg/strided_rectangle.dtg.h" +#include "pcg/strided_rectangle_side.h" + +namespace FlexFlow { + +std::vector device_ids(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_dims(MachineView const &) { + NOT_IMPLEMENTED(); +} + +std::size_t num_devices(MachineView const &) { + NOT_IMPLEMENTED(); +} + +DeviceType get_device_type(MachineView const &) { + NOT_IMPLEMENTED(); +} + +static StridedRectangle make_1d_rect(int start, int stop, int stride) { + assert(stop > start); + assert(stride > 0); + StridedRectangleSide side = + strided_side_from_size_and_stride(side_size_t{stop - start}, stride); + StridedRectangle rect = {{side}}; + return rect; +} + +MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { + StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); + return {device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(device_id_t start, + num_points_t num_points, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, + side_size_t interval_size, + int stride) { + NOT_IMPLEMENTED(); +} + +MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { + NOT_IMPLEMENTED(); +} + +/* device_id_t MachineView::at(FFOrdered const &coord) const { */ +/* size_t offset = this->rect.at(coord); */ +/* return this->start + offset; */ +/* } */ + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.dtg.cc b/lib/pcg/src/pcg/machine_view.dtg.cc new file mode 100644 index 0000000000..edab125e3d --- /dev/null +++ b/lib/pcg/src/pcg/machine_view.dtg.cc @@ -0,0 +1,78 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/machine_view.struct.toml +/* proj-data +{ + "generated_from": "16c571e6bb82d7ef88e5d2a9146638f4" +} +*/ + +#include "pcg/machine_view.dtg.h" + +#include "pcg/device_id_t.dtg.h" +#include "pcg/strided_rectangle.dtg.h" +#include + +namespace FlexFlow { +MachineView::MachineView(::FlexFlow::device_id_t const &start, + ::FlexFlow::StridedRectangle const &rect) + : start(start), rect(rect) {} +bool MachineView::operator==(MachineView const &other) const { + return std::tie(this->start, this->rect) == std::tie(other.start, other.rect); +} +bool MachineView::operator!=(MachineView const &other) const { + return std::tie(this->start, this->rect) != std::tie(other.start, other.rect); +} +bool MachineView::operator<(MachineView const &other) const { + return std::tie(this->start, this->rect) < std::tie(other.start, other.rect); +} +bool MachineView::operator>(MachineView const &other) const { + return std::tie(this->start, this->rect) > std::tie(other.start, other.rect); +} +bool MachineView::operator<=(MachineView const &other) const { + return std::tie(this->start, this->rect) <= std::tie(other.start, other.rect); +} +bool MachineView::operator>=(MachineView const &other) const { + return std::tie(this->start, this->rect) >= std::tie(other.start, other.rect); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::MachineView const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::device_id_t>{}(x.start) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::StridedRectangle>{}(x.rect) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::MachineView + adl_serializer::from_json(json const &j) { + return {j.at("start").template get<::FlexFlow::device_id_t>(), + j.at("rect").template get<::FlexFlow::StridedRectangle>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::MachineView const &v) { + j["__type"] = "MachineView"; + j["start"] = v.start; + j["rect"] = v.rect; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(MachineView const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, MachineView const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/num_points_t.dtg.cc b/lib/pcg/src/pcg/num_points_t.dtg.cc new file mode 100644 index 0000000000..7a0a849814 --- /dev/null +++ b/lib/pcg/src/pcg/num_points_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/num_points_t.struct.toml +/* proj-data +{ + "generated_from": "2a862b92055eda0508447d2f4df52f71" +} +*/ + +#include "pcg/num_points_t.dtg.h" + +#include + +namespace FlexFlow { +num_points_t::num_points_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool num_points_t::operator==(num_points_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool num_points_t::operator!=(num_points_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool num_points_t::operator<(num_points_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool num_points_t::operator>(num_points_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool num_points_t::operator<=(num_points_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool num_points_t::operator>=(num_points_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::num_points_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::num_points_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::num_points_t const &v) { + j["__type"] = "num_points_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(num_points_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, num_points_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph.cc b/lib/pcg/src/pcg/operator_graph/operator_graph.cc new file mode 100644 index 0000000000..461fc8027c --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph.cc @@ -0,0 +1,48 @@ +#include "pcg/operator_graph/operator_graph.h" +#include "utils/graph.h" + +namespace FlexFlow { + +/* struct OperatorGraphView::Impl { */ +/* MultiDiGraphView raw_graph; */ +/* }; */ + +/* struct OperatorGraph::Impl { */ +/* MultiDiGraph raw_graph; */ +/* }; */ + +/* std::unordered_set get_outputs(OperatorGraphView const + * &g) { */ +/* return transform(get_outputs(g.impl->raw_graph), [](MultiDiOutput const &o) + * {}); */ +/* } */ + +/* std::vector get_outputs(OperatorGraphView const &, Node + * const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* std::unordered_set get_uses(OperatorGraphView const &, + * OperatorGraphOutput const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_src_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* Node get_dst_node(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_src_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* int get_dst_idx(OperatorGraphEdge const &) { */ +/* NOT_IMPLEMENTED(); */ +/* } */ + +/* OperatorGraphView::query_nodes */ + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc new file mode 100644 index 0000000000..945034dd73 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_input.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphInput const &i) { + return i.node; +} + +int get_idx(OperatorGraphInput const &i) { + return i.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc new file mode 100644 index 0000000000..381c948ad0 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_input.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_input.struct.toml +/* proj-data +{ + "generated_from": "57d9c9afc86f43049c6f035c74477afd" +} +*/ + +#include "pcg/operator_graph/operator_graph_input.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphInput::OperatorGraphInput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphInput::operator==(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator!=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator<=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphInput::operator>=(OperatorGraphInput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphInput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphInput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphInput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc new file mode 100644 index 0000000000..bdfe1a9795 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.cc @@ -0,0 +1,13 @@ +#include "pcg/operator_graph/operator_graph_output.h" + +namespace FlexFlow { + +Node get_node(OperatorGraphOutput const &o) { + return o.node; +} + +int get_idx(OperatorGraphOutput const &o) { + return o.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc new file mode 100644 index 0000000000..88c23c0c67 --- /dev/null +++ b/lib/pcg/src/pcg/operator_graph/operator_graph_output.dtg.cc @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_graph/operator_graph_output.struct.toml +/* proj-data +{ + "generated_from": "3931cb388b00e0634495cdb89cb2af54" +} +*/ + +#include "pcg/operator_graph/operator_graph_output.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +OperatorGraphOutput::OperatorGraphOutput(::FlexFlow::Node const &node, + int const &idx) + : node(node), idx(idx) {} +bool OperatorGraphOutput::operator==(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) == std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator!=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) != std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) < std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) > std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator<=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) <= std::tie(other.node, other.idx); +} +bool OperatorGraphOutput::operator>=(OperatorGraphOutput const &other) const { + return std::tie(this->node, this->idx) >= std::tie(other.node, other.idx); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorGraphOutput const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= + std::hash{}(x.idx) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OperatorGraphOutput const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorGraphOutput const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_guid_t.dtg.cc b/lib/pcg/src/pcg/operator_guid_t.dtg.cc new file mode 100644 index 0000000000..46b031f7e1 --- /dev/null +++ b/lib/pcg/src/pcg/operator_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/operator_guid_t.struct.toml +/* proj-data +{ + "generated_from": "348b5a610f4ff6f545884564ee9a1e6a" +} +*/ + +#include "pcg/operator_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +operator_guid_t::operator_guid_t(::FlexFlow::Node const &raw_graph_node) + : raw_graph_node(raw_graph_node) {} +bool operator_guid_t::operator==(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) == std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator!=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) != std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) < std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) > std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator<=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) <= std::tie(other.raw_graph_node); +} +bool operator_guid_t::operator>=(operator_guid_t const &other) const { + return std::tie(this->raw_graph_node) >= std::tie(other.raw_graph_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::operator_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_graph_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(operator_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, operator_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d362459cc3 --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/adam_optimizer_attrs.dtg.cc @@ -0,0 +1,192 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/adam_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "f49e1bebcb0ef2bc3c210073e3183d4d" +} +*/ + +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +AdamOptimizerAttrs::AdamOptimizerAttrs(double const &alpha, + double const &beta1, + double const &beta2, + double const &weight_decay, + double const &alpha_t, + double const &beta_t, + double const &beta2_t) + : alpha(alpha), beta1(beta1), beta2(beta2), weight_decay(weight_decay), + alpha_t(alpha_t), beta_t(beta_t), beta2_t(beta2_t) {} +bool AdamOptimizerAttrs::operator==(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) == std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator!=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) != std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) < std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) > std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator<=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) <= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +bool AdamOptimizerAttrs::operator>=(AdamOptimizerAttrs const &other) const { + return std::tie(this->alpha, + this->beta1, + this->beta2, + this->weight_decay, + this->alpha_t, + this->beta_t, + this->beta2_t) >= std::tie(other.alpha, + other.beta1, + other.beta2, + other.weight_decay, + other.alpha_t, + other.beta_t, + other.beta2_t); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AdamOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.alpha) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta1) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.beta2) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.alpha_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.beta2_t) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::AdamOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("alpha").template get(), + j.at("beta1").template get(), + j.at("beta2").template get(), + j.at("weight_decay").template get(), + j.at("alpha_t").template get(), + j.at("beta_t").template get(), + j.at("beta2_t").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::AdamOptimizerAttrs const &v) { + j["__type"] = "AdamOptimizerAttrs"; + j["alpha"] = v.alpha; + j["beta1"] = v.beta1; + j["beta2"] = v.beta2; + j["weight_decay"] = v.weight_decay; + j["alpha_t"] = v.alpha_t; + j["beta_t"] = v.beta_t; + j["beta2_t"] = v.beta2_t; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(AdamOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AdamOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc new file mode 100644 index 0000000000..d5e668917b --- /dev/null +++ b/lib/pcg/src/pcg/optimizers/sgd_optimizer_attrs.dtg.cc @@ -0,0 +1,111 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/optimizers/sgd_optimizer_attrs.struct.toml +/* proj-data +{ + "generated_from": "d18c91cdddc760f1fb3990d2c817ee87" +} +*/ + +#include "pcg/optimizers/sgd_optimizer_attrs.dtg.h" + +#include + +namespace FlexFlow { +SGDOptimizerAttrs::SGDOptimizerAttrs(double const &lr, + double const &momentum, + bool const &nesterov, + double const &weight_decay) + : lr(lr), momentum(momentum), nesterov(nesterov), + weight_decay(weight_decay) {} +bool SGDOptimizerAttrs::operator==(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) == + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator!=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) != + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) < + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) > + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator<=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) <= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +bool SGDOptimizerAttrs::operator>=(SGDOptimizerAttrs const &other) const { + return std::tie( + this->lr, this->momentum, this->nesterov, this->weight_decay) >= + std::tie(other.lr, other.momentum, other.nesterov, other.weight_decay); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SGDOptimizerAttrs const &x) const { + size_t result = 0; + result ^= + std::hash{}(x.lr) + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.momentum) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.nesterov) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash{}(x.weight_decay) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::SGDOptimizerAttrs + adl_serializer::from_json(json const &j) { + return {j.at("lr").template get(), + j.at("momentum").template get(), + j.at("nesterov").template get(), + j.at("weight_decay").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::SGDOptimizerAttrs const &v) { + j["__type"] = "SGDOptimizerAttrs"; + j["lr"] = v.lr; + j["momentum"] = v.momentum; + j["nesterov"] = v.nesterov; + j["weight_decay"] = v.weight_decay; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(SGDOptimizerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, SGDOptimizerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..e4e1555b4a --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "e4db0f603f7b8947dda13e01f96c40fb" +} +*/ + +#include "pcg/parallel_computation_graph.dtg.h" + +#include "pcg/dataflow_graph.h" +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" + +namespace FlexFlow { +ParallelComputationGraph::ParallelComputationGraph( + ::FlexFlow::DataflowGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc new file mode 100644 index 0000000000..455fb22baf --- /dev/null +++ b/lib/pcg/src/pcg/parallel_layer_attrs.dtg.cc @@ -0,0 +1,83 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_layer_attrs.struct.toml +/* proj-data +{ + "generated_from": "97fa0b11c59ae892a8a530ffd67e33ad" +} +*/ + +#include "pcg/parallel_layer_attrs.dtg.h" + +#include "op-attrs/operator_attrs.h" +#include "utils/stack_string.h" +#include +#include + +namespace FlexFlow { +ParallelLayerAttrs::ParallelLayerAttrs( + ::FlexFlow::PCGOperatorAttrs const &attrs, + std::optional<::FlexFlow::stack_string> const &name) + : attrs(attrs), name(name) {} +bool ParallelLayerAttrs::operator==(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) == std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator!=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) != std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) < std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) > std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator<=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) <= std::tie(other.attrs, other.name); +} +bool ParallelLayerAttrs::operator>=(ParallelLayerAttrs const &other) const { + return std::tie(this->attrs, this->name) >= std::tie(other.attrs, other.name); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelLayerAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::PCGOperatorAttrs>{}(x.attrs) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>>{}(x.name) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelLayerAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("attrs").template get<::FlexFlow::PCGOperatorAttrs>(), + j.at("name") + .template get>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelLayerAttrs const &v) { + j["__type"] = "ParallelLayerAttrs"; + j["attrs"] = v.attrs; + j["name"] = v.name; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelLayerAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelLayerAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc new file mode 100644 index 0000000000..ae5d618172 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_tensor_attrs.dtg.cc @@ -0,0 +1,134 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/parallel_tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "b3e086b380bbc41d99332e1463a34b28" +} +*/ + +#include "pcg/parallel_tensor_attrs.dtg.h" + +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/param_sync.dtg.h" +#include "pcg/create_grad.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +ParallelTensorAttrs::ParallelTensorAttrs( + ::FlexFlow::ParallelTensorShape const &shape, + std::optional<::FlexFlow::ParamSync> const &sync_type, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + ::FlexFlow::CreateGrad const &create_gradients) + : shape(shape), sync_type(sync_type), initializer(initializer), + create_gradients(create_gradients) {} +bool ParallelTensorAttrs::operator==(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) == std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator!=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) != std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) < std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) > std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator<=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) <= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +bool ParallelTensorAttrs::operator>=(ParallelTensorAttrs const &other) const { + return std::tie(this->shape, + this->sync_type, + this->initializer, + this->create_gradients) >= std::tie(other.shape, + other.sync_type, + other.initializer, + other.create_gradients); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ParallelTensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ParallelTensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::CreateGrad>{}(x.create_gradients) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::ParallelTensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::ParallelTensorShape>(), + j.at("sync_type").template get>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get<::FlexFlow::CreateGrad>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::ParallelTensorAttrs const &v) { + j["__type"] = "ParallelTensorAttrs"; + j["shape"] = v.shape; + j["sync_type"] = v.sync_type; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(ParallelTensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, ParallelTensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/side_size_t.dtg.cc b/lib/pcg/src/pcg/side_size_t.dtg.cc new file mode 100644 index 0000000000..54db2974fe --- /dev/null +++ b/lib/pcg/src/pcg/side_size_t.dtg.cc @@ -0,0 +1,75 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/side_size_t.struct.toml +/* proj-data +{ + "generated_from": "6a1669890e547dcc7a4ddb90be05be15" +} +*/ + +#include "pcg/side_size_t.dtg.h" + +#include + +namespace FlexFlow { +side_size_t::side_size_t(int const &unwrapped) : unwrapped(unwrapped) {} +bool side_size_t::operator==(side_size_t const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool side_size_t::operator!=(side_size_t const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +bool side_size_t::operator<(side_size_t const &other) const { + return std::tie(this->unwrapped) < std::tie(other.unwrapped); +} +bool side_size_t::operator>(side_size_t const &other) const { + return std::tie(this->unwrapped) > std::tie(other.unwrapped); +} +bool side_size_t::operator<=(side_size_t const &other) const { + return std::tie(this->unwrapped) <= std::tie(other.unwrapped); +} +bool side_size_t::operator>=(side_size_t const &other) const { + return std::tie(this->unwrapped) >= std::tie(other.unwrapped); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::side_size_t const &x) const { + size_t result = 0; + result ^= std::hash{}(x.unwrapped) + 0x9e3779b9 + (result << 6) + + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::side_size_t + adl_serializer::from_json(json const &j) { + return {j.at("unwrapped").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::side_size_t const &v) { + j["__type"] = "side_size_t"; + j["unwrapped"] = v.unwrapped; +} +} // namespace nlohmann + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct(gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(side_size_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, side_size_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle.dtg.cc b/lib/pcg/src/pcg/strided_rectangle.dtg.cc new file mode 100644 index 0000000000..e743a2722a --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle.dtg.cc @@ -0,0 +1,86 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle.struct.toml +/* proj-data +{ + "generated_from": "817bbe017d179aa469822a4032d08836" +} +*/ + +#include "pcg/strided_rectangle.dtg.h" + +#include "op-attrs/dim_ordered.h" +#include "pcg/strided_rectangle_side.dtg.h" +#include + +namespace FlexFlow { +StridedRectangle::StridedRectangle( + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide> const &sides) + : sides(sides) {} +bool StridedRectangle::operator==(StridedRectangle const &other) const { + return std::tie(this->sides) == std::tie(other.sides); +} +bool StridedRectangle::operator!=(StridedRectangle const &other) const { + return std::tie(this->sides) != std::tie(other.sides); +} +bool StridedRectangle::operator<(StridedRectangle const &other) const { + return std::tie(this->sides) < std::tie(other.sides); +} +bool StridedRectangle::operator>(StridedRectangle const &other) const { + return std::tie(this->sides) > std::tie(other.sides); +} +bool StridedRectangle::operator<=(StridedRectangle const &other) const { + return std::tie(this->sides) <= std::tie(other.sides); +} +bool StridedRectangle::operator>=(StridedRectangle const &other) const { + return std::tie(this->sides) >= std::tie(other.sides); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangle const &x) const { + size_t result = 0; + result ^= + std::hash<::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>{}( + x.sides) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangle + adl_serializer::from_json(json const &j) { + return {j.at("sides") + .template get< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangle const &v) { + j["__type"] = "StridedRectangle"; + j["sides"] = v.sides; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary< + ::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangle const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangle const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc new file mode 100644 index 0000000000..80258886d7 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -0,0 +1,15 @@ +#include "pcg/strided_rectangle_side.h" +#include "utils/exception.h" + +namespace FlexFlow { + +StridedRectangleSide strided_side_from_size_and_stride(side_size_t, + int stride) { + NOT_IMPLEMENTED(); +} + +side_size_t get_side_size(StridedRectangleSide const &s) { + return s.num_points.unwrapped * s.stride; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc new file mode 100644 index 0000000000..0bb31b0496 --- /dev/null +++ b/lib/pcg/src/pcg/strided_rectangle_side.dtg.cc @@ -0,0 +1,91 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/strided_rectangle_side.struct.toml +/* proj-data +{ + "generated_from": "b14fcf1e28c262d22b92fac691ede3d4" +} +*/ + +#include "pcg/strided_rectangle_side.dtg.h" + +#include "pcg/num_points_t.dtg.h" +#include + +namespace FlexFlow { +StridedRectangleSide::StridedRectangleSide( + ::FlexFlow::num_points_t const &num_points, int const &stride) + : num_points(num_points), stride(stride) {} +bool StridedRectangleSide::operator==(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) == + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator!=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) != + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) < + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) > + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator<=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) <= + std::tie(other.num_points, other.stride); +} +bool StridedRectangleSide::operator>=(StridedRectangleSide const &other) const { + return std::tie(this->num_points, this->stride) >= + std::tie(other.num_points, other.stride); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::StridedRectangleSide const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::num_points_t>{}(x.num_points) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.stride) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::StridedRectangleSide + adl_serializer::from_json(json const &j) { + return {j.at("num_points").template get<::FlexFlow::num_points_t>(), + j.at("stride").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::StridedRectangleSide const &v) { + j["__type"] = "StridedRectangleSide"; + j["num_points"] = v.num_points; + j["stride"] = v.stride; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::num_points_t>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(StridedRectangleSide const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, StridedRectangleSide const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_attrs.dtg.cc b/lib/pcg/src/pcg/tensor_attrs.dtg.cc new file mode 100644 index 0000000000..46a6fb8d50 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_attrs.dtg.cc @@ -0,0 +1,133 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_attrs.struct.toml +/* proj-data +{ + "generated_from": "68447a4357476647ef25dd39dfd12578" +} +*/ + +#include "pcg/tensor_attrs.dtg.h" + +#include "op-attrs/param_sync.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/initializer_attrs.dtg.h" +#include +#include + +namespace FlexFlow { +TensorAttrs::TensorAttrs( + ::FlexFlow::TensorShape const &shape, + std::optional<::FlexFlow::InitializerAttrs> const &initializer, + bool const &create_gradients, + std::optional<::FlexFlow::ParamSync> const &sync_type) + : shape(shape), initializer(initializer), + create_gradients(create_gradients), sync_type(sync_type) {} +bool TensorAttrs::operator==(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) == std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator!=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) != std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) < std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) > std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator<=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) <= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +bool TensorAttrs::operator>=(TensorAttrs const &other) const { + return std::tie(this->shape, + this->initializer, + this->create_gradients, + this->sync_type) >= std::tie(other.shape, + other.initializer, + other.create_gradients, + other.sync_type); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttrs const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorShape>{}(x.shape) + 0x9e3779b9 + + (result << 6) + (result >> 2); + result ^= + std::hash>{}(x.initializer) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash{}(x.create_gradients) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash>{}(x.sync_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttrs + adl_serializer::from_json(json const &j) { + return { + j.at("shape").template get<::FlexFlow::TensorShape>(), + j.at("initializer") + .template get>(), + j.at("create_gradients").template get(), + j.at("sync_type").template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttrs const &v) { + j["__type"] = "TensorAttrs"; + j["shape"] = v.shape; + j["initializer"] = v.initializer; + j["create_gradients"] = v.create_gradients; + j["sync_type"] = v.sync_type; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttrs const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttrs const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/tensor_guid_t.dtg.cc b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc new file mode 100644 index 0000000000..9d57291112 --- /dev/null +++ b/lib/pcg/src/pcg/tensor_guid_t.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/pcg/include/pcg/tensor_guid_t.struct.toml +/* proj-data +{ + "generated_from": "dc15fcbb876ec70509dfa8b662963bc3" +} +*/ + +#include "pcg/tensor_guid_t.dtg.h" + +#include "utils/graph.h" +#include + +namespace FlexFlow { +tensor_guid_t::tensor_guid_t(::FlexFlow::MultiDiOutput const &raw_graph_output) + : raw_graph_output(raw_graph_output) {} +bool tensor_guid_t::operator==(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) == std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator!=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) != std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) < std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) > std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator<=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) <= std::tie(other.raw_graph_output); +} +bool tensor_guid_t::operator>=(tensor_guid_t const &other) const { + return std::tie(this->raw_graph_output) >= std::tie(other.raw_graph_output); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::tensor_guid_t const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiOutput>{}(x.raw_graph_output) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(tensor_guid_t const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, tensor_guid_t const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/pcg/src/serialization.cc b/lib/pcg/src/serialization.cc deleted file mode 100644 index 439c03916d..0000000000 --- a/lib/pcg/src/serialization.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "pcg/serialization.h" - -namespace FlexFlow {} diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 27ef9a7f5b..9c8ff69b42 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -3,34 +3,23 @@ namespace FlexFlow { -size_t StridedRectangle::at(FFOrdered const &coord) const { - assert(coord.size() == this->num_dims()); - - size_t _1d_stride = 1; - size_t idx = 0; - for (auto dim : inner_to_outer_idxs(this->sides)) { - idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; - _1d_stride *= this->sides.at(dim).get_size().value(); - } - return idx; -} - -StridedRectangleSide::StridedRectangleSide(side_size_t const &num, int stride) - : num_points(num.value()), stride(stride) {} - -side_size_t StridedRectangleSide::at(num_points_t) const { - NOT_IMPLEMENTED(); -} - -num_points_t StridedRectangleSide::at(side_size_t) const { - NOT_IMPLEMENTED(); -} - -side_size_t StridedRectangleSide::get_size() const { +/* size_t StridedRectangle::at(FFOrdered const &coord) const { */ +/* assert(coord.size() == this->num_dims()); */ + +/* size_t _1d_stride = 1; */ +/* size_t idx = 0; */ +/* for (auto dim : inner_to_outer_idxs(this->sides)) { */ +/* idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; */ +/* _1d_stride *= this->sides.at(dim).get_size().value(); */ +/* } */ +/* return idx; */ +/* } */ + +size_t get_num_dims(StridedRectangle const &) { NOT_IMPLEMENTED(); } -size_t StridedRectangle::num_dims() const { +size_t get_side_at_idx(StridedRectangle const &) { NOT_IMPLEMENTED(); } diff --git a/lib/pcg/src/tensor.cc b/lib/pcg/src/tensor.cc deleted file mode 100644 index df29ee0065..0000000000 --- a/lib/pcg/src/tensor.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/tensor.h" - -namespace FlexFlow { - -Tensor::operator TensorShape() const { - return TensorShape{dims, data_type}; -} - -TensorShape Tensor::get_shape() const { - return TensorShape(*this); -} - -Tensor construct_tensor_from_output_shape(TensorShape const &shape) { - return Tensor{shape.dims, shape.data_type, std::nullopt, false, std::nullopt}; -} - -} // namespace FlexFlow diff --git a/lib/pcg/test/CMakeLists.txt b/lib/pcg/test/CMakeLists.txt new file mode 100644 index 0000000000..685d1d8b88 --- /dev/null +++ b/lib/pcg/test/CMakeLists.txt @@ -0,0 +1,13 @@ +ff_add_test_executable( + NAME + pcg-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + pcg + doctest + utils-test-common +) diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc new file mode 100644 index 0000000000..e88e231bd0 --- /dev/null +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -0,0 +1,28 @@ +#include "doctest/doctest.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphBuilder") { + ComputationGraphBuilder b; + + size_t batch_size = 2; + + TensorShape input_shape = { + TensorDims{FFOrdered{batch_size, 3, 10, 10}}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_tensor(input_shape, /*create_grad=*/true); + tensor_guid_t output = b.conv2d(input, + /*outChannels=*/5, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0); + // ComputationGraph cg = b.computation_graph; + // CHECK(get_layers(cg).size() == 1); + } +} diff --git a/lib/runtime/include/runtime/task_spec/concrete_arg.h b/lib/runtime/include/runtime/task_spec/concrete_arg.h deleted file mode 100644 index 1d973eb81a..0000000000 --- a/lib/runtime/include/runtime/task_spec/concrete_arg.h +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_INCLUDE_RUNTIME_TASK_SPEC_CONCRETE_ARG_H -#define _FLEXFLOW_RUNTIME_INCLUDE_RUNTIME_TASK_SPEC_CONCRETE_ARG_H - -#include "arg_type_runtime_tag.h" -#include "utils/type_index.h" -#include - -namespace FlexFlow { - -struct ConcreteArgSpec { -public: - ConcreteArgSpec() = delete; - - template - T const &get() { - assert(this->type_tag.matches()); - - return *(T const *)ptr.get(); - } - - ArgTypeRuntimeTag get_type_tag() const { - return this->type_tag; - } - size_t serialize(Legion::Serializer &) const; - - template - static ConcreteArgSpec create(T const &t) { - static_assert(is_serializable::value, "Type must be serializable"); - - return ConcreteArgSpec(type_index(), - std::make_shared(t), - ArgTypeRuntimeTag::create()); - } - -private: - ConcreteArgSpec(std::type_index, - std::shared_ptr, - ArgTypeRuntimeTag const &); - - ArgTypeRuntimeTag type_tag; - std::shared_ptr ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/runtime/src/ops/gather.cc b/lib/runtime/src/ops/gather.cc deleted file mode 100644 index 9ef53ffc6a..0000000000 --- a/lib/runtime/src/ops/gather.cc +++ /dev/null @@ -1,416 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "gather.h" -#include "embedding.h" -#include "kernels/gather_kernels.h" -#include "legion/legion_utilities.h" - -namespace FlexFlow { - -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; - -using namespace FlexFlow::Kernels::Gather; - -GatherParams Gather::get_params() const { - GatherParams params; - params.legion_dim = this->legion_dim; - params.layer_guid = this->layer_guid; - return params; -} - -Tensor FFModel::gather(const Tensor input, - const Tensor index, - int dim, - char const *name) { - Layer *gather = new Layer(this, - OP_GATHER, - DT_FLOAT, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*output*/, - input, - index); - assert(index->data_type == DT_INT32 || index->data_type == DT_INT64); - assert(input->num_dims == index->num_dims); - int legion_dim = input->num_dims - 1 - dim; - // https://pytorch.org/docs/stable/generated/torch.gather.html - // Currently we assume index.size(d) == input.size(d) for all - // dimensions d != dim, which is a stronger constraint that PyTorch's - for (int i = 0; i < input->num_dims; i++) { - if (i != legion_dim) { - assert(input->dims[i] == index->dims[i]); - } - } - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < index->num_dims; i++) { - dims[i] = index->dims[i]; - } - gather->outputs[0] = create_tensor_legion_ordering( - index->num_dims, dims, input->data_type, gather, 0, true /*create_grad*/); - gather->add_int_property("legion_dim", legion_dim); - layers.push_back(gather); - return gather->outputs[0]; -} - -Op *Gather::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - long long value; - layer->get_int_property("legion_dim", value); - int legion_dim = value; - return new Gather( - model, layer->layer_guid, inputs[0], inputs[1], legion_dim, layer->name); -} - -Gather::Gather(FFModel &model, - GatherParams const ¶ms, - std::pair const &inputs, - char const *name) - : Gather(model, - params.layer_guid, - inputs.first, - inputs.second, - params.legion_dim, - name) {} - -Gather::Gather(FFModel &model, - LayerID const &_layer_guid, - const ParallelTensor input, - const ParallelTensor index, - int _legion_dim, - char const *name) - : Op(model, - OP_GATHER, - input->data_type, - name, - 2 /*inputs*/, - 0 /*weights*/, - 1 /*outputs*/, - input, - index), - legion_dim(_legion_dim) { - layer_guid = _layer_guid; - // Assume that input and index have the same paralleldim except - // for the legion_dim-th dim, which cannot be parallelized - for (int i = 0; i < input->num_dims; i++) { - if (i != legion_dim) { - assert(input->dims[i] == index->dims[i]); - } - } - assert(index->dims[legion_dim].degree == 1); - assert(input->dims[legion_dim].degree == 1); - // output has the same parallel dims as index - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < index->num_dims; i++) { - dims[i] = index->dims[i]; - } - outputs[0] = model.create_parallel_tensor_legion_ordering( - index->num_dims, dims, input->data_type, this); -} - -void Gather::serialize(Legion::Serializer &sez) const { - GatherParams params = get_params(); - sez.serialize(params.legion_dim); - sez.serialize(this->layer_guid.id); -} - -using PCG::Node; -/*static*/ -Node Gather::deserialize(FFModel &ff, - Legion::Deserializer &dez, - ParallelTensor inputs[], - int num_inputs) { - assert(num_inputs == 2); - int legion_dim; - dez.deserialize(legion_dim); - size_t id; - dez.deserialize(id); - LayerID layer_guid(id); - - GatherParams params; - params.legion_dim = legion_dim; - params.layer_guid = layer_guid; - return ff.get_or_create_node({inputs[0], inputs[1]}, params); -} - -Op *Gather::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - GatherParams params = get_params(); - return new Gather(ff, params, {inputs[0], inputs[1]}, this->name); -} - -void Gather::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher(GATHER_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(Gather)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[1]->region)); - launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} - -PerDeviceOpState *Gather::init_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - Gather const *gather = (Gather const *)task->args; - FFHandler handle = *((FFHandler const *)task->local_args); - GatherMeta *m = new GatherMeta(handle, gather); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR index = helperGetGenericTensorAccessorRO( - m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - assert(input.domain.get_dim() == index.domain.get_dim()); - assert(output.domain.get_dim() == index.domain.get_dim()); - for (int i = 0; i < input.domain.get_dim(); i++) { - assert(index.domain.hi()[i] == output.domain.hi()[i]); - assert(index.domain.lo()[i] == output.domain.lo()[i]); - if (i != m->legion_dim) { - assert(input.domain.hi()[i] == index.domain.hi()[i]); - assert(input.domain.lo()[i] == index.domain.lo()[i]); - } - } - return m; -} - -void Gather::forward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_forward(ff, argmap); - IndexLauncher launcher(GATHER_FWD_TASK_ID, - parallel_is, - TaskArgument(nullptr, false), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[1]->region)); - launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -void Gather::forward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - GatherMeta const *m = *((GatherMeta **)task->local_args); - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR index = helperGetGenericTensorAccessorRO( - m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - forward_kernel_wrapper(m, input, index, output); -} - -void Gather::backward(FFModel const &ff) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_backward(ff, argmap); - IndexLauncher launcher(GATHER_BWD_TASK_ID, - parallel_is, - TaskArgument(NULL, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - outputs[0]->region_grad)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[1]->region)); - launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - inputs[0]->region_grad)); - launcher.add_field(2, FID_DATA); - runtime->execute_index_space(ctx, launcher); -} - -void Gather::backward_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(regions.size() == 3); - assert(task->regions.size() == 3); - GatherMeta const *m = *((GatherMeta **)task->local_args); - GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR index = helperGetGenericTensorAccessorRO( - m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( - m->input_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - backward_kernel_wrapper(m, output_grad, index, input_grad); -} - -bool Gather::measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const { - ParallelTensorBase sub_input, sub_index, sub_output; - if (!outputs[0]->get_sub_tensor(mv, sub_output)) { - return false; - } - if (!inputs[0]->get_sub_tensor(mv, sub_input)) { - return false; - } - if (!inputs[1]->get_sub_tensor(mv, sub_index)) { - return false; - } - GatherMeta *m = new GatherMeta(sim->handler, this); - sim->free_all(); - bool out_of_memory = false; - Domain input_domain = sub_input.get_domain(); - void *input_ptr = sim->allocate(sub_input.get_volume(), inputs[0]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW input_acc( - inputs[0]->data_type, input_domain, input_ptr); - Domain index_domain = sub_index.get_domain(); - void *index_ptr = sim->allocate(sub_index.get_volume(), inputs[1]->data_type); - cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW index_acc( - inputs[1]->data_type, index_domain, index_ptr); - out_of_memory = out_of_memory || (input_ptr == NULL) || (index_ptr == NULL); - Domain out_domain = sub_output.get_domain(); - void *output_ptr = - sim->allocate(sub_output.get_volume(), outputs[0]->data_type); - out_of_memory = out_of_memory || (output_ptr == NULL); - cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset); - GenericTensorAccessorW output_acc( - outputs[0]->data_type, out_domain, output_ptr); - if (out_of_memory) { - cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME; - return true; - } - - std::function forward, backward; - forward = [&] { - forward_kernel_wrapper(m, input_acc, index_acc, output_acc); - }; - if (sim->computationMode == COMP_MODE_TRAINING) { - backward = [&] { - backward_kernel_wrapper(m, output_acc, index_acc, input_acc); - }; - } - - inner_measure_operator_cost(sim, forward, backward, cost_metrics); - - if (sim->computationMode == COMP_MODE_TRAINING) { - printf("[Measure Gather] name(%s) forward_time(%.4lf) " - "backward_time(%.4lf)\n", - name, - cost_metrics.forward_time, - cost_metrics.backward_time); - } else { - printf("[Measure Gather] name(%s) forward_time(%.4lf)\n", - name, - cost_metrics.forward_time); - } - delete m; - return true; -} - -}; // namespace FlexFlow - -namespace std { -size_t hash::operator()( - FlexFlow::GatherParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.legion_dim); - hash_combine(key, params.layer_guid.id); - return key; -} -}; // namespace std diff --git a/lib/runtime/src/ops/gather.h b/lib/runtime/src/ops/gather.h deleted file mode 100644 index 1ea20b71f5..0000000000 --- a/lib/runtime/src/ops/gather.h +++ /dev/null @@ -1,78 +0,0 @@ -#ifndef _FLEXFLOW_OPS_GATHER_H -#define _FLEXFLOW_OPS_GATHER_H - -#include "op-attrs/ops/gather.h" -#include "op_task_invocation.h" -#include "sim_environment.h" - -namespace FlexFlow { - -template <> -void register_task(); -template <> -void register_task(); -template <> -void register_task(); - -OpTaskInvocation init(GatherAttrs const &); -OpTaskInvocation forward(GatherAttrs const &); -OpTaskInvocation backward(GatherAttrs const &); - -CostMetrics measure_operator_cost(SimEnvFactory const &sim_factory, - GatherAttrs const &attrs, - ParallelTensorShape const &input_shape, - ParallelTensorShape const &index_shape, - ProfilingSettings const &settings, - MachineView const &machine_view); - -/* class Gather : public Op { */ -/* public: */ -/* Gather(FFModel &model, */ -/* ParallelTensor const &input, */ -/* ParallelTensor const &index, */ -/* int legion_dim, */ -/* char const *name = nullptr); */ -/* void init(FFModel const &) override; */ -/* void forward(FFModel const &) override; */ -/* void backward(FFModel const &) override; */ - -/* static Op * */ -/* create_operator_from_layer(FFModel &model, */ -/* Layer const *layer, */ -/* std::vector const &inputs); - */ - -/* static PerDeviceOpState *init_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void forward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* static void backward_task(Legion::Task const *task, */ -/* std::vector const - * ®ions, */ -/* Legion::Context ctx, */ -/* Legion::Runtime *runtime); */ -/* bool measure_operator_cost(Simulator *sim, */ -/* MachineView const &pc, */ -/* CostMetrics &cost_metrics) const override; */ -/* void serialize(Legion::Serializer &s) const override; */ -/* /1* static PCG::Node deserialize(FFModel &ff, *1/ */ -/* /1* Legion::Deserializer &d, *1/ */ -/* /1* ParallelTensor inputs[], *1/ */ -/* /1* int num_inputs); *1/ */ -/* Op *materialize(FFModel &ff, */ -/* ParallelTensor inputs[], */ -/* int num_inputs) const override; */ - -/* public: */ -/* int legion_dim; */ -/* }; */ - -} // namespace FlexFlow - -#endif diff --git a/lib/runtime/src/parallel_op_info.h b/lib/runtime/src/parallel_op_info.h index 11e2f03477..ebd44f012b 100644 --- a/lib/runtime/src/parallel_op_info.h +++ b/lib/runtime/src/parallel_op_info.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H #include "op-attrs/ff_dim.h" -#include "op-attrs/op.h" +#include "op-attrs/operator_type.h" #include "utils/visitable.h" #include #include diff --git a/lib/runtime/src/task_spec/op_task_invocation.cc b/lib/runtime/src/task_spec/op_task_invocation.cc deleted file mode 100644 index fbbfe47726..0000000000 --- a/lib/runtime/src/task_spec/op_task_invocation.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "op_task_invocation.h" -#include "task_argument_accessor.h" - -namespace FlexFlow { - -OpTaskSignature get_signature(task_id_t const &) { - NOT_IMPLEMENTED(); -} - -OpTensorSpec::OpTensorSpec(TensorRole _role, int _idx) - : role(_role), idx(_idx) {} - -OpTensorSpec input_tensor(int idx) { - return {TensorRole::INPUT, idx}; -} - -OpTensorSpec output_tensor(int idx) { - return {TensorRole::OUTPUT, idx}; -} - -OpTensorSpec weight_tensor(int idx) { - return {TensorRole::WEIGHT, idx}; -} - -// OpTaskBinding::OpTaskBinding() { -// this->serializer.reserve_bytes(sizeof(TaskArgumentFormat)); -// } - -void OpTaskBinding::bind(slot_id slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({{slot, IsGrad::NO}, tensor_spec}); -} - -void OpTaskBinding::bind_grad(slot_id slot, OpTensorSpec const &tensor_spec) { - this->tensor_bindings.insert({{slot, IsGrad::YES}, tensor_spec}); -} - -std::unordered_map, OpTensorSpec> const & - OpTaskBinding::get_tensor_bindings() const { - return this->tensor_bindings; -} - -std::unordered_map const & - OpTaskBinding::get_arg_bindings() const { - return this->arg_bindings; -} - -} // namespace FlexFlow diff --git a/lib/runtime/src/task_spec/op_task_invocation.h b/lib/runtime/src/task_spec/op_task_invocation.h deleted file mode 100644 index 56e709734e..0000000000 --- a/lib/runtime/src/task_spec/op_task_invocation.h +++ /dev/null @@ -1,135 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_OP_TASK_SPEC_H -#define _FLEXFLOW_RUNTIME_OP_TASK_SPEC_H - -#include "accessor.h" -#include "index_task_invocation.h" -#include "legion.h" -#include "op_arg_ref.h" -#include "op_task_signature.h" -#include "op_tensor_spec.h" -#include "runtime/config.h" -#include "runtime/profiling.h" -#include "serialization.h" -#include "standard_task_invocation.h" -#include "tasks.h" -#include "utils/bidict.h" -#include "utils/optional.h" -#include "utils/stack_map.h" -#include "variadic_tensor_ref.h" -#include -#include -#include - -namespace FlexFlow { - -enum class IsTrainable { YES, NO }; - -using OpArgSpec = variant; - -struct OpTaskBinding { - OpTaskBinding() = default; - - static_assert(is_subeq_variant::value, ""); - - void bind(slot_id, OpTensorSpec const &); - void bind_grad(slot_id, OpTensorSpec const &); - - template - void bind(slot_id name, VariadicTensorRef const &t) { - NOT_IMPLEMENTED(); - } - - template - void bind_device_specific_arg(slot_id name, T const &t) { - NOT_IMPLEMENTED(); - } - - template - void bind_device_specific_arg(slot_id name, OpArgRef const &t) { - NOT_IMPLEMENTED(); - } - - template - void bind_arg(slot_id name, T const &t) { - this->insert_arg_spec(name, ConcreteArgSpec::create(t)); - } - - template - void bind_arg(slot_id name, RuntimeArgRef const &ref) { - this->insert_arg_spec(name, RuntimeArgRefSpec::create(ref)); - } - - template - void bind_arg(slot_id name, OpArgRef const &ref) { - this->insert_arg_spec(name, OpArgRefSpec::create(ref)); - } - - template - void bind_arg(slot_id name, TypedFuture const &f) { - this->insert_arg_spec(name, CheckedTypedFuture::create(f)); - } - - template - void bind_arg(slot_id name, TypedFutureMap const &fm) { - this->insert_arg_spec(name, CheckedTypedFutureMap::create(fm)); - } - - std::unordered_map, OpTensorSpec> const & - get_tensor_bindings() const; - std::unordered_map const &get_arg_bindings() const; - -private: - void insert_arg_spec(slot_id name, OpArgSpec const &arg_spec) { - assert(!contains_key(this->arg_bindings, name)); - this->arg_bindings.insert({name, arg_spec}); - } - - // template - // ArgSpec generate_arg_spec(T const &t) { - // static_assert(is_serializable, "Type must be serializable"); - - // size_t pre_size = serializer.get_used_bytes(); - // ff_task_serialize(serializer, t); - // size_t post_size = serializer.get_used_bytes(); - // return { - // typeid(T), - // pre_size, - // post_size - pre_size - // }; - // } - - /* Legion::Serializer serializer; */ - std::unordered_map arg_bindings; - std::unordered_map, OpTensorSpec> tensor_bindings; -}; - -struct OpTaskInvocation : public use_visitable_cmp { -public: - OpTaskInvocation() = delete; - OpTaskInvocation(task_id_t const &task_id, OpTaskBinding const &binding) - : task_id(task_id), binding(binding) {} - -public: - task_id_t task_id; - OpTaskBinding binding; -}; - -OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); -OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); -OpTaskSignature get_op_signature(task_id_t const &); - -/* std::unordered_map get_regions_idxs(TaskArgumentFormat - * const &); */ - -/* TaskArgumentFormat compile_task_invocation(OpTaskSignature const &, - * OpTaskBinding const &); */ - -} // namespace FlexFlow - -#endif diff --git a/lib/runtime/src/task_spec/op_tensor_spec.h b/lib/runtime/src/task_spec/op_tensor_spec.h deleted file mode 100644 index d859bb3072..0000000000 --- a/lib/runtime/src/task_spec/op_tensor_spec.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H - -#include "op_task_signature.h" - -namespace FlexFlow { - -struct OpTensorSpec { - TensorRole role; - req idx; -}; -FF_VISITABLE_STRUCT(OpTensorSpec, role, idx); - -OpTensorSpec input_tensor(int); -OpTensorSpec output_tensor(int); -OpTensorSpec weight_tensor(int); - -} // namespace FlexFlow - -#endif diff --git a/lib/runtime/src/task_spec/task_argument_accessor.h b/lib/runtime/src/task_spec/task_argument_accessor.h deleted file mode 100644 index 9cc05b8252..0000000000 --- a/lib/runtime/src/task_spec/task_argument_accessor.h +++ /dev/null @@ -1,193 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_ARGUMENT_ACCESSOR_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_ARGUMENT_ACCESSOR_H - -#include "accessor.h" -#include "device_specific.h" -#include "realm_allocator.h" -#include "runtime/config.h" -#include "utils/exception.h" -#include "utils/stack_map.h" -#include "utils/strong_typedef.h" -#include - -namespace FlexFlow { - -struct region_idx_t : strong_typedef { - using strong_typedef::strong_typedef; -}; - -FF_TYPEDEF_HASHABLE(region_idx_t); -FF_TYPEDEF_PRINTABLE(region_idx_t, "region_idx"); - -using NonvariadicFormat = region_idx_t; -using VariadicFormat = std::vector; - -using TensorArgumentFormat = variant; - -bool is_variadic(TensorArgumentFormat const &); -VariadicFormat get_variadic_format(TensorArgumentFormat const &); -NonvariadicFormat get_nonvariadic_format(TensorArgumentFormat const &); - -struct TaskArgumentFormat { - std::type_index type; - size_t start; - req end; -}; -FF_VISITABLE_STRUCT(TaskArgumentFormat, type, start, end); - -struct FutureArgumentFormat { - std::type_index type; - req future_idx; -}; -FF_VISITABLE_STRUCT(FutureArgumentFormat, type, future_idx); - -struct TaskArgumentsFormat { - TaskArgumentsFormat() = default; - - stack_map region_idxs; - stack_map args; - stack_map futures; - stack_map regions; - stack_map data_types; - - void insert(std::pair const &); - void insert(std::pair const &); - - void insert(region_idx_t, Legion::PrivilegeMode, DataType); - void insert(slot_id, region_idx_t); - void insert(slot_id, std::vector const &); -}; - -FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( - TaskArgumentsFormat, region_idxs, args, futures, regions, data_types); - -Legion::PrivilegeMode get_privileges(TaskArgumentsFormat const &, - region_idx_t const &); -Legion::PrivilegeMode get_privileges(TaskArgumentsFormat const &, - parallel_tensor_guid_t const &); -Permissions get_permissions(TaskArgumentsFormat const &, region_idx_t const &); -Permissions get_permissions(TaskArgumentsFormat const &, - parallel_tensor_guid_t const &); -region_idx_t get_region_idx(TaskArgumentsFormat const &, - parallel_tensor_guid_t const &); -DataType get_datatype(TaskArgumentsFormat const &, region_idx_t const &); - -struct TaskArgumentAccessor { - TaskArgumentAccessor(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - Allocator get_allocator() const { - return get_gpu_memory_allocator(this->task); - } - - template - T const &get_argument(slot_id slot) const { - NOT_IMPLEMENTED(); - // TaskArgumentFormat arg_fmt = this->args_fmt.args.at(slot); - // std::type_index actual_type = arg_fmt.type; - // std::type_index requested_type = {typeid(T)}; - - // if (actual_type != requested_type) { - // throw mk_runtime_error( - // "Type mismatch in argument access (\"{}\" != \"{}\")", - // actual_type.name(), - // requested_type.name()); - // } - - // void *start_ptr = &((std::uint8_t *)this->task->args)[arg_fmt.start]; - // Legion::Deserializer dez(start_ptr, arg_fmt.start); - - // return ff_task_deserialize(dez); - } - - template - optional get_optional_argument(slot_id) const { - NOT_IMPLEMENTED(); - } - - template - std::vector get_variadic_argument(slot_id) const { - NOT_IMPLEMENTED(); - } - - template - privilege_mode_to_accessor - get_generic_accessor(region_idx_t const &idx) const { - auto tensor_privs = get_permissions(this->args_fmt, idx); - if (tensor_privs != PRIV) { - throw mk_runtime_error( - "Privilege mismatch while accessing tensor: {} != {}", - tensor_privs, - PRIV); - } - - return helperGetGenericTensorAccessor( - get_datatype(this->args_fmt, idx), - regions[idx.value()], - task->regions[idx.value()], - FID_DATA, - ctx, - runtime); - } - - template - privilege_mode_to_accessor get_tensor(slot_id slot) const { - auto argument_format = - get(this->args_fmt.region_idxs.at(slot)); - - return this->get_generic_accessor(argument_format); - } - - template - privilege_mode_to_accessor get_tensor_grad(slot_id slot) const { - NOT_IMPLEMENTED(); - } - - template - std::vector> - get_variadic_tensor(slot_id slot) const { - std::vector> result; - - auto argument_format = - get(this->args_fmt.region_idxs.at(slot)); - for (NonvariadicFormat const &argument : argument_format) { - result.push_back(this->get_generic_accessor(argument)); - } - - return result; - } - - template - std::vector> - get_variadic_tensor_grad(slot_id slot) const { - NOT_IMPLEMENTED(); - } - - template - T *unwrap(DeviceSpecific const &arg) const { - return arg.get(this->get_device_idx()); - } - - template - DeviceSpecific create_device_specific(Args &&...args) const { - return DeviceSpecific::create(this->get_device_idx(), - std::forward(args)...); - } - - size_t get_device_idx() const { - NOT_IMPLEMENTED(); - } - -private: - Legion::Task const *task; - std::vector const ®ions; - Legion::Context ctx; - Legion::Runtime *runtime; - TaskArgumentsFormat const &args_fmt; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/runtime/src/task_spec/variadic_tensor_ref.h b/lib/runtime/src/task_spec/variadic_tensor_ref.h deleted file mode 100644 index a9d1b54731..0000000000 --- a/lib/runtime/src/task_spec/variadic_tensor_ref.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H -#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H - -#include "arg_ref.h" -#include "op_tensor_spec.h" - -namespace FlexFlow { - -enum class VariadicTensorRefType { INPUT_TENSORS }; - -template -using VariadicTensorRef = ArgRef; - -VariadicTensorRef get_input_tensors() { - return {VariadicTensorRefType::INPUT_TENSORS}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h index dbde110f8d..5563d8a835 100644 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ b/lib/substitution-generator/include/substitution-generator/json.h @@ -1,166 +1,16 @@ #ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H #define _FLEXFLOW_SUBSTITUTION_LOADER_H -#include "op-attrs/op.h" +#include "substitution-generator/legacy_operator_type.dtg.h" +#include "substitution-generator/legacy_pm_parameter.dtg.h" #include #include +#include namespace FlexFlow { -enum PMParameter { - PM_OP_TYPE, // AnyOp - PM_NUM_INPUTS, // AnyOp - PM_NUM_OUTPUTS, // AnyOp - PM_GROUP, // Conv2D - PM_KERNEL_H, // Conv2D, Pool2D - PM_KERNEL_W, // Conv2D, Pool2D - PM_STRIDE_H, // Conv2D, Pool2D - PM_STRIDE_W, // Conv2D, Pool2D - PM_PADDING_H, // Conv2D, Pool2D - PM_PADDING_W, // Conv2D, Pool2D - PM_ACTI, // Conv2D, Pool2D - PM_NUMDIM, // Concat, Transpose - PM_AXIS, // Concat, Split - PM_PERM, // Transpose - PM_OUTSHUFFLE, // Transpose - PM_MERGE_GCONV_COUNT, // MergeGConv - PM_AXES, // Squeeze, Unsqueeze, Reduce* - PM_KEEP_DIMS, // Reduce* - PM_EPSILON, // BatchNorm - PM_REPARTITION_DIM, // Repartition - PM_REPARTITION_DEGREE, // Repartition - PM_REPLICATE_DIM, // Replicate - PM_REPLICATE_DEGREE, // Replicate - PM_COMBINE_DIM, // Combine - PM_COMBINE_DEGREE, // Combine - PM_REDUCTION_DIM, // Reduction - PM_REDUCTION_DEGREE, // Reduction - PM_SOFTMAX_DIM, // Softmax - PM_NUM_HEADS, // MultiHeadAttention - PM_INVALID, - PM_PARALLEL_DIM, - PM_PARALLEL_DEGREE, - PM_PAD, -}; - -NLOHMANN_JSON_SERIALIZE_ENUM(PMParameter, - {{PM_INVALID, nullptr}, - {PM_OP_TYPE, "PM_OP_TYPE"}, - {PM_NUM_INPUTS, "PM_NUM_INPUTS"}, - {PM_NUM_OUTPUTS, "PM_NUM_OUTPUTS"}, - {PM_GROUP, "PM_GROUP"}, - {PM_KERNEL_H, "PM_KERNEL_H"}, - {PM_KERNEL_W, "PM_KERNEL_W"}, - {PM_STRIDE_H, "PM_STRIDE_H"}, - {PM_STRIDE_W, "PM_STRIDE_W"}, - {PM_PADDING_H, "PM_PADDING_H"}, - {PM_PADDING_W, "PM_PADDING_W"}, - {PM_ACTI, "PM_ACTI"}, - {PM_NUMDIM, "PM_NUMDIM"}, - {PM_AXIS, "PM_AXIS"}, - {PM_PERM, "PM_PERM"}, - {PM_OUTSHUFFLE, "PM_OUTSHUFFLE"}, - {PM_MERGE_GCONV_COUNT, "PM_MERGE_GCONV_COUNT"}, - {PM_AXES, "PM_AXES"}, - {PM_KEEP_DIMS, "PM_KEEP_DIMS"}, - {PM_EPSILON, "PM_EPSILON"}, - {PM_REPARTITION_DIM, "PM_REPARTITION_DIM"}, - {PM_REPARTITION_DEGREE, "PM_REPARTITION_DEGREE"}, - {PM_REPLICATE_DIM, "PM_REPLICATE_DIM"}, - {PM_REPLICATE_DEGREE, "PM_REPLICATE_DEGREE"}, - {PM_COMBINE_DIM, "PM_COMBINE_DIM"}, - {PM_COMBINE_DEGREE, "PM_COMBINE_DEGREE"}, - {PM_REDUCTION_DIM, "PM_REDUCTION_DIM"}, - {PM_REDUCTION_DEGREE, "PM_REDUCTION_DEGREE"}, - {PM_SOFTMAX_DIM, "PM_SOFTMAX_DIM"}, - {PM_NUM_HEADS, "PM_NUM_HEADS"}, - {PM_PARALLEL_DIM, "PM_PARALLEL_DIM"}, - {PM_PARALLEL_DEGREE, "PM_PARALLEL_DEGREE"}, - {PM_PAD, "PM_PAD"}}) - -NLOHMANN_JSON_SERIALIZE_ENUM(Op, - {{Op::NOOP, "OP_NOOP"}, - {Op::CONV2D, "OP_CONV2D"}, - {Op::DROPOUT, "OP_DROPOUT"}, - {Op::LINEAR, "OP_LINEAR"}, - {Op::BATCHMATMUL, "OP_BATCHMATMUL"}, - {Op::POOL2D, "OP_POOL2D_MAX"}, - {Op::SCALAR_MULTIPLY, "OP_SCALAR_MULTIPLY"}, - {Op::SCALAR_ADD, "OP_SCALAR_ADD"}, - {Op::SCALAR_FLOOR_DIV, "OP_SCALAR_FLOOR_DIV"}, - {Op::SCALAR_TRUE_DIV, "OP_SCALAR_TRUE_DIV"}, - {Op::SCALAR_SUB, "OP_SCALAR_SUB"}, - {Op::RELU, "OP_RELU"}, - {Op::IDENTITY, "OP_IDENTITY"}, - {Op::SIGMOID, "OP_SIGMOID"}, - {Op::TANH, "OP_TANH"}, - {Op::ELU, "OP_ELU"}, - {Op::FLAT, "OP_FLAT"}, - {Op::SOFTMAX, "OP_SOFTMAX"}, - {Op::BATCHNORM, "OP_BATCHNORM"}, - {Op::CONCAT, "OP_CONCAT"}, - {Op::SPLIT, "OP_SPLIT"}, - {Op::EMBEDDING, "OP_EMBEDDING"}, - {Op::CACHE, "OP_CACHE"}, - {Op::RESHAPE, "OP_RESHAPE"}, - {Op::REVERSE, "OP_REVERSE"}, - {Op::TRANSPOSE, "OP_TRANSPOSE"}, - {Op::EW_ADD, "OP_EW_ADD"}, - {Op::EW_MUL, "OP_EW_MUL"}, - {Op::MATMUL, "OP_MATMUL"}, - {Op::MUL, "OP_MUL"}, - {Op::ENLARGE, "OP_ENLARGE"}, - {Op::SQUEEZE, "OP_SQUEEZE"}, - {Op::UNSQUEEZE, "OP_UNSQUEEZE"}, - {Op::EW_SUB, "OP_EW_SUB"}, - {Op::EW_DIV, "OP_EW_DIV"}, - {Op::EW_EQUAL, "OP_EW_EQUAL"}, - {Op::EW_GREATER, "OP_EW_GREATER"}, - {Op::EW_LESS, "OP_EW_LESS"}, - {Op::EW_MAX, "OP_EW_MAX"}, - {Op::EW_MIN, "OP_EW_MIN"}, - {Op::REDUCE_ARGMAX, "OP_REDUCE_ARGMAX"}, - {Op::REDUCE_ARGMIN, "OP_REDUCE_ARGMIN"}, - {Op::REDUCE_MAX, "OP_REDUCE_MAX"}, - {Op::REDUCE_MEAN, "OP_REDUCE_MEAN"}, - {Op::REDUCE_MIN, "OP_REDUCE_MIN"}, - {Op::REDUCE_PROD, "OP_REDUCE_PROD"}, - {Op::REDUCE_SUM, "OP_REDUCE_SUM"}, - {Op::PAD, "OP_PAD"}, - {Op::SHAPE, "OP_SHAPE"}, - {Op::SIZE, "OP_SIZE"}, - {Op::TOPK, "OP_TOPK"}, - {Op::WHERE, "OP_WHERE"}, - {Op::CEIL, "OP_CEIL"}, - {Op::CAST, "OP_CAST"}, - {Op::EXP, "OP_EXP"}, - {Op::ROUND, "OP_ROUND"}, - {Op::LOG, "OP_LOG"}, - {Op::LOGICAL_NOT, "OP_LOGICAL_NOT"}, - {Op::SQRT, "OP_SQRT"}, - {Op::SIN, "OP_SIN"}, - {Op::COS, "OP_COS"}, - {Op::LEAKYRELU, "OP_LEAKYRELU"}, - {Op::SLICE, "OP_SLICE"}, - {Op::RESIZE, "OP_RESIZE"}, - {Op::PRELU, "OP_PRELU"}, - {Op::GELU, "OP_GELU"}, - {Op::MULTIHEAD_ATTENTION, - "OP_MULTIHEAD_ATTENTION"}, - {Op::FUSED, "OP_FUSED"}, - {Op::RSQRT, "OP_RSQRT"}, - {Op::POW, "OP_POW"}, - {Op::MEAN, "OP_MEAN"}, - {Op::LAYERNORM, "OP_LAYERNORM"}, - {Op::REPARTITION, "OP_PARTITION"}, - {Op::COMBINE, "OP_COMBINE"}, - {Op::REPLICATE, "OP_REPLICATE"}, - {Op::REDUCTION, "OP_REDUCE"}, - {Op::PIPELINE, "OP_PIPELINE"}, - {Op::FUSED_PARALLEL, "OP_FUSED_PARALLEL"}}) - struct Parameter { - PMParameter key; + LegacyPMParameter key; int value; }; void from_json(nlohmann::json const &j, Parameter &p); @@ -172,11 +22,11 @@ struct Tensor { void from_json(nlohmann::json const &j, Tensor &t); struct Operator { - OperatorType op_type; + LegacyOperatorType op_type; std::vector input; std::vector para; - std::optional at(PMParameter key) const; + std::optional at(LegacyPMParameter key) const; }; void from_json(nlohmann::json const &j, Operator &t); diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h new file mode 100644 index 0000000000..a74f6ef4f3 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.dtg.h @@ -0,0 +1,124 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyOperatorType { + NOOP, + INPUT, + WEIGHT, + CONV2D, + DROPOUT, + LINEAR, + BATCHMATMUL, + POOL2D, + SCALAR_MULTIPLY, + SCALAR_ADD, + SCALAR_FLOOR_DIV, + SCALAR_TRUE_DIV, + SCALAR_SUB, + RELU, + IDENTITY, + SIGMOID, + TANH, + ELU, + FLAT, + SOFTMAX, + BATCHNORM, + CONCAT, + SPLIT, + EMBEDDING, + CACHE, + RESHAPE, + REVERSE, + TRANSPOSE, + EW_ADD, + EW_MUL, + MATMUL, + MUL, + ENLARGE, + SQUEEZE, + UNSQUEEZE, + EW_SUB, + EW_DIV, + EW_EQUAL, + EW_GREATER, + EW_LESS, + EW_MAX, + EW_MIN, + REDUCE_ARGMAX, + REDUCE_ARGMIN, + REDUCE_MAX, + REDUCE_MEAN, + REDUCE_MIN, + REDUCE_PROD, + REDUCE_SUM, + PAD, + SHAPE, + SIZE, + TOPK, + WHERE, + CEIL, + CAST, + EXP, + ROUND, + LOG, + LOGICAL_NOT, + SQRT, + SIN, + COS, + LEAKYRELU, + SLICE, + RESIZE, + PRELU, + GELU, + MULTIHEAD_ATTENTION, + FUSED, + RSQRT, + POW, + MEAN, + LAYERNORM, + GATHER, + BROADCAST, + REPARTITION, + COMBINE, + REPLICATE, + REDUCTION, + BATCH, + PIPELINE, + FUSED_PARALLEL +}; +std::string format_as(LegacyOperatorType); +std::ostream &operator<<(std::ostream &, LegacyOperatorType); +void to_json(::nlohmann::json &, LegacyOperatorType); +void from_json(::nlohmann::json const &, LegacyOperatorType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyOperatorType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_OPERATOR_TYPE_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml new file mode 100644 index 0000000000..3f0bcccf6f --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml @@ -0,0 +1,95 @@ +namespace = "FlexFlow" +name = "LegacyOperatorType" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "NOOP", json_key = "OP_NOOP" }, + { name = "INPUT", json_key = "OP_INPUT" }, + { name = "WEIGHT", json_key = "OP_WEIGHT" }, + { name = "CONV2D", json_key = "OP_CONV2D" }, + { name = "DROPOUT", json_key = "OP_DROPOUT" }, + { name = "LINEAR", json_key = "OP_LINEAR" }, + { name = "BATCHMATMUL", json_key = "OP_BATCHMATMUL" }, + { name = "POOL2D", json_key = "OP_POOL2D" }, + { name = "SCALAR_MULTIPLY", json_key = "OP_SCALAR_MULTIPLY" }, + { name = "SCALAR_ADD", json_key = "OP_SCALAR_ADD" }, + { name = "SCALAR_FLOOR_DIV", json_key = "OP_SCALAR_FLOOR_DIV" }, + { name = "SCALAR_TRUE_DIV", json_key = "OP_SCALAR_TRUE_DIV" }, + { name = "SCALAR_SUB", json_key = "OP_SCALAR_SUB" }, + { name = "RELU", json_key = "OP_RELU" }, + { name = "IDENTITY", json_key = "OP_IDENTITY" }, + { name = "SIGMOID", json_key = "OP_SIGMOID" }, + { name = "TANH", json_key = "OP_TANH" }, + { name = "ELU", json_key = "OP_ELU" }, + { name = "FLAT", json_key = "OP_FLAT" }, + { name = "SOFTMAX", json_key = "OP_SOFTMAX" }, + { name = "BATCHNORM", json_key = "OP_BATCHNORM" }, + { name = "CONCAT", json_key = "OP_CONCAT" }, + { name = "SPLIT", json_key = "OP_SPLIT" }, + { name = "EMBEDDING", json_key = "OP_EMBEDDING" }, + { name = "CACHE", json_key = "OP_CACHE" }, + { name = "RESHAPE", json_key = "OP_RESHAPE" }, + { name = "REVERSE", json_key = "OP_REVERSE" }, + { name = "TRANSPOSE", json_key = "OP_TRANSPOSE" }, + { name = "EW_ADD", json_key = "OP_EW_ADD" }, + { name = "EW_MUL", json_key = "OP_EW_MUL" }, + { name = "MATMUL", json_key = "OP_MATMUL" }, + { name = "MUL", json_key = "OP_MUL" }, + { name = "ENLARGE", json_key = "OP_ENLARGE" }, + { name = "SQUEEZE", json_key = "OP_SQUEEZE" }, + { name = "UNSQUEEZE", json_key = "OP_UNSQUEEZE" }, + { name = "EW_SUB", json_key = "OP_EW_SUB" }, + { name = "EW_DIV", json_key = "OP_EW_DIV" }, + { name = "EW_EQUAL", json_key = "OP_EW_EQUAL" }, + { name = "EW_GREATER", json_key = "OP_EW_GREATER" }, + { name = "EW_LESS", json_key = "OP_EW_LESS" }, + { name = "EW_MAX", json_key = "OP_EW_MAX" }, + { name = "EW_MIN", json_key = "OP_EW_MIN" }, + { name = "REDUCE_ARGMAX", json_key = "OP_REDUCE_ARGMAX" }, + { name = "REDUCE_ARGMIN", json_key = "OP_REDUCE_ARGMIN" }, + { name = "REDUCE_MAX", json_key = "OP_REDUCE_MAX" }, + { name = "REDUCE_MEAN", json_key = "OP_REDUCE_MEAN" }, + { name = "REDUCE_MIN", json_key = "OP_REDUCE_MIN" }, + { name = "REDUCE_PROD", json_key = "OP_REDUCE_PROD" }, + { name = "REDUCE_SUM", json_key = "OP_REDUCE_SUM" }, + { name = "PAD", json_key = "OP_PAD" }, + { name = "SHAPE", json_key = "OP_SHAPE" }, + { name = "SIZE", json_key = "OP_SIZE" }, + { name = "TOPK", json_key = "OP_TOPK" }, + { name = "WHERE", json_key = "OP_WHERE" }, + { name = "CEIL", json_key = "OP_CEIL" }, + { name = "CAST", json_key = "OP_CAST" }, + { name = "EXP", json_key = "OP_EXP" }, + { name = "ROUND", json_key = "OP_ROUND" }, + { name = "LOG", json_key = "OP_LOG" }, + { name = "LOGICAL_NOT", json_key = "OP_LOGICAL_NOT" }, + { name = "SQRT", json_key = "OP_SQRT" }, + { name = "SIN", json_key = "OP_SIN" }, + { name = "COS", json_key = "OP_COS" }, + { name = "LEAKYRELU", json_key = "OP_LEAKYRELU" }, + { name = "SLICE", json_key = "OP_SLICE" }, + { name = "RESIZE", json_key = "OP_RESIZE" }, + { name = "PRELU", json_key = "OP_PRELU" }, + { name = "GELU", json_key = "OP_GELU" }, + { name = "MULTIHEAD_ATTENTION", json_key = "OP_MULTIHEAD_ATTENTION" }, + { name = "FUSED", json_key = "OP_FUSED" }, + { name = "RSQRT", json_key = "OP_RSQRT" }, + { name = "POW", json_key = "OP_POW" }, + { name = "MEAN", json_key = "OP_MEAN" }, + { name = "LAYERNORM", json_key = "OP_LAYERNORM" }, + { name = "GATHER", json_key = "OP_GATHER" }, + { name = "BROADCAST", json_key = "OP_BROADCAST" }, + { name = "REPARTITION", json_key = "OP_PARTITION" }, + { name = "COMBINE", json_key = "OP_COMBINE" }, + { name = "REPLICATE", json_key = "OP_REPLICATE" }, + { name = "REDUCTION", json_key = "OP_REDUCE" }, + { name = "BATCH", json_key = "OP_BATCH" }, + { name = "PIPELINE", json_key = "OP_PIPELINE" }, + { name = "FUSED_PARALLEL", json_key = "OP_FUSED_PARALLEL" }, +] diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h new file mode 100644 index 0000000000..2435024b9a --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.dtg.h @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class LegacyPMParameter { + OP_TYPE, + NUM_INPUTS, + NUM_OUTPUTS, + GROUP, + KERNEL_H, + KERNEL_W, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + ACTI, + NUMDIM, + AXIS, + PERM, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + REPARTITION_DIM, + REPARTITION_DEGREE, + REPLICATE_DIM, + REPLICATE_DEGREE, + COMBINE_DIM, + COMBINE_DEGREE, + REDUCTION_DIM, + REDUCTION_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD +}; +std::string format_as(LegacyPMParameter); +std::ostream &operator<<(std::ostream &, LegacyPMParameter); +void to_json(::nlohmann::json &, LegacyPMParameter); +void from_json(::nlohmann::json const &, LegacyPMParameter &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::LegacyPMParameter) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_PM_PARAMETER_DTG_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml new file mode 100644 index 0000000000..e71a71a5a8 --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml @@ -0,0 +1,44 @@ +namespace = "FlexFlow" +name = "LegacyPMParameter" + +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +values = [ + { name = "OP_TYPE", json_key = "PM_OP_TYPE" }, + { name = "NUM_INPUTS", json_key = "PM_NUM_INPUTS" }, + { name = "NUM_OUTPUTS", json_key = "PM_NUM_OUTPUTS" }, + { name = "GROUP", json_key = "PM_GROUP" }, + { name = "KERNEL_H", json_key = "PM_KERNEL_H" }, + { name = "KERNEL_W", json_key = "PM_KERNEL_W" }, + { name = "STRIDE_H", json_key = "PM_STRIDE_H" }, + { name = "STRIDE_W", json_key = "PM_STRIDE_W" }, + { name = "PADDING_H", json_key = "PM_PADDING_H" }, + { name = "PADDING_W", json_key = "PM_PADDING_W" }, + { name = "ACTI", json_key = "PM_ACTI" }, + { name = "NUMDIM", json_key = "PM_NUMDIM" }, + { name = "AXIS", json_key = "PM_AXIS" }, + { name = "PERM", json_key = "PM_PERM" }, + { name = "OUTSHUFFLE", json_key = "PM_OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT", json_key = "PM_MERGE_GCONV_COUNT" }, + { name = "AXES", json_key = "PM_AXES" }, + { name = "KEEP_DIMS", json_key = "PM_KEEP_DIMS" }, + { name = "EPSILON", json_key = "PM_EPSILON" }, + { name = "REPARTITION_DIM", json_key = "PM_REPARTITION_DIM" }, + { name = "REPARTITION_DEGREE", json_key = "PM_REPARTITION_DEGREE" }, + { name = "REPLICATE_DIM", json_key = "PM_REPLICATE_DIM" }, + { name = "REPLICATE_DEGREE", json_key = "PM_REPLICATE_DEGREE" }, + { name = "COMBINE_DIM", json_key = "PM_COMBINE_DIM" }, + { name = "COMBINE_DEGREE", json_key = "PM_COMBINE_DEGREE" }, + { name = "REDUCTION_DIM", json_key = "PM_REDUCTION_DIM" }, + { name = "REDUCTION_DEGREE", json_key = "PM_REDUCTION_DEGREE" }, + { name = "SOFTMAX_DIM", json_key = "PM_SOFTMAX_DIM" }, + { name = "NUM_HEADS", json_key = "PM_NUM_HEADS" }, + { name = "PARALLEL_DIM", json_key = "PM_PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE", json_key = "PM_PARALLEL_DEGREE" }, + { name = "PAD", json_key = "PM_PAD" }, +] diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/json.cc index 7e6a93b863..940ecb3e36 100644 --- a/lib/substitution-generator/src/substitution-generator/json.cc +++ b/lib/substitution-generator/src/substitution-generator/json.cc @@ -10,11 +10,6 @@ namespace FlexFlow { void from_json(json const &j, Parameter &p) { j.at("key").get_to(p.key); j.at("value").get_to(p.value); - if (p.key == PM_INVALID) { - std::ostringstream oss; - oss << "Attempted to load invalid PMParameter: " << j.at("key"); - throw std::runtime_error(oss.str()); - } } void from_json(json const &j, Tensor &t) { @@ -22,17 +17,17 @@ void from_json(json const &j, Tensor &t) { j.at("tsId").get_to(t.tsId); } -std::optional Operator::at(PMParameter key) const { - std::optional value = std::nullopt; - for (Parameter const &p : this->para) { - if (p.key == key) { - assert(!value.has_value()); - value = p.key; - } - } +/* std::optional Operator::at(LegacyPMParameter key) const { */ +/* std::optional value = std::nullopt; */ +/* for (Parameter const &p : this->para) { */ +/* if (p.key == key) { */ +/* assert(!value.has_value()); */ +/* value = p.key; */ +/* } */ +/* } */ - return value; -} +/* return value; */ +/* } */ void from_json(json const &j, Operator &o) { j.at("type").get_to(o.op_type); diff --git a/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc new file mode 100644 index 0000000000..94c65e33fd --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_operator_type.dtg.cc @@ -0,0 +1,721 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_operator_type.enum.toml +/* proj-data +{ + "generated_from": "d6ba52e2b0d58b7cb533dae3894b0486" +} +*/ + +#include "substitution-generator/legacy_operator_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyOperatorType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + return "NOOP"; + case LegacyOperatorType::INPUT: + return "INPUT"; + case LegacyOperatorType::WEIGHT: + return "WEIGHT"; + case LegacyOperatorType::CONV2D: + return "CONV2D"; + case LegacyOperatorType::DROPOUT: + return "DROPOUT"; + case LegacyOperatorType::LINEAR: + return "LINEAR"; + case LegacyOperatorType::BATCHMATMUL: + return "BATCHMATMUL"; + case LegacyOperatorType::POOL2D: + return "POOL2D"; + case LegacyOperatorType::SCALAR_MULTIPLY: + return "SCALAR_MULTIPLY"; + case LegacyOperatorType::SCALAR_ADD: + return "SCALAR_ADD"; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + return "SCALAR_FLOOR_DIV"; + case LegacyOperatorType::SCALAR_TRUE_DIV: + return "SCALAR_TRUE_DIV"; + case LegacyOperatorType::SCALAR_SUB: + return "SCALAR_SUB"; + case LegacyOperatorType::RELU: + return "RELU"; + case LegacyOperatorType::IDENTITY: + return "IDENTITY"; + case LegacyOperatorType::SIGMOID: + return "SIGMOID"; + case LegacyOperatorType::TANH: + return "TANH"; + case LegacyOperatorType::ELU: + return "ELU"; + case LegacyOperatorType::FLAT: + return "FLAT"; + case LegacyOperatorType::SOFTMAX: + return "SOFTMAX"; + case LegacyOperatorType::BATCHNORM: + return "BATCHNORM"; + case LegacyOperatorType::CONCAT: + return "CONCAT"; + case LegacyOperatorType::SPLIT: + return "SPLIT"; + case LegacyOperatorType::EMBEDDING: + return "EMBEDDING"; + case LegacyOperatorType::CACHE: + return "CACHE"; + case LegacyOperatorType::RESHAPE: + return "RESHAPE"; + case LegacyOperatorType::REVERSE: + return "REVERSE"; + case LegacyOperatorType::TRANSPOSE: + return "TRANSPOSE"; + case LegacyOperatorType::EW_ADD: + return "EW_ADD"; + case LegacyOperatorType::EW_MUL: + return "EW_MUL"; + case LegacyOperatorType::MATMUL: + return "MATMUL"; + case LegacyOperatorType::MUL: + return "MUL"; + case LegacyOperatorType::ENLARGE: + return "ENLARGE"; + case LegacyOperatorType::SQUEEZE: + return "SQUEEZE"; + case LegacyOperatorType::UNSQUEEZE: + return "UNSQUEEZE"; + case LegacyOperatorType::EW_SUB: + return "EW_SUB"; + case LegacyOperatorType::EW_DIV: + return "EW_DIV"; + case LegacyOperatorType::EW_EQUAL: + return "EW_EQUAL"; + case LegacyOperatorType::EW_GREATER: + return "EW_GREATER"; + case LegacyOperatorType::EW_LESS: + return "EW_LESS"; + case LegacyOperatorType::EW_MAX: + return "EW_MAX"; + case LegacyOperatorType::EW_MIN: + return "EW_MIN"; + case LegacyOperatorType::REDUCE_ARGMAX: + return "REDUCE_ARGMAX"; + case LegacyOperatorType::REDUCE_ARGMIN: + return "REDUCE_ARGMIN"; + case LegacyOperatorType::REDUCE_MAX: + return "REDUCE_MAX"; + case LegacyOperatorType::REDUCE_MEAN: + return "REDUCE_MEAN"; + case LegacyOperatorType::REDUCE_MIN: + return "REDUCE_MIN"; + case LegacyOperatorType::REDUCE_PROD: + return "REDUCE_PROD"; + case LegacyOperatorType::REDUCE_SUM: + return "REDUCE_SUM"; + case LegacyOperatorType::PAD: + return "PAD"; + case LegacyOperatorType::SHAPE: + return "SHAPE"; + case LegacyOperatorType::SIZE: + return "SIZE"; + case LegacyOperatorType::TOPK: + return "TOPK"; + case LegacyOperatorType::WHERE: + return "WHERE"; + case LegacyOperatorType::CEIL: + return "CEIL"; + case LegacyOperatorType::CAST: + return "CAST"; + case LegacyOperatorType::EXP: + return "EXP"; + case LegacyOperatorType::ROUND: + return "ROUND"; + case LegacyOperatorType::LOG: + return "LOG"; + case LegacyOperatorType::LOGICAL_NOT: + return "LOGICAL_NOT"; + case LegacyOperatorType::SQRT: + return "SQRT"; + case LegacyOperatorType::SIN: + return "SIN"; + case LegacyOperatorType::COS: + return "COS"; + case LegacyOperatorType::LEAKYRELU: + return "LEAKYRELU"; + case LegacyOperatorType::SLICE: + return "SLICE"; + case LegacyOperatorType::RESIZE: + return "RESIZE"; + case LegacyOperatorType::PRELU: + return "PRELU"; + case LegacyOperatorType::GELU: + return "GELU"; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + return "MULTIHEAD_ATTENTION"; + case LegacyOperatorType::FUSED: + return "FUSED"; + case LegacyOperatorType::RSQRT: + return "RSQRT"; + case LegacyOperatorType::POW: + return "POW"; + case LegacyOperatorType::MEAN: + return "MEAN"; + case LegacyOperatorType::LAYERNORM: + return "LAYERNORM"; + case LegacyOperatorType::GATHER: + return "GATHER"; + case LegacyOperatorType::BROADCAST: + return "BROADCAST"; + case LegacyOperatorType::REPARTITION: + return "REPARTITION"; + case LegacyOperatorType::COMBINE: + return "COMBINE"; + case LegacyOperatorType::REPLICATE: + return "REPLICATE"; + case LegacyOperatorType::REDUCTION: + return "REDUCTION"; + case LegacyOperatorType::BATCH: + return "BATCH"; + case LegacyOperatorType::PIPELINE: + return "PIPELINE"; + case LegacyOperatorType::FUSED_PARALLEL: + return "FUSED_PARALLEL"; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyOperatorType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyOperatorType x) { + switch (x) { + case LegacyOperatorType::NOOP: + j = "OP_NOOP"; + break; + case LegacyOperatorType::INPUT: + j = "OP_INPUT"; + break; + case LegacyOperatorType::WEIGHT: + j = "OP_WEIGHT"; + break; + case LegacyOperatorType::CONV2D: + j = "OP_CONV2D"; + break; + case LegacyOperatorType::DROPOUT: + j = "OP_DROPOUT"; + break; + case LegacyOperatorType::LINEAR: + j = "OP_LINEAR"; + break; + case LegacyOperatorType::BATCHMATMUL: + j = "OP_BATCHMATMUL"; + break; + case LegacyOperatorType::POOL2D: + j = "OP_POOL2D"; + break; + case LegacyOperatorType::SCALAR_MULTIPLY: + j = "OP_SCALAR_MULTIPLY"; + break; + case LegacyOperatorType::SCALAR_ADD: + j = "OP_SCALAR_ADD"; + break; + case LegacyOperatorType::SCALAR_FLOOR_DIV: + j = "OP_SCALAR_FLOOR_DIV"; + break; + case LegacyOperatorType::SCALAR_TRUE_DIV: + j = "OP_SCALAR_TRUE_DIV"; + break; + case LegacyOperatorType::SCALAR_SUB: + j = "OP_SCALAR_SUB"; + break; + case LegacyOperatorType::RELU: + j = "OP_RELU"; + break; + case LegacyOperatorType::IDENTITY: + j = "OP_IDENTITY"; + break; + case LegacyOperatorType::SIGMOID: + j = "OP_SIGMOID"; + break; + case LegacyOperatorType::TANH: + j = "OP_TANH"; + break; + case LegacyOperatorType::ELU: + j = "OP_ELU"; + break; + case LegacyOperatorType::FLAT: + j = "OP_FLAT"; + break; + case LegacyOperatorType::SOFTMAX: + j = "OP_SOFTMAX"; + break; + case LegacyOperatorType::BATCHNORM: + j = "OP_BATCHNORM"; + break; + case LegacyOperatorType::CONCAT: + j = "OP_CONCAT"; + break; + case LegacyOperatorType::SPLIT: + j = "OP_SPLIT"; + break; + case LegacyOperatorType::EMBEDDING: + j = "OP_EMBEDDING"; + break; + case LegacyOperatorType::CACHE: + j = "OP_CACHE"; + break; + case LegacyOperatorType::RESHAPE: + j = "OP_RESHAPE"; + break; + case LegacyOperatorType::REVERSE: + j = "OP_REVERSE"; + break; + case LegacyOperatorType::TRANSPOSE: + j = "OP_TRANSPOSE"; + break; + case LegacyOperatorType::EW_ADD: + j = "OP_EW_ADD"; + break; + case LegacyOperatorType::EW_MUL: + j = "OP_EW_MUL"; + break; + case LegacyOperatorType::MATMUL: + j = "OP_MATMUL"; + break; + case LegacyOperatorType::MUL: + j = "OP_MUL"; + break; + case LegacyOperatorType::ENLARGE: + j = "OP_ENLARGE"; + break; + case LegacyOperatorType::SQUEEZE: + j = "OP_SQUEEZE"; + break; + case LegacyOperatorType::UNSQUEEZE: + j = "OP_UNSQUEEZE"; + break; + case LegacyOperatorType::EW_SUB: + j = "OP_EW_SUB"; + break; + case LegacyOperatorType::EW_DIV: + j = "OP_EW_DIV"; + break; + case LegacyOperatorType::EW_EQUAL: + j = "OP_EW_EQUAL"; + break; + case LegacyOperatorType::EW_GREATER: + j = "OP_EW_GREATER"; + break; + case LegacyOperatorType::EW_LESS: + j = "OP_EW_LESS"; + break; + case LegacyOperatorType::EW_MAX: + j = "OP_EW_MAX"; + break; + case LegacyOperatorType::EW_MIN: + j = "OP_EW_MIN"; + break; + case LegacyOperatorType::REDUCE_ARGMAX: + j = "OP_REDUCE_ARGMAX"; + break; + case LegacyOperatorType::REDUCE_ARGMIN: + j = "OP_REDUCE_ARGMIN"; + break; + case LegacyOperatorType::REDUCE_MAX: + j = "OP_REDUCE_MAX"; + break; + case LegacyOperatorType::REDUCE_MEAN: + j = "OP_REDUCE_MEAN"; + break; + case LegacyOperatorType::REDUCE_MIN: + j = "OP_REDUCE_MIN"; + break; + case LegacyOperatorType::REDUCE_PROD: + j = "OP_REDUCE_PROD"; + break; + case LegacyOperatorType::REDUCE_SUM: + j = "OP_REDUCE_SUM"; + break; + case LegacyOperatorType::PAD: + j = "OP_PAD"; + break; + case LegacyOperatorType::SHAPE: + j = "OP_SHAPE"; + break; + case LegacyOperatorType::SIZE: + j = "OP_SIZE"; + break; + case LegacyOperatorType::TOPK: + j = "OP_TOPK"; + break; + case LegacyOperatorType::WHERE: + j = "OP_WHERE"; + break; + case LegacyOperatorType::CEIL: + j = "OP_CEIL"; + break; + case LegacyOperatorType::CAST: + j = "OP_CAST"; + break; + case LegacyOperatorType::EXP: + j = "OP_EXP"; + break; + case LegacyOperatorType::ROUND: + j = "OP_ROUND"; + break; + case LegacyOperatorType::LOG: + j = "OP_LOG"; + break; + case LegacyOperatorType::LOGICAL_NOT: + j = "OP_LOGICAL_NOT"; + break; + case LegacyOperatorType::SQRT: + j = "OP_SQRT"; + break; + case LegacyOperatorType::SIN: + j = "OP_SIN"; + break; + case LegacyOperatorType::COS: + j = "OP_COS"; + break; + case LegacyOperatorType::LEAKYRELU: + j = "OP_LEAKYRELU"; + break; + case LegacyOperatorType::SLICE: + j = "OP_SLICE"; + break; + case LegacyOperatorType::RESIZE: + j = "OP_RESIZE"; + break; + case LegacyOperatorType::PRELU: + j = "OP_PRELU"; + break; + case LegacyOperatorType::GELU: + j = "OP_GELU"; + break; + case LegacyOperatorType::MULTIHEAD_ATTENTION: + j = "OP_MULTIHEAD_ATTENTION"; + break; + case LegacyOperatorType::FUSED: + j = "OP_FUSED"; + break; + case LegacyOperatorType::RSQRT: + j = "OP_RSQRT"; + break; + case LegacyOperatorType::POW: + j = "OP_POW"; + break; + case LegacyOperatorType::MEAN: + j = "OP_MEAN"; + break; + case LegacyOperatorType::LAYERNORM: + j = "OP_LAYERNORM"; + break; + case LegacyOperatorType::GATHER: + j = "OP_GATHER"; + break; + case LegacyOperatorType::BROADCAST: + j = "OP_BROADCAST"; + break; + case LegacyOperatorType::REPARTITION: + j = "OP_PARTITION"; + break; + case LegacyOperatorType::COMBINE: + j = "OP_COMBINE"; + break; + case LegacyOperatorType::REPLICATE: + j = "OP_REPLICATE"; + break; + case LegacyOperatorType::REDUCTION: + j = "OP_REDUCE"; + break; + case LegacyOperatorType::BATCH: + j = "OP_BATCH"; + break; + case LegacyOperatorType::PIPELINE: + j = "OP_PIPELINE"; + break; + case LegacyOperatorType::FUSED_PARALLEL: + j = "OP_FUSED_PARALLEL"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyOperatorType &x) { + std::string as_str = j.get(); + if (as_str == "OP_NOOP") { + x = LegacyOperatorType::NOOP; + } else if (as_str == "OP_INPUT") { + x = LegacyOperatorType::INPUT; + } else if (as_str == "OP_WEIGHT") { + x = LegacyOperatorType::WEIGHT; + } else if (as_str == "OP_CONV2D") { + x = LegacyOperatorType::CONV2D; + } else if (as_str == "OP_DROPOUT") { + x = LegacyOperatorType::DROPOUT; + } else if (as_str == "OP_LINEAR") { + x = LegacyOperatorType::LINEAR; + } else if (as_str == "OP_BATCHMATMUL") { + x = LegacyOperatorType::BATCHMATMUL; + } else if (as_str == "OP_POOL2D") { + x = LegacyOperatorType::POOL2D; + } else if (as_str == "OP_SCALAR_MULTIPLY") { + x = LegacyOperatorType::SCALAR_MULTIPLY; + } else if (as_str == "OP_SCALAR_ADD") { + x = LegacyOperatorType::SCALAR_ADD; + } else if (as_str == "OP_SCALAR_FLOOR_DIV") { + x = LegacyOperatorType::SCALAR_FLOOR_DIV; + } else if (as_str == "OP_SCALAR_TRUE_DIV") { + x = LegacyOperatorType::SCALAR_TRUE_DIV; + } else if (as_str == "OP_SCALAR_SUB") { + x = LegacyOperatorType::SCALAR_SUB; + } else if (as_str == "OP_RELU") { + x = LegacyOperatorType::RELU; + } else if (as_str == "OP_IDENTITY") { + x = LegacyOperatorType::IDENTITY; + } else if (as_str == "OP_SIGMOID") { + x = LegacyOperatorType::SIGMOID; + } else if (as_str == "OP_TANH") { + x = LegacyOperatorType::TANH; + } else if (as_str == "OP_ELU") { + x = LegacyOperatorType::ELU; + } else if (as_str == "OP_FLAT") { + x = LegacyOperatorType::FLAT; + } else if (as_str == "OP_SOFTMAX") { + x = LegacyOperatorType::SOFTMAX; + } else if (as_str == "OP_BATCHNORM") { + x = LegacyOperatorType::BATCHNORM; + } else if (as_str == "OP_CONCAT") { + x = LegacyOperatorType::CONCAT; + } else if (as_str == "OP_SPLIT") { + x = LegacyOperatorType::SPLIT; + } else if (as_str == "OP_EMBEDDING") { + x = LegacyOperatorType::EMBEDDING; + } else if (as_str == "OP_CACHE") { + x = LegacyOperatorType::CACHE; + } else if (as_str == "OP_RESHAPE") { + x = LegacyOperatorType::RESHAPE; + } else if (as_str == "OP_REVERSE") { + x = LegacyOperatorType::REVERSE; + } else if (as_str == "OP_TRANSPOSE") { + x = LegacyOperatorType::TRANSPOSE; + } else if (as_str == "OP_EW_ADD") { + x = LegacyOperatorType::EW_ADD; + } else if (as_str == "OP_EW_MUL") { + x = LegacyOperatorType::EW_MUL; + } else if (as_str == "OP_MATMUL") { + x = LegacyOperatorType::MATMUL; + } else if (as_str == "OP_MUL") { + x = LegacyOperatorType::MUL; + } else if (as_str == "OP_ENLARGE") { + x = LegacyOperatorType::ENLARGE; + } else if (as_str == "OP_SQUEEZE") { + x = LegacyOperatorType::SQUEEZE; + } else if (as_str == "OP_UNSQUEEZE") { + x = LegacyOperatorType::UNSQUEEZE; + } else if (as_str == "OP_EW_SUB") { + x = LegacyOperatorType::EW_SUB; + } else if (as_str == "OP_EW_DIV") { + x = LegacyOperatorType::EW_DIV; + } else if (as_str == "OP_EW_EQUAL") { + x = LegacyOperatorType::EW_EQUAL; + } else if (as_str == "OP_EW_GREATER") { + x = LegacyOperatorType::EW_GREATER; + } else if (as_str == "OP_EW_LESS") { + x = LegacyOperatorType::EW_LESS; + } else if (as_str == "OP_EW_MAX") { + x = LegacyOperatorType::EW_MAX; + } else if (as_str == "OP_EW_MIN") { + x = LegacyOperatorType::EW_MIN; + } else if (as_str == "OP_REDUCE_ARGMAX") { + x = LegacyOperatorType::REDUCE_ARGMAX; + } else if (as_str == "OP_REDUCE_ARGMIN") { + x = LegacyOperatorType::REDUCE_ARGMIN; + } else if (as_str == "OP_REDUCE_MAX") { + x = LegacyOperatorType::REDUCE_MAX; + } else if (as_str == "OP_REDUCE_MEAN") { + x = LegacyOperatorType::REDUCE_MEAN; + } else if (as_str == "OP_REDUCE_MIN") { + x = LegacyOperatorType::REDUCE_MIN; + } else if (as_str == "OP_REDUCE_PROD") { + x = LegacyOperatorType::REDUCE_PROD; + } else if (as_str == "OP_REDUCE_SUM") { + x = LegacyOperatorType::REDUCE_SUM; + } else if (as_str == "OP_PAD") { + x = LegacyOperatorType::PAD; + } else if (as_str == "OP_SHAPE") { + x = LegacyOperatorType::SHAPE; + } else if (as_str == "OP_SIZE") { + x = LegacyOperatorType::SIZE; + } else if (as_str == "OP_TOPK") { + x = LegacyOperatorType::TOPK; + } else if (as_str == "OP_WHERE") { + x = LegacyOperatorType::WHERE; + } else if (as_str == "OP_CEIL") { + x = LegacyOperatorType::CEIL; + } else if (as_str == "OP_CAST") { + x = LegacyOperatorType::CAST; + } else if (as_str == "OP_EXP") { + x = LegacyOperatorType::EXP; + } else if (as_str == "OP_ROUND") { + x = LegacyOperatorType::ROUND; + } else if (as_str == "OP_LOG") { + x = LegacyOperatorType::LOG; + } else if (as_str == "OP_LOGICAL_NOT") { + x = LegacyOperatorType::LOGICAL_NOT; + } else if (as_str == "OP_SQRT") { + x = LegacyOperatorType::SQRT; + } else if (as_str == "OP_SIN") { + x = LegacyOperatorType::SIN; + } else if (as_str == "OP_COS") { + x = LegacyOperatorType::COS; + } else if (as_str == "OP_LEAKYRELU") { + x = LegacyOperatorType::LEAKYRELU; + } else if (as_str == "OP_SLICE") { + x = LegacyOperatorType::SLICE; + } else if (as_str == "OP_RESIZE") { + x = LegacyOperatorType::RESIZE; + } else if (as_str == "OP_PRELU") { + x = LegacyOperatorType::PRELU; + } else if (as_str == "OP_GELU") { + x = LegacyOperatorType::GELU; + } else if (as_str == "OP_MULTIHEAD_ATTENTION") { + x = LegacyOperatorType::MULTIHEAD_ATTENTION; + } else if (as_str == "OP_FUSED") { + x = LegacyOperatorType::FUSED; + } else if (as_str == "OP_RSQRT") { + x = LegacyOperatorType::RSQRT; + } else if (as_str == "OP_POW") { + x = LegacyOperatorType::POW; + } else if (as_str == "OP_MEAN") { + x = LegacyOperatorType::MEAN; + } else if (as_str == "OP_LAYERNORM") { + x = LegacyOperatorType::LAYERNORM; + } else if (as_str == "OP_GATHER") { + x = LegacyOperatorType::GATHER; + } else if (as_str == "OP_BROADCAST") { + x = LegacyOperatorType::BROADCAST; + } else if (as_str == "OP_PARTITION") { + x = LegacyOperatorType::REPARTITION; + } else if (as_str == "OP_COMBINE") { + x = LegacyOperatorType::COMBINE; + } else if (as_str == "OP_REPLICATE") { + x = LegacyOperatorType::REPLICATE; + } else if (as_str == "OP_REDUCE") { + x = LegacyOperatorType::REDUCTION; + } else if (as_str == "OP_BATCH") { + x = LegacyOperatorType::BATCH; + } else if (as_str == "OP_PIPELINE") { + x = LegacyOperatorType::PIPELINE; + } else if (as_str == "OP_FUSED_PARALLEL") { + x = LegacyOperatorType::FUSED_PARALLEL; + } else { + std::ostringstream oss; + oss << "Unknown LegacyOperatorType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyOperatorType::NOOP, + FlexFlow::LegacyOperatorType::INPUT, + FlexFlow::LegacyOperatorType::WEIGHT, + FlexFlow::LegacyOperatorType::CONV2D, + FlexFlow::LegacyOperatorType::DROPOUT, + FlexFlow::LegacyOperatorType::LINEAR, + FlexFlow::LegacyOperatorType::BATCHMATMUL, + FlexFlow::LegacyOperatorType::POOL2D, + FlexFlow::LegacyOperatorType::SCALAR_MULTIPLY, + FlexFlow::LegacyOperatorType::SCALAR_ADD, + FlexFlow::LegacyOperatorType::SCALAR_FLOOR_DIV, + FlexFlow::LegacyOperatorType::SCALAR_TRUE_DIV, + FlexFlow::LegacyOperatorType::SCALAR_SUB, + FlexFlow::LegacyOperatorType::RELU, + FlexFlow::LegacyOperatorType::IDENTITY, + FlexFlow::LegacyOperatorType::SIGMOID, + FlexFlow::LegacyOperatorType::TANH, + FlexFlow::LegacyOperatorType::ELU, + FlexFlow::LegacyOperatorType::FLAT, + FlexFlow::LegacyOperatorType::SOFTMAX, + FlexFlow::LegacyOperatorType::BATCHNORM, + FlexFlow::LegacyOperatorType::CONCAT, + FlexFlow::LegacyOperatorType::SPLIT, + FlexFlow::LegacyOperatorType::EMBEDDING, + FlexFlow::LegacyOperatorType::CACHE, + FlexFlow::LegacyOperatorType::RESHAPE, + FlexFlow::LegacyOperatorType::REVERSE, + FlexFlow::LegacyOperatorType::TRANSPOSE, + FlexFlow::LegacyOperatorType::EW_ADD, + FlexFlow::LegacyOperatorType::EW_MUL, + FlexFlow::LegacyOperatorType::MATMUL, + FlexFlow::LegacyOperatorType::MUL, + FlexFlow::LegacyOperatorType::ENLARGE, + FlexFlow::LegacyOperatorType::SQUEEZE, + FlexFlow::LegacyOperatorType::UNSQUEEZE, + FlexFlow::LegacyOperatorType::EW_SUB, + FlexFlow::LegacyOperatorType::EW_DIV, + FlexFlow::LegacyOperatorType::EW_EQUAL, + FlexFlow::LegacyOperatorType::EW_GREATER, + FlexFlow::LegacyOperatorType::EW_LESS, + FlexFlow::LegacyOperatorType::EW_MAX, + FlexFlow::LegacyOperatorType::EW_MIN, + FlexFlow::LegacyOperatorType::REDUCE_ARGMAX, + FlexFlow::LegacyOperatorType::REDUCE_ARGMIN, + FlexFlow::LegacyOperatorType::REDUCE_MAX, + FlexFlow::LegacyOperatorType::REDUCE_MEAN, + FlexFlow::LegacyOperatorType::REDUCE_MIN, + FlexFlow::LegacyOperatorType::REDUCE_PROD, + FlexFlow::LegacyOperatorType::REDUCE_SUM, + FlexFlow::LegacyOperatorType::PAD, + FlexFlow::LegacyOperatorType::SHAPE, + FlexFlow::LegacyOperatorType::SIZE, + FlexFlow::LegacyOperatorType::TOPK, + FlexFlow::LegacyOperatorType::WHERE, + FlexFlow::LegacyOperatorType::CEIL, + FlexFlow::LegacyOperatorType::CAST, + FlexFlow::LegacyOperatorType::EXP, + FlexFlow::LegacyOperatorType::ROUND, + FlexFlow::LegacyOperatorType::LOG, + FlexFlow::LegacyOperatorType::LOGICAL_NOT, + FlexFlow::LegacyOperatorType::SQRT, + FlexFlow::LegacyOperatorType::SIN, + FlexFlow::LegacyOperatorType::COS, + FlexFlow::LegacyOperatorType::LEAKYRELU, + FlexFlow::LegacyOperatorType::SLICE, + FlexFlow::LegacyOperatorType::RESIZE, + FlexFlow::LegacyOperatorType::PRELU, + FlexFlow::LegacyOperatorType::GELU, + FlexFlow::LegacyOperatorType::MULTIHEAD_ATTENTION, + FlexFlow::LegacyOperatorType::FUSED, + FlexFlow::LegacyOperatorType::RSQRT, + FlexFlow::LegacyOperatorType::POW, + FlexFlow::LegacyOperatorType::MEAN, + FlexFlow::LegacyOperatorType::LAYERNORM, + FlexFlow::LegacyOperatorType::GATHER, + FlexFlow::LegacyOperatorType::BROADCAST, + FlexFlow::LegacyOperatorType::REPARTITION, + FlexFlow::LegacyOperatorType::COMBINE, + FlexFlow::LegacyOperatorType::REPLICATE, + FlexFlow::LegacyOperatorType::REDUCTION, + FlexFlow::LegacyOperatorType::BATCH, + FlexFlow::LegacyOperatorType::PIPELINE, + FlexFlow::LegacyOperatorType::FUSED_PARALLEL); +} +} // namespace rc diff --git a/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc new file mode 100644 index 0000000000..c8df4ccd7d --- /dev/null +++ b/lib/substitution-generator/src/substitution-generator/legacy_pm_parameter.dtg.cc @@ -0,0 +1,313 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitution-generator/include/substitution-generator/legacy_pm_parameter.enum.toml +/* proj-data +{ + "generated_from": "e8dda0047c91e576878b86df2fec0b6b" +} +*/ + +#include "substitution-generator/legacy_pm_parameter.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::LegacyPMParameter x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + return "OP_TYPE"; + case LegacyPMParameter::NUM_INPUTS: + return "NUM_INPUTS"; + case LegacyPMParameter::NUM_OUTPUTS: + return "NUM_OUTPUTS"; + case LegacyPMParameter::GROUP: + return "GROUP"; + case LegacyPMParameter::KERNEL_H: + return "KERNEL_H"; + case LegacyPMParameter::KERNEL_W: + return "KERNEL_W"; + case LegacyPMParameter::STRIDE_H: + return "STRIDE_H"; + case LegacyPMParameter::STRIDE_W: + return "STRIDE_W"; + case LegacyPMParameter::PADDING_H: + return "PADDING_H"; + case LegacyPMParameter::PADDING_W: + return "PADDING_W"; + case LegacyPMParameter::ACTI: + return "ACTI"; + case LegacyPMParameter::NUMDIM: + return "NUMDIM"; + case LegacyPMParameter::AXIS: + return "AXIS"; + case LegacyPMParameter::PERM: + return "PERM"; + case LegacyPMParameter::OUTSHUFFLE: + return "OUTSHUFFLE"; + case LegacyPMParameter::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case LegacyPMParameter::AXES: + return "AXES"; + case LegacyPMParameter::KEEP_DIMS: + return "KEEP_DIMS"; + case LegacyPMParameter::EPSILON: + return "EPSILON"; + case LegacyPMParameter::REPARTITION_DIM: + return "REPARTITION_DIM"; + case LegacyPMParameter::REPARTITION_DEGREE: + return "REPARTITION_DEGREE"; + case LegacyPMParameter::REPLICATE_DIM: + return "REPLICATE_DIM"; + case LegacyPMParameter::REPLICATE_DEGREE: + return "REPLICATE_DEGREE"; + case LegacyPMParameter::COMBINE_DIM: + return "COMBINE_DIM"; + case LegacyPMParameter::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case LegacyPMParameter::REDUCTION_DIM: + return "REDUCTION_DIM"; + case LegacyPMParameter::REDUCTION_DEGREE: + return "REDUCTION_DEGREE"; + case LegacyPMParameter::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case LegacyPMParameter::NUM_HEADS: + return "NUM_HEADS"; + case LegacyPMParameter::PARALLEL_DIM: + return "PARALLEL_DIM"; + case LegacyPMParameter::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case LegacyPMParameter::PAD: + return "PAD"; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, LegacyPMParameter x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, LegacyPMParameter x) { + switch (x) { + case LegacyPMParameter::OP_TYPE: + j = "PM_OP_TYPE"; + break; + case LegacyPMParameter::NUM_INPUTS: + j = "PM_NUM_INPUTS"; + break; + case LegacyPMParameter::NUM_OUTPUTS: + j = "PM_NUM_OUTPUTS"; + break; + case LegacyPMParameter::GROUP: + j = "PM_GROUP"; + break; + case LegacyPMParameter::KERNEL_H: + j = "PM_KERNEL_H"; + break; + case LegacyPMParameter::KERNEL_W: + j = "PM_KERNEL_W"; + break; + case LegacyPMParameter::STRIDE_H: + j = "PM_STRIDE_H"; + break; + case LegacyPMParameter::STRIDE_W: + j = "PM_STRIDE_W"; + break; + case LegacyPMParameter::PADDING_H: + j = "PM_PADDING_H"; + break; + case LegacyPMParameter::PADDING_W: + j = "PM_PADDING_W"; + break; + case LegacyPMParameter::ACTI: + j = "PM_ACTI"; + break; + case LegacyPMParameter::NUMDIM: + j = "PM_NUMDIM"; + break; + case LegacyPMParameter::AXIS: + j = "PM_AXIS"; + break; + case LegacyPMParameter::PERM: + j = "PM_PERM"; + break; + case LegacyPMParameter::OUTSHUFFLE: + j = "PM_OUTSHUFFLE"; + break; + case LegacyPMParameter::MERGE_GCONV_COUNT: + j = "PM_MERGE_GCONV_COUNT"; + break; + case LegacyPMParameter::AXES: + j = "PM_AXES"; + break; + case LegacyPMParameter::KEEP_DIMS: + j = "PM_KEEP_DIMS"; + break; + case LegacyPMParameter::EPSILON: + j = "PM_EPSILON"; + break; + case LegacyPMParameter::REPARTITION_DIM: + j = "PM_REPARTITION_DIM"; + break; + case LegacyPMParameter::REPARTITION_DEGREE: + j = "PM_REPARTITION_DEGREE"; + break; + case LegacyPMParameter::REPLICATE_DIM: + j = "PM_REPLICATE_DIM"; + break; + case LegacyPMParameter::REPLICATE_DEGREE: + j = "PM_REPLICATE_DEGREE"; + break; + case LegacyPMParameter::COMBINE_DIM: + j = "PM_COMBINE_DIM"; + break; + case LegacyPMParameter::COMBINE_DEGREE: + j = "PM_COMBINE_DEGREE"; + break; + case LegacyPMParameter::REDUCTION_DIM: + j = "PM_REDUCTION_DIM"; + break; + case LegacyPMParameter::REDUCTION_DEGREE: + j = "PM_REDUCTION_DEGREE"; + break; + case LegacyPMParameter::SOFTMAX_DIM: + j = "PM_SOFTMAX_DIM"; + break; + case LegacyPMParameter::NUM_HEADS: + j = "PM_NUM_HEADS"; + break; + case LegacyPMParameter::PARALLEL_DIM: + j = "PM_PARALLEL_DIM"; + break; + case LegacyPMParameter::PARALLEL_DEGREE: + j = "PM_PARALLEL_DEGREE"; + break; + case LegacyPMParameter::PAD: + j = "PM_PAD"; + break; + default: + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, LegacyPMParameter &x) { + std::string as_str = j.get(); + if (as_str == "PM_OP_TYPE") { + x = LegacyPMParameter::OP_TYPE; + } else if (as_str == "PM_NUM_INPUTS") { + x = LegacyPMParameter::NUM_INPUTS; + } else if (as_str == "PM_NUM_OUTPUTS") { + x = LegacyPMParameter::NUM_OUTPUTS; + } else if (as_str == "PM_GROUP") { + x = LegacyPMParameter::GROUP; + } else if (as_str == "PM_KERNEL_H") { + x = LegacyPMParameter::KERNEL_H; + } else if (as_str == "PM_KERNEL_W") { + x = LegacyPMParameter::KERNEL_W; + } else if (as_str == "PM_STRIDE_H") { + x = LegacyPMParameter::STRIDE_H; + } else if (as_str == "PM_STRIDE_W") { + x = LegacyPMParameter::STRIDE_W; + } else if (as_str == "PM_PADDING_H") { + x = LegacyPMParameter::PADDING_H; + } else if (as_str == "PM_PADDING_W") { + x = LegacyPMParameter::PADDING_W; + } else if (as_str == "PM_ACTI") { + x = LegacyPMParameter::ACTI; + } else if (as_str == "PM_NUMDIM") { + x = LegacyPMParameter::NUMDIM; + } else if (as_str == "PM_AXIS") { + x = LegacyPMParameter::AXIS; + } else if (as_str == "PM_PERM") { + x = LegacyPMParameter::PERM; + } else if (as_str == "PM_OUTSHUFFLE") { + x = LegacyPMParameter::OUTSHUFFLE; + } else if (as_str == "PM_MERGE_GCONV_COUNT") { + x = LegacyPMParameter::MERGE_GCONV_COUNT; + } else if (as_str == "PM_AXES") { + x = LegacyPMParameter::AXES; + } else if (as_str == "PM_KEEP_DIMS") { + x = LegacyPMParameter::KEEP_DIMS; + } else if (as_str == "PM_EPSILON") { + x = LegacyPMParameter::EPSILON; + } else if (as_str == "PM_REPARTITION_DIM") { + x = LegacyPMParameter::REPARTITION_DIM; + } else if (as_str == "PM_REPARTITION_DEGREE") { + x = LegacyPMParameter::REPARTITION_DEGREE; + } else if (as_str == "PM_REPLICATE_DIM") { + x = LegacyPMParameter::REPLICATE_DIM; + } else if (as_str == "PM_REPLICATE_DEGREE") { + x = LegacyPMParameter::REPLICATE_DEGREE; + } else if (as_str == "PM_COMBINE_DIM") { + x = LegacyPMParameter::COMBINE_DIM; + } else if (as_str == "PM_COMBINE_DEGREE") { + x = LegacyPMParameter::COMBINE_DEGREE; + } else if (as_str == "PM_REDUCTION_DIM") { + x = LegacyPMParameter::REDUCTION_DIM; + } else if (as_str == "PM_REDUCTION_DEGREE") { + x = LegacyPMParameter::REDUCTION_DEGREE; + } else if (as_str == "PM_SOFTMAX_DIM") { + x = LegacyPMParameter::SOFTMAX_DIM; + } else if (as_str == "PM_NUM_HEADS") { + x = LegacyPMParameter::NUM_HEADS; + } else if (as_str == "PM_PARALLEL_DIM") { + x = LegacyPMParameter::PARALLEL_DIM; + } else if (as_str == "PM_PARALLEL_DEGREE") { + x = LegacyPMParameter::PARALLEL_DEGREE; + } else if (as_str == "PM_PAD") { + x = LegacyPMParameter::PAD; + } else { + std::ostringstream oss; + oss << "Unknown LegacyPMParameter value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::LegacyPMParameter::OP_TYPE, + FlexFlow::LegacyPMParameter::NUM_INPUTS, + FlexFlow::LegacyPMParameter::NUM_OUTPUTS, + FlexFlow::LegacyPMParameter::GROUP, + FlexFlow::LegacyPMParameter::KERNEL_H, + FlexFlow::LegacyPMParameter::KERNEL_W, + FlexFlow::LegacyPMParameter::STRIDE_H, + FlexFlow::LegacyPMParameter::STRIDE_W, + FlexFlow::LegacyPMParameter::PADDING_H, + FlexFlow::LegacyPMParameter::PADDING_W, + FlexFlow::LegacyPMParameter::ACTI, + FlexFlow::LegacyPMParameter::NUMDIM, + FlexFlow::LegacyPMParameter::AXIS, + FlexFlow::LegacyPMParameter::PERM, + FlexFlow::LegacyPMParameter::OUTSHUFFLE, + FlexFlow::LegacyPMParameter::MERGE_GCONV_COUNT, + FlexFlow::LegacyPMParameter::AXES, + FlexFlow::LegacyPMParameter::KEEP_DIMS, + FlexFlow::LegacyPMParameter::EPSILON, + FlexFlow::LegacyPMParameter::REPARTITION_DIM, + FlexFlow::LegacyPMParameter::REPARTITION_DEGREE, + FlexFlow::LegacyPMParameter::REPLICATE_DIM, + FlexFlow::LegacyPMParameter::REPLICATE_DEGREE, + FlexFlow::LegacyPMParameter::COMBINE_DIM, + FlexFlow::LegacyPMParameter::COMBINE_DEGREE, + FlexFlow::LegacyPMParameter::REDUCTION_DIM, + FlexFlow::LegacyPMParameter::REDUCTION_DEGREE, + FlexFlow::LegacyPMParameter::SOFTMAX_DIM, + FlexFlow::LegacyPMParameter::NUM_HEADS, + FlexFlow::LegacyPMParameter::PARALLEL_DIM, + FlexFlow::LegacyPMParameter::PARALLEL_DEGREE, + FlexFlow::LegacyPMParameter::PAD); +} +} // namespace rc diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/json.cc index d12b294a2e..befdaf1308 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/json.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { Operator o; from_json(j, o); - CHECK(o.op_type == Op::EW_ADD); + CHECK(o.op_type == LegacyOperatorType::EW_ADD); CHECK(o.input.size() == 2); CHECK(o.input[0].opId == -2); CHECK(o.input[0].tsId == 0); diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h deleted file mode 100644 index 0afd48b431..0000000000 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H -#define _FLEXFLOW_SUBSTITUTIONS_CONSTRAINT_H - -#include "utils/variant.h" - -namespace FlexFlow { - -enum class ConstraintType { EQUAL }; - -template -struct ListIndexAccess { - T attribute_key; - req index; -}; - -template -struct ListSize { - req attribute_key; -}; - -template -using AttributeExpr = std::variant, ListSize>; - -template -struct AttributeConstraint { - ConstraintType constraint_type; - AttributeExpr attribute_expr; - V attribute_value; -}; - -template -struct AttributePattern { - std::vector> attribute_constraints; - // TODO: Revert to unordered_set once we have visitable for templates - // std::unordered_set> attribute_constraints; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/constraint_type.dtg.h b/lib/substitutions/include/substitutions/constraint_type.dtg.h new file mode 100644 index 0000000000..f99b794e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class ConstraintType { EQUAL }; +std::string format_as(ConstraintType); +std::ostream &operator<<(std::ostream &, ConstraintType); +void to_json(::nlohmann::json &, ConstraintType); +void from_json(::nlohmann::json const &, ConstraintType &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ConstraintType) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_CONSTRAINT_TYPE_DTG_H diff --git a/lib/substitutions/include/substitutions/constraint_type.enum.toml b/lib/substitutions/include/substitutions/constraint_type.enum.toml new file mode 100644 index 0000000000..8646ba1c83 --- /dev/null +++ b/lib/substitutions/include/substitutions/constraint_type.enum.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "ConstraintType" +features = [ + "json", + "hash", + "rapidcheck", + "fmt", +] + +[[values]] +name = "EQUAL" diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index 4f4021203b..5f03a6e92e 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -1,32 +1,24 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTIONS_H -#include "graph_pattern_match.h" -#include "operator_pattern.h" -#include "parallel_tensor_pattern.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { -struct GraphPattern - : public strong_typedef< - GraphPattern, - OutputLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &); -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern); - -bool is_singleton_pattern(OpenMultiDiGraphView const &); - -bool operator_satisfies(Operator const ¶ms, OperatorPattern const &pattern); - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern); +TensorAttributePattern get_tensor_pattern(PCGPattern const &, + PatternEdge const &); +OperatorAttributePattern get_operator_pattern(PCGPattern const &, + PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, - GraphPattern const &, + PCGPattern const &, MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h deleted file mode 100644 index bf6d6b6921..0000000000 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H -#define _FLEXFLOW_SUBSTITUTIONS_GRAPH_PATTERN_MATCH_H - -#include "utils/graph.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -struct MultiDiGraphPatternMatch { - using PatternNode = Node; - using PCGNode = Node; - using PatternEdge = OpenMultiDiEdge; - using PCGEdge = OpenMultiDiEdge; - - bidict node_assignment; - bidict edge_assignment; -}; - -struct MatchSplit { - MultiDiGraphPatternMatch prefix_submatch; - MultiDiGraphPatternMatch postfix_submatch; -}; - -struct MatchAdditionalCriterion { - std::function node_criterion; - std::function - edge_criterion; -}; - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion); - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h deleted file mode 100644 index 8fc4ebefc2..0000000000 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ /dev/null @@ -1,107 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_OPERATOR_PATTERN_H - -#include "attribute_expr.h" -#include "op-attrs/activation.h" -#include "op-attrs/datatype.h" -#include "op-attrs/op.h" -#include "pcg/operator.h" -#include -#include - -namespace FlexFlow { - -enum class OperatorAttributeKey { - OP_TYPE, // AnyOp - USE_BIAS, - GROUPS, - POOL_TYPE, - KERNEL_H, - KERNEL_W, - DATA_TYPE, - SCALAR, - STRIDE_H, - STRIDE_W, - PADDING_H, - PADDING_W, - AGGR, - NUM_ENTRIES, - OUT_CHANNELS, - ACTIVATION, - NUMDIM, - AXIS, - PERMUTATION, - OUTSHUFFLE, - MERGE_GCONV_COUNT, - AXES, - KEEP_DIMS, - EPSILON, - PARALLEL_OP_DIM, - PARALLEL_OP_DEGREE, - SOFTMAX_DIM, - NUM_HEADS, - PARALLEL_DIM, - PARALLEL_DEGREE, - PAD, - EMBED_DIM, - KDIM, - VDIM, - DROPOUT, - BIAS, - ADD_BIAS_KV, - ADD_ZERO_ATTN, - A_SEQ_LENGTH_DIM, - B_SEQ_LENGTH_DIM, - RELU, - TARGET_DIMS, - RATE, - SEED, - SHOULD_BROADCAST_LHS, - SHOULD_BROADCAST_RHS, - DIM, - ELEMENTWISE_AFFINE, - REGULARIZER, - SHAPE, - SPLITS, - K, - SORTED, - COMBINE_DIM, - COMBINE_DEGREE, - NUM_INPUTS -}; - -using OperatorAttributeValue = - std::variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - std::optional, - PoolOp, - TensorShape, - DataType>; - -FF_VISITABLE_STRUCT(ListIndexAccess, - attribute_key, - index); -FF_VISITABLE_STRUCT(ListSize, attribute_key); - -using OperatorAttributeConstraint = - AttributeConstraint; - -using OperatorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h new file mode 100644 index 0000000000..93d2d56384 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_access.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h new file mode 100644 index 0000000000..236a248945 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/eval_list_size.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +namespace FlexFlow { + +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h similarity index 83% rename from lib/substitutions/include/substitutions/get_attribute.h rename to lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index 0e6dd4c69b..93f4a2bc0f 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H #define _FLEXFLOW_SUBSTITUTIONS_GET_ATTRIBUTES_H -#include "op-attrs/operator_attrs.h" -#include "operator_pattern.h" -#include "utils/optional.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include namespace FlexFlow { @@ -11,6 +12,8 @@ std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey); std::optional get_attribute(CombineAttrs const &p, @@ -33,12 +36,17 @@ std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey); std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey); std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey); std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReduceAttrs const &p, @@ -51,6 +59,8 @@ std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey); std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey); std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey); std::optional get_attribute(SoftmaxAttrs const &p, diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h new file mode 100644 index 0000000000..35ec9e499f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeConstraint { + OperatorAttributeConstraint() = delete; + OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value); + + bool operator==(OperatorAttributeConstraint const &) const; + bool operator!=(OperatorAttributeConstraint const &) const; + bool operator<(OperatorAttributeConstraint const &) const; + bool operator>(OperatorAttributeConstraint const &) const; + bool operator<=(OperatorAttributeConstraint const &) const; + bool operator>=(OperatorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::OperatorAttributeExpr attribute_expr; + ::FlexFlow::OperatorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml new file mode 100644 index 0000000000..646faf878e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "OperatorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::OperatorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..a66a035ba8 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.dtg.h @@ -0,0 +1,143 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeExpr { + OperatorAttributeExpr() = delete; + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeKey const &); + explicit OperatorAttributeExpr(::FlexFlow::OperatorAttributeListSize const &); + explicit OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::OperatorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::OperatorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOperatorAttributeExpr_v, + "OperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OperatorAttributeKey, " + "::FlexFlow::OperatorAttributeListSize, " + "::FlexFlow::OperatorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeExpr const &) const; + bool operator!=(OperatorAttributeExpr const &) const; + bool operator<(OperatorAttributeExpr const &) const; + bool operator>(OperatorAttributeExpr const &) const; + bool operator<=(OperatorAttributeExpr const &) const; + bool operator>=(OperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeExpr> { + size_t operator()(::FlexFlow::OperatorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeExpr> { + static ::FlexFlow::OperatorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h new file mode 100644 index 0000000000..4528847771 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_EXPR_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr); +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..ff79ecaaa5 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "OperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_access.dtg.h", + "substitutions/operator_pattern/operator_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::OperatorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::OperatorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::OperatorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h new file mode 100644 index 0000000000..49a5ccbbe6 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.dtg.h @@ -0,0 +1,97 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class OperatorAttributeKey { + OP_TYPE, + USE_BIAS, + GROUPS, + POOL_TYPE, + KERNEL_H, + KERNEL_W, + DATA_TYPE, + SCALAR, + STRIDE_H, + STRIDE_W, + PADDING_H, + PADDING_W, + AGGR, + NUM_ENTRIES, + OUT_CHANNELS, + ACTIVATION, + NUMDIM, + AXIS, + PERMUTATION, + OUTSHUFFLE, + MERGE_GCONV_COUNT, + AXES, + KEEP_DIMS, + EPSILON, + PARALLEL_OP_DIM, + PARALLEL_OP_DEGREE, + SOFTMAX_DIM, + NUM_HEADS, + PARALLEL_DIM, + PARALLEL_DEGREE, + PAD, + EMBED_DIM, + KDIM, + VDIM, + DROPOUT, + BIAS, + ADD_BIAS_KV, + ADD_ZERO_ATTN, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + RELU, + TARGET_DIMS, + RATE, + SEED, + SHOULD_BROADCAST_LHS, + SHOULD_BROADCAST_RHS, + DIM, + ELEMENTWISE_AFFINE, + REGULARIZER, + SHAPE, + SPLITS, + K, + SORTED, + COMBINE_DIM, + COMBINE_DEGREE, + NUM_INPUTS +}; +std::string format_as(OperatorAttributeKey); +std::ostream &operator<<(std::ostream &, OperatorAttributeKey); +void to_json(::nlohmann::json &, OperatorAttributeKey); +void from_json(::nlohmann::json const &, OperatorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml new file mode 100644 index 0000000000..59e913750e --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -0,0 +1,67 @@ +namespace = "FlexFlow" +name = "OperatorAttributeKey" +features = [ + "json", + "hash", + "fmt", + "rapidcheck", +] + +values = [ + { name = "OP_TYPE" }, + { name = "USE_BIAS" }, + { name = "GROUPS" }, + { name = "POOL_TYPE" }, + { name = "KERNEL_H" }, + { name = "KERNEL_W" }, + { name = "DATA_TYPE" }, + { name = "SCALAR" }, + { name = "STRIDE_H" }, + { name = "STRIDE_W" }, + { name = "PADDING_H" }, + { name = "PADDING_W" }, + { name = "AGGR" }, + { name = "NUM_ENTRIES" }, + { name = "OUT_CHANNELS" }, + { name = "ACTIVATION" }, + { name = "NUMDIM" }, + { name = "AXIS" }, + { name = "PERMUTATION" }, + { name = "OUTSHUFFLE" }, + { name = "MERGE_GCONV_COUNT" }, + { name = "AXES" }, + { name = "KEEP_DIMS" }, + { name = "EPSILON" }, + { name = "PARALLEL_OP_DIM" }, + { name = "PARALLEL_OP_DEGREE" }, + { name = "SOFTMAX_DIM" }, + { name = "NUM_HEADS" }, + { name = "PARALLEL_DIM" }, + { name = "PARALLEL_DEGREE" }, + { name = "PAD" }, + { name = "EMBED_DIM" }, + { name = "KDIM" }, + { name = "VDIM" }, + { name = "DROPOUT" }, + { name = "BIAS" }, + { name = "ADD_BIAS_KV" }, + { name = "ADD_ZERO_ATTN" }, + { name = "A_SEQ_LENGTH_DIM" }, + { name = "B_SEQ_LENGTH_DIM" }, + { name = "RELU" }, + { name = "TARGET_DIMS" }, + { name = "RATE" }, + { name = "SEED" }, + { name = "SHOULD_BROADCAST_LHS" }, + { name = "SHOULD_BROADCAST_RHS" }, + { name = "DIM" }, + { name = "ELEMENTWISE_AFFINE" }, + { name = "REGULARIZER" }, + { name = "SHAPE" }, + { name = "SPLITS" }, + { name = "K" }, + { name = "SORTED" }, + { name = "COMBINE_DIM" }, + { name = "COMBINE_DEGREE" }, + { name = "NUM_INPUTS" }, +] diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h new file mode 100644 index 0000000000..5a30c40f8d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.dtg.h @@ -0,0 +1,67 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListIndexAccess { + OperatorAttributeListIndexAccess() = delete; + OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index); + + bool operator==(OperatorAttributeListIndexAccess const &) const; + bool operator!=(OperatorAttributeListIndexAccess const &) const; + bool operator<(OperatorAttributeListIndexAccess const &) const; + bool operator>(OperatorAttributeListIndexAccess const &) const; + bool operator<=(OperatorAttributeListIndexAccess const &) const; + bool operator>=(OperatorAttributeListIndexAccess const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, + FlexFlow::OperatorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + OperatorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml new file mode 100644 index 0000000000..bceff393d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h new file mode 100644 index 0000000000..17d76a08f1 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.dtg.h @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeListSize { + OperatorAttributeListSize() = delete; + OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key); + + bool operator==(OperatorAttributeListSize const &) const; + bool operator!=(OperatorAttributeListSize const &) const; + bool operator<(OperatorAttributeListSize const &) const; + bool operator>(OperatorAttributeListSize const &) const; + bool operator<=(OperatorAttributeListSize const &) const; + bool operator>=(OperatorAttributeListSize const &) const; + ::FlexFlow::OperatorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, OperatorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml new file mode 100644 index 0000000000..271b545fda --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "OperatorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::OperatorAttributeKey" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h new file mode 100644 index 0000000000..7bce198f3d --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributePattern { + OperatorAttributePattern() = delete; + OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints); + + bool operator==(OperatorAttributePattern const &) const; + bool operator!=(OperatorAttributePattern const &) const; + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OperatorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::OperatorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::OperatorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &); +std::ostream &operator<<(std::ostream &, OperatorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml new file mode 100644 index 0000000000..6facf7d3bc --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "OperatorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "utils/fmt.h", + "substitutions/operator_pattern/operator_attribute_constraint.dtg.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::OperatorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h new file mode 100644 index 0000000000..080909d147 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.dtg.h @@ -0,0 +1,264 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "op-attrs/activation.dtg.h" +#include "op-attrs/aggregate_op.dtg.h" +#include "op-attrs/datatype.dtg.h" +#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/operator_type.dtg.h" +#include "op-attrs/pool_op.dtg.h" +#include "op-attrs/regularizer_attrs.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" +#include +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OperatorAttributeValue { + OperatorAttributeValue() = delete; + OperatorAttributeValue(int const &); + OperatorAttributeValue(bool const &); + OperatorAttributeValue(std::vector const &); + OperatorAttributeValue(std::vector<::FlexFlow::ff_dim_t> const &); + OperatorAttributeValue(::FlexFlow::OperatorType const &); + OperatorAttributeValue(::FlexFlow::Activation const &); + OperatorAttributeValue(::FlexFlow::ff_dim_t const &); + OperatorAttributeValue(size_t const &); + OperatorAttributeValue(::FlexFlow::AggregateOp const &); + OperatorAttributeValue(std::optional<::FlexFlow::RegularizerAttrs> const &); + OperatorAttributeValue(::FlexFlow::PoolOp const &); + OperatorAttributeValue(::FlexFlow::TensorShape const &); + OperatorAttributeValue(::FlexFlow::DataType const &); + template + static constexpr bool IsPartOfOperatorAttributeValue_v = + std::is_same_v || std::is_same_v || + std::is_same_v> || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get()); + return result; + } + case 2: { + ReturnType result = v(this->get>()); + return result; + } + case 3: { + ReturnType result = v(this->get>()); + return result; + } + case 4: { + ReturnType result = v(this->get<::FlexFlow::OperatorType>()); + return result; + } + case 5: { + ReturnType result = v(this->get<::FlexFlow::Activation>()); + return result; + } + case 6: { + ReturnType result = v(this->get<::FlexFlow::ff_dim_t>()); + return result; + } + case 7: { + ReturnType result = v(this->get()); + return result; + } + case 8: { + ReturnType result = v(this->get<::FlexFlow::AggregateOp>()); + return result; + } + case 9: { + ReturnType result = + v(this->get>()); + return result; + } + case 10: { + ReturnType result = v(this->get<::FlexFlow::PoolOp>()); + return result; + } + case 11: { + ReturnType result = v(this->get<::FlexFlow::TensorShape>()); + return result; + } + case 12: { + ReturnType result = v(this->get<::FlexFlow::DataType>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::has() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfOperatorAttributeValue_v, + "OperatorAttributeValue::get() expected one of [int, bool, " + "std::vector, std::vector<::FlexFlow::ff_dim_t>, " + "::FlexFlow::OperatorType, ::FlexFlow::Activation, " + "::FlexFlow::ff_dim_t, size_t, ::FlexFlow::AggregateOp, " + "std::optional<::FlexFlow::RegularizerAttrs>, ::FlexFlow::PoolOp, " + "::FlexFlow::TensorShape, ::FlexFlow::DataType], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OperatorAttributeValue const &) const; + bool operator!=(OperatorAttributeValue const &) const; + bool operator<(OperatorAttributeValue const &) const; + bool operator>(OperatorAttributeValue const &) const; + bool operator<=(OperatorAttributeValue const &) const; + bool operator>=(OperatorAttributeValue const &) const; + std::variant, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OperatorAttributeValue> { + size_t operator()(::FlexFlow::OperatorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::OperatorAttributeValue> { + static ::FlexFlow::OperatorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::OperatorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OperatorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml new file mode 100644 index 0000000000..9ab88e63c2 --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -0,0 +1,63 @@ +namespace = "FlexFlow" +name = "OperatorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] +explicit_constructors = false + +includes = [ + "", + "", + "op-attrs/operator_type.dtg.h", + "op-attrs/ff_dim.dtg.h", + "op-attrs/activation.dtg.h", + "op-attrs/aggregate_op.dtg.h", + "op-attrs/regularizer_attrs.dtg.h", + "op-attrs/pool_op.dtg.h", + "op-attrs/tensor_shape.dtg.h", + "op-attrs/datatype.dtg.h", + "", +] + +[[values]] +type = "int" + +[[values]] +type = "bool" + +[[values]] +type = "std::vector" + +[[values]] +type = "std::vector<::FlexFlow::ff_dim_t>" + +[[values]] +type = "::FlexFlow::OperatorType" + +[[values]] +type = "::FlexFlow::Activation" + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "size_t" + +[[values]] +type = "::FlexFlow::AggregateOp" + +[[values]] +type = "std::optional<::FlexFlow::RegularizerAttrs>" + +[[values]] +type = "::FlexFlow::PoolOp" + +[[values]] +type = "::FlexFlow::TensorShape" + +[[values]] +type = "::FlexFlow::DataType" diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..7ddda2219c --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_constraint.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint( + PCGOperatorAttrs const ¶ms, + OperatorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..ca4d5c13fa --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/satisfies_pattern.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_SATISFIES_PATTERN_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h deleted file mode 100644 index 4ed90aed06..0000000000 --- a/lib/substitutions/include/substitutions/output_graph.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_OUTPUT_GRAPH_H - -#include "utils/graph.h" - -namespace FlexFlow { - -// NOTE(@wmdi) I am not sure whether these should be part of attribute expr. -struct OperatorAttrAccess { - Node node; - AttributeExpr attr_expr; -}; - -struct AttrConstant { - OperatorAttributeValue value; -}; - -using OperatorAttributeExpr = std::variant; - -// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can -// define the assignment for each operator type. -struct OperatorAttrAssignment { - std::unordered_map assignments; -}; - -struct OutputGraphExpr - : public strong_typedef< - OutputGraphExpr, - NodeLabelledOpenMultiDiGraph> { - using strong_typedef::strong_typedef; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h new file mode 100644 index 0000000000..9dd20bb10e --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.dtg.h @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct AttrConstant { + AttrConstant() = delete; + AttrConstant(::FlexFlow::OperatorAttributeValue const &value); + + bool operator==(AttrConstant const &) const; + bool operator!=(AttrConstant const &) const; + bool operator<(AttrConstant const &) const; + bool operator>(AttrConstant const &) const; + bool operator<=(AttrConstant const &) const; + bool operator>=(AttrConstant const &) const; + ::FlexFlow::OperatorAttributeValue value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AttrConstant const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &); +std::ostream &operator<<(std::ostream &, AttrConstant const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_ATTR_CONSTANT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml new file mode 100644 index 0000000000..68973f9c0c --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "AttrConstant" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_value.dtg.h", +] + +[[fields]] +name = "value" +type = "::FlexFlow::OperatorAttributeValue" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h new file mode 100644 index 0000000000..3d6fb21574 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.dtg.h @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct OutputGraphExpr { + OutputGraphExpr() = delete; + OutputGraphExpr(::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph); + + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml new file mode 100644 index 0000000000..37d87f7820 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "OutputGraphExpr" +features = [] + +includes = [ + "utils/graph.h", + "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::NodeLabelledOpenMultiDiGraph<::FlexFlow::OutputOperatorAttrsAssignment>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h new file mode 100644 index 0000000000..0d585f0aa0 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrAccess { + OutputOperatorAttrAccess() = delete; + OutputOperatorAttrAccess(::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr); + + bool operator==(OutputOperatorAttrAccess const &) const; + bool operator!=(OutputOperatorAttrAccess const &) const; + bool operator<(OutputOperatorAttrAccess const &) const; + bool operator>(OutputOperatorAttrAccess const &) const; + bool operator<=(OutputOperatorAttrAccess const &) const; + bool operator>=(OutputOperatorAttrAccess const &) const; + ::FlexFlow::Node node; + ::FlexFlow::OperatorAttributeExpr attr_expr; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrAccess const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTR_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml new file mode 100644 index 0000000000..51aae54730 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrAccess" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_expr.dtg.h", +] + +[[fields]] +name = "node" +type = "::FlexFlow::Node" + +# NOTE(@wmdi) I am not sure whether these should be part of attribute expr. +[[fields]] +name = "attr_expr" +type = "::FlexFlow::OperatorAttributeExpr" + diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h new file mode 100644 index 0000000000..327c230b61 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.dtg.h @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "substitutions/output_graph/attr_constant.dtg.h" +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttributeExpr { + OutputOperatorAttributeExpr() = delete; + explicit OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &); + explicit OutputOperatorAttributeExpr(::FlexFlow::AttrConstant const &); + template + static constexpr bool IsPartOfOutputOperatorAttributeExpr_v = + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = + v(this->get<::FlexFlow::OutputOperatorAttrAccess>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::AttrConstant>()); + return result; + } + default: { + throw std::runtime_error( + fmt::format("Unknown index {} for type OutputOperatorAttributeExpr", + this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::has() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfOutputOperatorAttributeExpr_v, + "OutputOperatorAttributeExpr::get() expected one of " + "[::FlexFlow::OutputOperatorAttrAccess, " + "::FlexFlow::AttrConstant], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(OutputOperatorAttributeExpr const &) const; + bool operator!=(OutputOperatorAttributeExpr const &) const; + bool operator<(OutputOperatorAttributeExpr const &) const; + bool operator>(OutputOperatorAttributeExpr const &) const; + bool operator<=(OutputOperatorAttributeExpr const &) const; + bool operator>=(OutputOperatorAttributeExpr const &) const; + std::variant<::FlexFlow::OutputOperatorAttrAccess, ::FlexFlow::AttrConstant> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::OutputOperatorAttributeExpr> { + size_t operator()(::FlexFlow::OutputOperatorAttributeExpr const &) const; +}; +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::OutputOperatorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml new file mode 100644 index 0000000000..19810a0151 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "substitutions/output_graph/attr_constant.dtg.h", + "substitutions/output_graph/output_operator_attr_access.dtg.h", +] + +[[values]] +type = "::FlexFlow::OutputOperatorAttrAccess" +key = "attr_ref" + +[[values]] +type = "::FlexFlow::AttrConstant" +key = "constant" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h new file mode 100644 index 0000000000..5586a90a08 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.dtg.h @@ -0,0 +1,49 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H + +#include "fmt/format.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct OutputOperatorAttrsAssignment { + OutputOperatorAttrsAssignment() = delete; + OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments); + + bool operator==(OutputOperatorAttrsAssignment const &) const; + bool operator!=(OutputOperatorAttrsAssignment const &) const; + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> + assignments; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputOperatorAttrsAssignment const &) const; +}; +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &); +std::ostream &operator<<(std::ostream &, OutputOperatorAttrsAssignment const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_DTG_H diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml new file mode 100644 index 0000000000..9781515803 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OutputOperatorAttrsAssignment" +features = [ + "eq", + # "ord", + "hash", + # "json", + "fmt", +] + +includes = [ + "substitutions/operator_pattern/operator_attribute_key.dtg.h", + "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "", +] + +# NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can +# define the assignment for each operator type. +[[fields]] +name = "assignments" +type = "std::unordered_map<::FlexFlow::OperatorAttributeKey, ::FlexFlow::OutputOperatorAttributeExpr>" diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h deleted file mode 100644 index 741554142f..0000000000 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H -#define _FLEXFLOW_SUBSTITUTIONS_TENSOR_PATTERN_H - -#include "attribute_expr.h" -#include "pcg/parallel_tensor.h" - -namespace FlexFlow { - -enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; - -using TensorAttributeValue = std::variant>; - -using TensorAttributeConstraint = - AttributeConstraint; - -using ParallelTensorPattern = - AttributePattern; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr); - -} // namespace FlexFlow - -#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern.dtg.h b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h new file mode 100644 index 0000000000..0c0cc41891 --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct PCGPattern { + PCGPattern() = delete; + PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/pcg_pattern.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml new file mode 100644 index 0000000000..191d66a38c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "PCGPattern" +features = [] +includes = [ + "utils/graph.h", + "substitutions/operator_pattern/operator_attribute_pattern.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::OperatorAttributePattern, ::FlexFlow::TensorAttributePattern>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h new file mode 100644 index 0000000000..d31d65d83b --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.dtg.h @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +struct SubParallelComputationGraph { + SubParallelComputationGraph() = delete; + SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph); + + ::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> + raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_DTG_H diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 0d6bfe7628..5d40f3f975 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -1,18 +1,17 @@ -#ifndef _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#define _FLEXFLOW_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H -#include "op-attrs/operator_attrs.h" -#include "pcg/machine_view.h" -#include "pcg/operator.h" -#include "pcg/parallel_tensor.h" -#include "utils/graph.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" namespace FlexFlow { -using SubParallelComputationGraph = - OutputLabelledOpenMultiDiGraph; - -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(SubParallelComputationGraph); +ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, + Node const &); +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, + Node const &); +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &, + OpenMultiDiEdge const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..1ba04b544c --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraph" +features = [ ] + +includes = [ + "pcg/parallel_layer_attrs.dtg.h", + "pcg/parallel_tensor_attrs.dtg.h", + "utils/graph.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OutputLabelledOpenMultiDiGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/substitution.dtg.h b/lib/substitutions/include/substitutions/substitution.dtg.h new file mode 100644 index 0000000000..5f50d9bafc --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.dtg.h @@ -0,0 +1,38 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +struct Substitution { + Substitution() = delete; + Substitution(::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output); + + ::FlexFlow::PCGPattern pcg_pattern; + ::FlexFlow::OutputGraphExpr output_graph_expr; + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge> + input_edge_match_to_output; + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> + output_edge_match_to_output; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_DTG_H diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 8dbe4e66cf..1aa2b2946b 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -1,24 +1,12 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H -#include "graph_pattern.h" -#include "output_graph.h" -#include "sub_parallel_computation_graph.h" +#include "sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" namespace FlexFlow { -struct Substitution { - using InputPatternInput = InputMultiDiEdge; - using InputPatternOutput = OutputMultiDiEdge; - using OutputPatternInput = InputMultiDiEdge; - using OutputPatternOutput = OutputMultiDiEdge; - - GraphPattern input_graph; - OutputGraphExpr output_graph_expr; - bidict input_mapping; - bidict output_mapping; -}; - bool is_valid_substitution(Substitution const &); SubParallelComputationGraph @@ -28,12 +16,4 @@ SubParallelComputationGraph } // namespace FlexFlow -namespace std { -template <> -struct hash { - size_t operator()(FlexFlow::Substitution const &) const; -}; - -}; // namespace std - #endif diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml new file mode 100644 index 0000000000..eb630e9308 --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "Substitution" +features = [] + +includes = [ + "substitutions/pcg_pattern.dtg.h", + "substitutions/output_graph/output_graph_expr.dtg.h", +] + +[[fields]] +name = "pcg_pattern" +type = "::FlexFlow::PCGPattern" + +[[fields]] +name = "output_graph_expr" +type = "::FlexFlow::OutputGraphExpr" + +[[fields]] +name = "input_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>" + +[[fields]] +name = "output_edge_match_to_output" +type = "::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::OutputMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h new file mode 100644 index 0000000000..e245e800b2 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_access.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_ACCESS_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h new file mode 100644 index 0000000000..de0d58e14f --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/eval_list_size.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_EVAL_LIST_SIZE_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h new file mode 100644 index 0000000000..eedca2da82 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/get_attribute.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_GET_ATTRIBUTE_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &, + TensorAttributeKey); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h new file mode 100644 index 0000000000..6c11b421a8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_constraint.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_CONSTRAINT_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const ¶ms, + TensorAttributeConstraint const &constraint); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h new file mode 100644 index 0000000000..b8b46669c6 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/satisfies_pattern.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_SATISFIES_PATTERN_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h new file mode 100644 index 0000000000..ba705a5d35 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h @@ -0,0 +1,62 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeConstraint { + TensorAttributeConstraint() = delete; + TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value); + + bool operator==(TensorAttributeConstraint const &) const; + bool operator!=(TensorAttributeConstraint const &) const; + bool operator<(TensorAttributeConstraint const &) const; + bool operator>(TensorAttributeConstraint const &) const; + bool operator<=(TensorAttributeConstraint const &) const; + bool operator>=(TensorAttributeConstraint const &) const; + ::FlexFlow::ConstraintType constraint_type; + ::FlexFlow::TensorAttributeExpr attribute_expr; + ::FlexFlow::TensorAttributeValue attribute_value; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeConstraint const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeConstraint from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeConstraint const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &); +std::ostream &operator<<(std::ostream &, TensorAttributeConstraint const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_CONSTRAINT_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml new file mode 100644 index 0000000000..6aba719e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "TensorAttributeConstraint" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "substitutions/constraint_type.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_value.dtg.h", +] + +[[fields]] +name = "constraint_type" +type = "::FlexFlow::ConstraintType" + +[[fields]] +name = "attribute_expr" +type = "::FlexFlow::TensorAttributeExpr" + +[[fields]] +name = "attribute_value" +type = "::FlexFlow::TensorAttributeValue" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h new file mode 100644 index 0000000000..d34be357c5 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.dtg.h @@ -0,0 +1,141 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeExpr { + TensorAttributeExpr() = delete; + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeKey const &); + explicit TensorAttributeExpr(::FlexFlow::TensorAttributeListSize const &); + explicit TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &); + template + static constexpr bool IsPartOfTensorAttributeExpr_v = + std::is_same_v || + std::is_same_v || + std::is_same_v; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeKey>()); + return result; + } + case 1: { + ReturnType result = v(this->get<::FlexFlow::TensorAttributeListSize>()); + return result; + } + case 2: { + ReturnType result = + v(this->get<::FlexFlow::TensorAttributeListIndexAccess>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", this->index())); + } + } + } + template + bool has() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::has() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert( + IsPartOfTensorAttributeExpr_v, + "TensorAttributeExpr::get() expected one of " + "[::FlexFlow::TensorAttributeKey, ::FlexFlow::TensorAttributeListSize, " + "::FlexFlow::TensorAttributeListIndexAccess], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeExpr const &) const; + bool operator!=(TensorAttributeExpr const &) const; + bool operator<(TensorAttributeExpr const &) const; + bool operator>(TensorAttributeExpr const &) const; + bool operator<=(TensorAttributeExpr const &) const; + bool operator>=(TensorAttributeExpr const &) const; + std::variant<::FlexFlow::TensorAttributeKey, + ::FlexFlow::TensorAttributeListSize, + ::FlexFlow::TensorAttributeListIndexAccess> + raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeExpr> { + size_t operator()(::FlexFlow::TensorAttributeExpr const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeExpr> { + static ::FlexFlow::TensorAttributeExpr from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeExpr const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeExpr const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h new file mode 100644 index 0000000000..98d4394530 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_EXPR_H + +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +namespace FlexFlow { + +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml new file mode 100644 index 0000000000..03ec0eb624 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "TensorAttributeExpr" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h", + "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h", +] + +[[values]] +type = "::FlexFlow::TensorAttributeKey" +key = "key" + +[[values]] +type = "::FlexFlow::TensorAttributeListSize" +key = "list_size" + +[[values]] +type = "::FlexFlow::TensorAttributeListIndexAccess" +key = "list_idx" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h new file mode 100644 index 0000000000..50a0aa49e8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.dtg.h @@ -0,0 +1,40 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include +#include +#include + +namespace FlexFlow { +enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; +std::string format_as(TensorAttributeKey); +std::ostream &operator<<(std::ostream &, TensorAttributeKey); +void to_json(::nlohmann::json &, TensorAttributeKey); +void from_json(::nlohmann::json const &, TensorAttributeKey &); +} // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeKey) const; +}; +} // namespace std +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_KEY_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml new file mode 100644 index 0000000000..3df36d13ac --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "TensorAttributeKey" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "DIM_SIZES" + +[[values]] +name = "DIM_DEGREES" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h new file mode 100644 index 0000000000..473f4e1698 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h @@ -0,0 +1,66 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListIndexAccess { + TensorAttributeListIndexAccess() = delete; + TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index); + + bool operator==(TensorAttributeListIndexAccess const &) const; + bool operator!=(TensorAttributeListIndexAccess const &) const; + bool operator<(TensorAttributeListIndexAccess const &) const; + bool operator>(TensorAttributeListIndexAccess const &) const; + bool operator<=(TensorAttributeListIndexAccess const &) const; + bool operator>=(TensorAttributeListIndexAccess const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; + int index; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListIndexAccess const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListIndexAccess from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListIndexAccess const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &); +std::ostream &operator<<(std::ostream &, + TensorAttributeListIndexAccess const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_ACCESS_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml new file mode 100644 index 0000000000..a57dd25845 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "TensorAttributeListIndexAccess" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +] + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" + +[[fields]] +name = "index" +type = "int" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h new file mode 100644 index 0000000000..1630014bdf --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h @@ -0,0 +1,63 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "rapidcheck.h" +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeListSize { + TensorAttributeListSize() = delete; + TensorAttributeListSize(::FlexFlow::TensorAttributeKey const &attribute_key); + + bool operator==(TensorAttributeListSize const &) const; + bool operator!=(TensorAttributeListSize const &) const; + bool operator<(TensorAttributeListSize const &) const; + bool operator>(TensorAttributeListSize const &) const; + bool operator<=(TensorAttributeListSize const &) const; + bool operator>=(TensorAttributeListSize const &) const; + ::FlexFlow::TensorAttributeKey attribute_key; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributeListSize const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributeListSize from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributeListSize const &); +}; +} // namespace nlohmann + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &); +std::ostream &operator<<(std::ostream &, TensorAttributeListSize const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_LIST_SIZE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml new file mode 100644 index 0000000000..c876696343 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "TensorAttributeListSize" +features = [ + "eq", + "ord", + "hash", + "rapidcheck", + "json", + "fmt", +] + +includes = [ + "substitutions/tensor_pattern/tensor_attribute_key.dtg.h", +] + + +[[fields]] +name = "attribute_key" +type = "::FlexFlow::TensorAttributeKey" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h new file mode 100644 index 0000000000..ecc4bc7da0 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h @@ -0,0 +1,56 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributePattern { + TensorAttributePattern() = delete; + TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints); + + bool operator==(TensorAttributePattern const &) const; + bool operator!=(TensorAttributePattern const &) const; + std::unordered_set<::FlexFlow::TensorAttributeConstraint> + attribute_constraints; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::TensorAttributePattern const &) const; +}; +} // namespace std + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::TensorAttributePattern from_json(json const &); + static void to_json(json &, FlexFlow::TensorAttributePattern const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &); +std::ostream &operator<<(std::ostream &, TensorAttributePattern const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml new file mode 100644 index 0000000000..43f45e95b9 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TensorAttributePattern" +features = [ + "eq", + # "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h", + "utils/hash-utils.h", +] + +[[fields]] +name = "attribute_constraints" +type = "std::unordered_set<::FlexFlow::TensorAttributeConstraint>" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h new file mode 100644 index 0000000000..948a7abae6 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.dtg.h @@ -0,0 +1,118 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "utils/fmt.h" +#include "utils/hash-utils-core.h" +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { +struct TensorAttributeValue { + TensorAttributeValue() = delete; + explicit TensorAttributeValue(size_t const &); + explicit TensorAttributeValue(std::vector const &); + template + static constexpr bool IsPartOfTensorAttributeValue_v = + std::is_same_v || std::is_same_v>; + template + ReturnType visit(Visitor &&v) const { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + ReturnType visit(Visitor &&v) { + switch (this->index()) { + case 0: { + ReturnType result = v(this->get()); + return result; + } + case 1: { + ReturnType result = v(this->get>()); + return result; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", this->index())); + } + } + } + template + bool has() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::has() expected one of [size_t, " + "std::vector], received T"); + return std::holds_alternative(this->raw_variant); + } + template + T const &get() const { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + template + T &get() { + static_assert(IsPartOfTensorAttributeValue_v, + "TensorAttributeValue::get() expected one of [size_t, " + "std::vector], received T"); + return std::get(this->raw_variant); + } + size_t index() const { + return this->raw_variant.index(); + } + bool operator==(TensorAttributeValue const &) const; + bool operator!=(TensorAttributeValue const &) const; + bool operator<(TensorAttributeValue const &) const; + bool operator>(TensorAttributeValue const &) const; + bool operator<=(TensorAttributeValue const &) const; + bool operator>=(TensorAttributeValue const &) const; + std::variant> raw_variant; +}; +} // namespace FlexFlow +namespace std { +template <> +struct hash<::FlexFlow::TensorAttributeValue> { + size_t operator()(::FlexFlow::TensorAttributeValue const &) const; +}; +} // namespace std +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::TensorAttributeValue> { + static ::FlexFlow::TensorAttributeValue from_json(json const &); + static void to_json(json &, ::FlexFlow::TensorAttributeValue const &); +}; +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &); +std::ostream &operator<<(std::ostream &, + ::FlexFlow::TensorAttributeValue const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_VALUE_DTG_H diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml new file mode 100644 index 0000000000..91313f159b --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorAttributeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "", + "utils/hash-utils-core.h", + "utils/fmt.h", +] + +[[values]] +type = "size_t" + +[[values]] +type = "std::vector" diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h new file mode 100644 index 0000000000..6bf815791d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct ClosedPatternEdge { + ClosedPatternEdge() = delete; + ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge); + + bool operator==(ClosedPatternEdge const &) const; + bool operator!=(ClosedPatternEdge const &) const; + bool operator<(ClosedPatternEdge const &) const; + bool operator>(ClosedPatternEdge const &) const; + bool operator<=(ClosedPatternEdge const &) const; + bool operator>=(ClosedPatternEdge const &) const; + ::FlexFlow::MultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::ClosedPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_CLOSED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml new file mode 100644 index 0000000000..d609ca1c27 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "ClosedPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::MultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..5ce0e63073 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct DownwardOpenPatternEdge { + DownwardOpenPatternEdge() = delete; + DownwardOpenPatternEdge(::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge); + + bool operator==(DownwardOpenPatternEdge const &) const; + bool operator!=(DownwardOpenPatternEdge const &) const; + bool operator<(DownwardOpenPatternEdge const &) const; + bool operator>(DownwardOpenPatternEdge const &) const; + bool operator<=(DownwardOpenPatternEdge const &) const; + bool operator>=(DownwardOpenPatternEdge const &) const; + ::FlexFlow::DownwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::DownwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h new file mode 100644 index 0000000000..9855d96e46 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_DOWNWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..2dda7498f0 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "DownwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DownwardOpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h new file mode 100644 index 0000000000..e92fe547b1 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H + +#include "utils/bidict.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UnlabelledPatternEdgeSplits { + UnlabelledPatternEdgeSplits() = delete; + UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped); + + bool operator==(UnlabelledPatternEdgeSplits const &) const; + bool operator!=(UnlabelledPatternEdgeSplits const &) const; + ::FlexFlow::bidict< + ::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>> + unwrapped; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.h b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h new file mode 100644 index 0000000000..58704500ac --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_EDGE_SPLITS_H + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include + +namespace FlexFlow { + +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &, + ClosedPatternEdge const &); + +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml new file mode 100644 index 0000000000..fa714296c8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UnlabelledPatternEdgeSplits" +features = [ + "eq", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "", +] + +[[fields]] +name = "unwrapped" +type = "::FlexFlow::bidict<::FlexFlow::MultiDiEdge, std::pair<::FlexFlow::OutputMultiDiEdge, ::FlexFlow::InputMultiDiEdge>>" diff --git a/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h new file mode 100644 index 0000000000..29c5740c0e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/find_pattern_matches.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_FIND_PATTERN_MATCHES_H + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h new file mode 100644 index 0000000000..f292acba14 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct InputPatternEdge { + InputPatternEdge() = delete; + InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge); + + bool operator==(InputPatternEdge const &) const; + bool operator!=(InputPatternEdge const &) const; + bool operator<(InputPatternEdge const &) const; + bool operator>(InputPatternEdge const &) const; + bool operator<=(InputPatternEdge const &) const; + bool operator>=(InputPatternEdge const &) const; + ::FlexFlow::InputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::InputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h new file mode 100644 index 0000000000..b05fa479db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_INPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml new file mode 100644 index 0000000000..6da52b58aa --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "InputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::InputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h new file mode 100644 index 0000000000..e910be21ba --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MatchAdditionalCriterion { + MatchAdditionalCriterion() = delete; + MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion); + + std::function + node_criterion; + std::function + edge_criterion; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_ADDITIONAL_CRITERION_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml new file mode 100644 index 0000000000..c0107d84e9 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchAdditionalCriterion" +features = [] + +includes = [ + "", + "utils/graph.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_edge.dtg.h", +] + +[[fields]] +name = "node_criterion" +type = "std::function" + +[[fields]] +name = "edge_criterion" +type = "std::function" diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h new file mode 100644 index 0000000000..aa17814c52 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.dtg.h @@ -0,0 +1,29 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include + +namespace FlexFlow { +struct MatchSplit { + MatchSplit() = delete; + MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch); + + bool operator==(MatchSplit const &) const; + bool operator!=(MatchSplit const &) const; + MultiDiGraphPatternMatch prefix_submatch; + MultiDiGraphPatternMatch postfix_submatch; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.h b/lib/substitutions/include/substitutions/unlabelled/match_split.h new file mode 100644 index 0000000000..a23bc3f89a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MATCH_SPLIT_H + +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +MatchSplit empty_match_split(); +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml new file mode 100644 index 0000000000..3fd77e7b4a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MatchSplit" +features = [ + "eq", + # "ord", +] + +includes = [ + "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +] + +[[fields]] +name = "prefix_submatch" +type = "MultiDiGraphPatternMatch" + +[[fields]] +name = "postfix_submatch" +type = "MultiDiGraphPatternMatch" diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h new file mode 100644 index 0000000000..30f81504fe --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.dtg.h @@ -0,0 +1,36 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +struct MultiDiGraphPatternMatch { + MultiDiGraphPatternMatch() = delete; + MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment); + + bool operator==(MultiDiGraphPatternMatch const &) const; + bool operator!=(MultiDiGraphPatternMatch const &) const; + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> node_assignment; + ::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge> + edge_assignment; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h new file mode 100644 index 0000000000..aacae6d42a --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_MULTIDIGRAPH_PATTERN_MATCH_H + +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match(); +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml new file mode 100644 index 0000000000..778767ab62 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +# TODO(@lockshaw): rename to UnlabelledGraphPatternMatch +name = "MultiDiGraphPatternMatch" +features = [ + "eq", + # "ord", + # "hash", + # "fmt", +] + +includes = [ + "utils/bidict.h", + "utils/graph.h", + "substitutions/unlabelled/pattern_edge.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node>" + +[[fields]] +name = "edge_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternEdge, ::FlexFlow::OpenMultiDiEdge>" diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h new file mode 100644 index 0000000000..04ec8c656d --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct OutputPatternEdge { + OutputPatternEdge() = delete; + OutputPatternEdge(::FlexFlow::OutputMultiDiEdge const &raw_edge); + + bool operator==(OutputPatternEdge const &) const; + bool operator!=(OutputPatternEdge const &) const; + bool operator<(OutputPatternEdge const &) const; + bool operator>(OutputPatternEdge const &) const; + bool operator<=(OutputPatternEdge const &) const; + bool operator>=(OutputPatternEdge const &) const; + ::FlexFlow::OutputMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::OutputPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h new file mode 100644 index 0000000000..72e8ff02cf --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_OUTPUT_PATTERN_EDGE_H + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml new file mode 100644 index 0000000000..362cbc3265 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "OutputPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OutputMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h new file mode 100644 index 0000000000..4883590130 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternEdge { + PatternEdge() = delete; + PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge); + + bool operator==(PatternEdge const &) const; + bool operator!=(PatternEdge const &) const; + bool operator<(PatternEdge const &) const; + bool operator>(PatternEdge const &) const; + bool operator<=(PatternEdge const &) const; + bool operator>=(PatternEdge const &) const; + ::FlexFlow::OpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h new file mode 100644 index 0000000000..79db533d4e --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PATTERN_EDGE_H + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &); +bool is_closed_edge(PatternEdge const &); +bool is_input_edge(PatternEdge const &); +bool is_output_edge(PatternEdge const &); + +ClosedPatternEdge require_closed_edge(PatternEdge const &); +InputPatternEdge require_input_edge(PatternEdge const &); +OutputPatternEdge require_output_edge(PatternEdge const &); + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &); +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &); +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml new file mode 100644 index 0000000000..4abfa1c0db --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenMultiDiEdge" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h new file mode 100644 index 0000000000..223886b411 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_matching.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_MATCHING_H + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" +#include "substitutions/unlabelled/match_split.dtg.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { + +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion); + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h new file mode 100644 index 0000000000..56471c2e08 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct PatternNode { + PatternNode() = delete; + PatternNode(::FlexFlow::Node const &raw_node); + + bool operator==(PatternNode const &) const; + bool operator!=(PatternNode const &) const; + bool operator<(PatternNode const &) const; + bool operator>(PatternNode const &) const; + bool operator<=(PatternNode const &) const; + bool operator>=(PatternNode const &) const; + ::FlexFlow::Node raw_node; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::PatternNode const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_NODE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml new file mode 100644 index 0000000000..ecd0253516 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "PatternNode" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h new file mode 100644 index 0000000000..453c4020a8 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.dtg.h @@ -0,0 +1,47 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H + +#include "fmt/format.h" +#include "nlohmann/json.hpp" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include +#include + +namespace FlexFlow { +struct PatternSplit { + PatternSplit() = delete; + PatternSplit(std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second); + + bool operator==(PatternSplit const &) const; + bool operator!=(PatternSplit const &) const; + std::unordered_set<::FlexFlow::PatternNode> first; + std::unordered_set<::FlexFlow::PatternNode> second; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer { + static FlexFlow::PatternSplit from_json(json const &); + static void to_json(json &, FlexFlow::PatternSplit const &); +}; +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &); +std::ostream &operator<<(std::ostream &, PatternSplit const &); +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.h b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h new file mode 100644 index 0000000000..3fcc5cb12f --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_PATTERN_SPLIT_H + +#include "substitutions/unlabelled/edge_splits.dtg.h" +#include "substitutions/unlabelled/pattern_split.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &); + +GraphSplit get_raw_split(PatternSplit const &); + +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split); + +std::pair + apply_split(UnlabelledGraphPattern const &, PatternSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml new file mode 100644 index 0000000000..04d1080ff7 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "PatternSplit" +features = [ + "eq", + # "ord", + "json", + "fmt", +] + +includes = [ + "utils/graph.h", + "", + "substitutions/unlabelled/pattern_node.dtg.h", +] + +[[fields]] +name = "first" +type = "std::unordered_set<::FlexFlow::PatternNode>" + +[[fields]] +name = "second" +type = "std::unordered_set<::FlexFlow::PatternNode>" diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h new file mode 100644 index 0000000000..a2ba6c26d2 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.dtg.h @@ -0,0 +1,24 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H + +#include "utils/graph.h" + +namespace FlexFlow { +struct UnlabelledGraphPattern { + UnlabelledGraphPattern() = delete; + UnlabelledGraphPattern(::FlexFlow::OpenMultiDiGraphView const &raw_graph); + + ::FlexFlow::OpenMultiDiGraphView raw_graph; +}; +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h new file mode 100644 index 0000000000..9bb63037be --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_GRAPH_PATTERN_H + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &); +bool is_singleton_pattern(UnlabelledGraphPattern const &); +std::unordered_set get_nodes(UnlabelledGraphPattern const &); +std::unordered_set get_edges(UnlabelledGraphPattern const &); +std::vector + get_topological_ordering(UnlabelledGraphPattern const &); + +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &, PatternNode const &); +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &, PatternNode const &); + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml new file mode 100644 index 0000000000..03f4bd5523 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml @@ -0,0 +1,10 @@ +namespace = "FlexFlow" +name = "UnlabelledGraphPattern" +features = [] +includes = [ + "utils/graph.h" +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::OpenMultiDiGraphView" diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h new file mode 100644 index 0000000000..82440b5820 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.dtg.h @@ -0,0 +1,39 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H + +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +struct UpwardOpenPatternEdge { + UpwardOpenPatternEdge() = delete; + UpwardOpenPatternEdge(::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge); + + bool operator==(UpwardOpenPatternEdge const &) const; + bool operator!=(UpwardOpenPatternEdge const &) const; + bool operator<(UpwardOpenPatternEdge const &) const; + bool operator>(UpwardOpenPatternEdge const &) const; + bool operator<=(UpwardOpenPatternEdge const &) const; + bool operator>=(UpwardOpenPatternEdge const &) const; + ::FlexFlow::UpwardOpenMultiDiEdge raw_edge; +}; +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::UpwardOpenPatternEdge const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_DTG_H diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h new file mode 100644 index 0000000000..998cf1a519 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UPWARD_OPEN_PATTERN_EDGE_H + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml new file mode 100644 index 0000000000..a4c3bad809 --- /dev/null +++ b/lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "UpwardOpenPatternEdge" +features = [ + "eq", + "ord", + "hash", +] + +includes = [ + "utils/graph.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::UpwardOpenMultiDiEdge" diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc deleted file mode 100644 index 296a975626..0000000000 --- a/lib/substitutions/src/graph_pattern.cc +++ /dev/null @@ -1,257 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/parallel_computation_graph.h" -#include "substitutions/get_attribute.h" -#include "substitutions/graph_pattern_match.h" -#include "substitutions/operator_pattern.h" -#include "substitutions/parallel_tensor_pattern.h" - -namespace FlexFlow { - -std::optional - evaluate_list_index_access(int index, - std::optional const &v) { - if (!v.has_value() || - !std::holds_alternative>(v.value()) || - !std::holds_alternative>( - v.value())) { - return std::nullopt; - } - - if (index >= MAX_TENSOR_DIM) { - return std::nullopt; - } - - if (std::holds_alternative>(v.value())) { - return get>(v.value()).at(index); - } else { - return get>(v.value()).at(index); - } -} - -std::optional - evaluate_list_index_access(int const &index, - std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - auto vec = get>(v.value()); - - if (index >= vec.size()) { - return std::nullopt; - } - - return vec.at(index); -} - -std::optional - evaluate_list_size(std::optional const &v) { - return MAX_TENSOR_DIM; -} - -std::optional - evaluate_list_size(std::optional const &v) { - if (!v.has_value() || !std::holds_alternative>(v.value())) { - return std::nullopt; - } - - return (int)get>(v.value()).size(); -} - -struct EvaluateOperatorAttributeExpr { - EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - - std::optional - operator()(OperatorAttributeKey const &key) { - return get_attribute(this->attrs.attrs, key); - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - get_attribute(this->attrs.attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - std::optional v = - get_attribute(this->attrs.attrs, list_size.attribute_key); - return evaluate_list_size(v); - } - -private: - Operator attrs; -}; - -std::optional - evaluate_tensor_attribute_expr(ParallelTensor const &, - AttributeExpr const &); - -struct EvaluateTensorAttributeExpr { - EvaluateTensorAttributeExpr(ParallelTensor const &tensor_shape) - : tensor_shape(tensor_shape) {} - - template - std::optional evaluate(T const &t) { - return this->operator()(t); - } - - std::optional operator()(TensorAttributeKey key) { - switch (key) { - case TensorAttributeKey::DIM_SIZES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { - result.push_back(dim.size); - } - return result; - } - case TensorAttributeKey::DIM_DEGREES: { - std::vector result; - for (ParallelDim const &dim : this->tensor_shape.dims) { - result.push_back(dim.degree); - } - return result; - } - default: - throw std::runtime_error("Unknown TensorAttributeKey"); - } - } - - std::optional - operator()(ListIndexAccess const &index_access) { - std::optional v = - this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access.index, v); - } - - std::optional - operator()(ListSize const &list_size) { - return evaluate_list_size(this->evaluate(list_size.attribute_key)); - } - -private: - ParallelTensor tensor_shape; -}; - -std::optional - evaluate_attribute_expr(ParallelTensor const &tensor_shape, - AttributeExpr const &expr) { - return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); -} - -std::optional - evaluate_attribute_expr(Operator const &attrs, - AttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr(attrs), expr); -} - -template -std::optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - std::optional const &maybe_attribute_value) { - if (!maybe_attribute_value.has_value()) { - return std::nullopt; - } - V attr_val = maybe_attribute_value.value(); - - if (attr_val.index() != constraint_value.index()) { - return std::nullopt; - } - - if (constraint_type == ConstraintType::EQUAL) { - return attr_val == constraint_value; - } else { - throw std::runtime_error("Unknown constraint_type"); - } -} - -std::optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -std::optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { - auto value = evaluate_attribute_expr(params, constraint.attribute_expr); - OperatorAttributeValue v = value.value(); - return satisfies( - constraint.constraint_type, constraint.attribute_value, value); -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -std::optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return optional_all_of(pattern.attribute_constraints, - [&](OperatorAttributeConstraint const &c) { - return satisfies(params, c); - }); -} - -std::optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return optional_all_of( - pattern.attribute_constraints, - [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); -} - -bool operator_satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool parallel_tensor_satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { - return satisfies(params, pattern).value_or(false); -} - -bool assignment_satisfies(SubParallelComputationGraph const &pcg, - GraphPattern const &pattern, - MultiDiGraphPatternMatch const &patternMatch) { - bool result = true; - for (auto const &kv : patternMatch.node_assignment) { - Node patternNode = kv.first; - Node pcgNode = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); - result &= constraintResult.value_or(false); - } - - for (auto const &kv : patternMatch.edge_assignment) { - OpenMultiDiEdge patternEdge = kv.first; - OpenMultiDiEdge pcgEdge = kv.second; - std::optional constraintResult = - satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); - result &= constraintResult.value_or(false); - } - - result &= pattern_matches( - pattern, - pcg, - patternMatch, - MatchAdditionalCriterion{[](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, - OpenMultiDiEdge const &) { return true; }}); - - return result; -} -} // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc deleted file mode 100644 index f9c6b9a773..0000000000 --- a/lib/substitutions/src/graph_pattern_match.cc +++ /dev/null @@ -1,305 +0,0 @@ -#include "substitutions/graph_pattern.h" -#include "utils/hash-utils.h" -#include - -namespace FlexFlow { - -GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) { - std::vector topological_ordering = get_topological_ordering(pattern); - assert(topological_ordering.size() >= 2); - - int split_point = topological_ordering.size() / 2; - auto split = vector_split(topological_ordering, split_point); - std::unordered_set prefix(split.first.begin(), split.first.end()); - std::unordered_set postfix(split.second.begin(), split.second.end()); - return {prefix, postfix}; -} - -std::pair - apply_split(OpenMultiDiGraphView const &pattern, GraphSplit const &split) { - return {get_subgraph(pattern, split.first), - get_subgraph(pattern, split.second)}; -} - -/* -Given a match and a pattern split, gets the submatches in subpatterns. -*/ -MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - MultiDiGraphPatternMatch const &match, - GraphSplit const &split) { - auto prefix = split.first; - auto postfix = split.second; - - MatchSplit result; - - for (auto const &kv : match.node_assignment) { - Node pattern_node = kv.first; - Node graph_node = kv.second; - if (contains(split.first, pattern_node)) { - result.prefix_submatch.node_assignment.equate(pattern_node, graph_node); - } else { - assert(contains(split.second, pattern_node)); - result.postfix_submatch.node_assignment.equate(pattern_node, graph_node); - } - } - - auto edge_splits = get_edge_splits(pattern, split); - - std::function - handle_edge = [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) -> void { - auto edge_nodes = get_nodes(pattern_edge); - if (is_subseteq_of(edge_nodes, prefix)) { - result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else if (is_subseteq_of(edge_nodes, postfix)) { - result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); - } else { - assert(is_standard_edge(pattern_edge)); - assert(is_standard_edge(graph_edge)); - auto standard_edge = std::get(pattern_edge); - auto divided = edge_splits.at_l(standard_edge); - auto divided_graph_edge = split_edge(get(graph_edge)); - handle_edge(divided.first, divided_graph_edge.first); - handle_edge(divided.second, divided_graph_edge.second); - } - }; - - for (auto const &kv : match.edge_assignment) { - OpenMultiDiEdge pattern_edge = kv.first; - OpenMultiDiEdge graph_edge = match.edge_assignment.at_l(pattern_edge); - handle_edge(pattern_edge, graph_edge); - } - - return result; -} - -bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { - return num_nodes(pattern) == 1; -} - -bool pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MultiDiGraphPatternMatch const &match, - MatchAdditionalCriterion const &additional_criterion) { - if (is_singleton_pattern(pattern)) { - Node pattern_node = get_only(get_nodes(pattern)); - Node graph_matched_node = match.node_assignment.at_l(pattern_node); - if (!additional_criterion.node_criterion(pattern_node, - graph_matched_node)) { - return false; - } - for (OpenMultiDiEdge const &e : get_edges(pattern)) { - OpenMultiDiEdge graph_matched_edge = match.edge_assignment.at_l(e); - - assert(is_input_edge(e) || is_output_edge(e)); - if (is_input_edge(e)) { - if (is_output_edge(graph_matched_edge)) { - return false; - } - UpwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = std::get(e); - if (match.node_assignment.at_l(input_edge.dst) != - get_dst_node(matched_edge)) { - return false; - } - } else { - if (is_input_edge(graph_matched_edge)) { - return false; - } - DownwardOpenMultiDiEdge matched_edge = - narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = std::get(e); - if (match.node_assignment.at_l(output_edge.src) != - get_src_node(matched_edge)) { - return false; - } - } - - if (!additional_criterion.edge_criterion(e, graph_matched_edge)) { - return false; - } - } - - return true; - } - - auto split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto submatches = apply_split(pattern, match, split); - - return pattern_matches(subpatterns.first, - graph, - submatches.prefix_submatch, - additional_criterion) && - pattern_matches(subpatterns.second, - graph, - submatches.postfix_submatch, - additional_criterion); -} - -template -bool dst_compare(T const &lhs, T const &rhs) { - return get_dst_idx(lhs) < get_dst_idx(rhs); -} - -template -bool src_compare(T const &lhs, T const &rhs) { - return get_src_idx(lhs) < get_src_idx(rhs); -} - -std::optional - get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - Node const &graph_node) { - assert(is_singleton_pattern(pattern)); - - Node pattern_node = get_only(get_nodes(pattern)); - - MultiDiGraphPatternMatch match; - match.node_assignment.equate(pattern_node, graph_node); - - std::unordered_set incoming = - get_incoming_edges(graph, graph_node); - std::unordered_set outgoing = - get_outgoing_edges(graph, graph_node); - - std::unordered_set pattern_incoming = - get_incoming_edges(pattern, pattern_node); - std::unordered_set pattern_outgoing = - get_outgoing_edges(pattern, pattern_node); - - if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return std::nullopt; - } - - if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return std::nullopt; - } - - std::vector incoming_ordered = - sorted_by(incoming, dst_compare); - std::vector outgoing_ordered = - sorted_by(outgoing, src_compare); - - std::vector pattern_incoming_ordered = - sorted_by(pattern_incoming, dst_compare); - std::vector pattern_outgoing_ordered = - sorted_by(pattern_outgoing, src_compare); - - if (pattern_incoming.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < incoming_ordered.size(); ++i) { - UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i], - pattern_edge = pattern_incoming_ordered[i]; - NodePort graph_port = get_dst_idx(graph_edge), - pattern_port = get_dst_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.emplace(graph_port, pattern_port); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - if (pattern_outgoing.size()) { - std::unordered_map node_port_mapping; - for (int i = 0; i < outgoing_ordered.size(); ++i) { - DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], - pattern_edge = pattern_outgoing_ordered[i]; - NodePort graph_port = get_src_idx(graph_edge), - pattern_port = get_src_idx(pattern_edge); - if (!contains_key(node_port_mapping, graph_port)) { - node_port_mapping.insert({graph_port, pattern_port}); - } else { - if (pattern_port != node_port_mapping.at(graph_port)) { - return std::nullopt; - } - } - match.edge_assignment.equate(widen(pattern_edge), - widen(graph_edge)); - } - } - - return match; -} - -std::optional unsplit_matches( - MultiDiGraphPatternMatch const &prefix, - MultiDiGraphPatternMatch const &postfix, - bidict> const - &edge_splits) { - MultiDiGraphPatternMatch result; - std::unordered_set handled; - for (auto const &kv : edge_splits) { - MultiDiEdge standard_edge = kv.first; - OutputMultiDiEdge output_edge = kv.second.first; - InputMultiDiEdge input_edge = kv.second.second; - handled.insert(output_edge); - handled.insert(input_edge); - - OpenMultiDiEdge output_graph_edge = - prefix.edge_assignment.at_l(output_edge); - OpenMultiDiEdge input_graph_edge = postfix.edge_assignment.at_l(input_edge); - if (output_graph_edge == input_graph_edge) { - result.edge_assignment.equate(standard_edge, output_graph_edge); - } else { - return std::nullopt; - } - } - - for (auto const &kv : - merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { - if (!contains(handled, kv.first)) { - result.edge_assignment.equate(kv.first, kv.second); - } - } - - result.node_assignment = - merge_maps(prefix.node_assignment, postfix.node_assignment); - - return result; -} - -std::vector - find_pattern_matches(OpenMultiDiGraphView const &pattern, - OpenMultiDiGraphView const &graph, - MatchAdditionalCriterion const &additional_criterion) { - std::vector matches; - if (is_singleton_pattern(pattern)) { - for (Node const &graph_node : get_nodes(graph)) { - std::optional candidate = - get_candidate_singleton_match(pattern, graph, graph_node); - if (candidate.has_value() && - pattern_matches( - pattern, graph, candidate.value(), additional_criterion)) { - matches.push_back(candidate.value()); - } - } - } else { - GraphSplit split = split_pattern(pattern); - auto subpatterns = apply_split(pattern, split); - auto prefix_matches = - find_pattern_matches(subpatterns.first, graph, additional_criterion); - auto postfix_matches = - find_pattern_matches(subpatterns.second, graph, additional_criterion); - auto edge_splits = get_edge_splits(pattern, split); - for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { - for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - std::optional unsplit = - unsplit_matches(prefix_match, postfix_match, edge_splits); - if (unsplit.has_value()) { - matches.push_back(unsplit.value()); - } - } - } - } - - return matches; -} - -} // namespace FlexFlow diff --git a/lib/substitutions/src/sub_parallel_computation_graph.cc b/lib/substitutions/src/sub_parallel_computation_graph.cc deleted file mode 100644 index e8cb093222..0000000000 --- a/lib/substitutions/src/sub_parallel_computation_graph.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "substitutions/sub_parallel_computation_graph.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 15816185ee..94993f3c90 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -2,475 +2,386 @@ namespace FlexFlow { -struct DeriveValidOperatorAttributeExpr { - template - std::unordered_set> - operator()(T const &t) { - return derive_valid_operator_attribute_expr(t); - } - - std::unordered_set> - derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { - return {key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListIndexAccess const &access) { - return {access, access.attribute_key}; - } - - std::unordered_set> - derive_valid_operator_attribute_expr( - ListSize const &ls) { - return {ls, ls.attribute_key}; - } -}; - -std::unordered_set> - get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { - return set_union(transform( - pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { - return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); - })); -} - -bool is_valid_operator_attribute_expr( - OperatorPattern const &pattern, - AttributeExpr const &expr) { - return contains(get_valid_operator_attribute_exprs(pattern), expr); -} - -struct IsValidOperatorAttributeExprFunctor { - GraphPattern const &graph_pattern; - - template - bool operator()(T const &t) const { - return is_valid(t); - } - - bool is_valid(OperatorAttrAccess const &t) const { - return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), - t.attr_expr); - } - - bool is_valid(AttrConstant const &t) const { - return true; - } -}; - -bool is_valid_operator_attribute_expr(GraphPattern const &pattern, - OperatorAttributeExpr const &expr) { - return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); -} - -bool is_valid_substitution(Substitution const &s) { - for (Node const &node : get_nodes(s.output_graph_expr.value())) { - for (OperatorAttributeExpr expr : - values(s.output_graph_expr.value().at(node).assignments)) { - if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { - return false; - } - } - } - return true; -} - -struct EvaluateOperatorAttributeExpr { - SubParallelComputationGraph const &graph; - MultiDiGraphPatternMatch const &match; - - template - OperatorAttributeValue operator()(T const &t) { - return evaluate(t); - } - - OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { - Node node_in_pattern = t.node; - Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); - return evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value(); - } - - OperatorAttributeValue evaluate(AttrConstant const &t) { - return t.value; - } -}; - -OperatorAttributeValue - evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, - MultiDiGraphPatternMatch const &match, - OperatorAttributeExpr const &expr) { - return visit(EvaluateOperatorAttributeExpr{g, match}, expr); -} - -Operator get_operator_attrs(SubParallelComputationGraph const &graph, - MultiDiGraphPatternMatch const &match, - OperatorAttrAssignment const &assignment) { - std::unordered_map assignments; - for (auto const &[key, expr] : assignment.assignments) { - OperatorAttributeValue value = - evaluate_graph_attribute_expr(graph, match, expr); - assignments.emplace(key, value); - } - assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(std::holds_alternative( - assignments.at(OperatorAttributeKey::OP_TYPE))); - OperatorType op_type = - std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); - switch (op_type) { - case Op::BATCHMATMUL: - return Operator{ - BatchMatmulAttrs{std::get(assignments.at( - OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - std::get(assignments.at( - OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - std::nullopt}; - case Op::BATCHNORM: - return Operator{BatchNormAttrs{std::get( - assignments.at(OperatorAttributeKey::RELU))}, - std::nullopt}; - case Op::CAST: - return Operator{CastAttrs{std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE))}, - std::nullopt}; - case Op::CONCAT: - return Operator{ - ConcatAttrs{ - std::get(assignments.at(OperatorAttributeKey::AXIS)), - std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - std::nullopt}; - case Op::CONV2D: - return Operator{ - Conv2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::GROUPS)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - std::nullopt}; - case Op::DROPOUT: - return Operator{DropoutAttrs{std::get(assignments.at( - OperatorAttributeKey::RATE)), - std::get(assignments.at( - OperatorAttributeKey::SEED))}, - std::nullopt}; - case Op::EW_ADD: - case Op::EW_DIV: - case Op::EW_EQUAL: - case Op::EW_GREATER: - case Op::EW_LESS: - case Op::EW_MAX: - case Op::EW_MIN: - case Op::EW_MUL: - case Op::EW_SUB: - return Operator{ - ElementBinaryAttrs{op_type, - std::get(assignments.at( - OperatorAttributeKey::DATA_TYPE)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - std::get(assignments.at( - OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - std::nullopt}; - case Op::SCALAR_ADD: - case Op::SCALAR_FLOOR_DIV: - case Op::SCALAR_MULTIPLY: - case Op::SCALAR_SUB: - case Op::SCALAR_TRUE_DIV: - return Operator{ - ElementScalarUnaryAttrs{ - op_type, - std::get(assignments.at(OperatorAttributeKey::SCALAR))}, - std::nullopt}; - case Op::EXP: - case Op::IDENTITY: - case Op::GELU: - case Op::RSQRT: - case Op::POW: - case Op::SIN: - case Op::COS: - return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; - case Op::EMBEDDING: - return Operator{ - EmbeddingAttrs{ - std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::AGGR)), - std::get( - assignments.at(OperatorAttributeKey::OP_TYPE))}, - std::nullopt}; - case Op::FLAT: - return Operator{FlatAttrs{}, std::nullopt}; - case Op::GATHER: - return Operator{GatherAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case Op::INPUT: - return Operator{InputAttrs{}, std::nullopt}; - case Op::LAYERNORM: - return Operator{ - LayerNormAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - std::get( - assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - std::get(assignments.at(OperatorAttributeKey::EPSILON))}, - std::nullopt}; - case Op::LINEAR: - return Operator{ - LinearAttrs{ - std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - std::get( - assignments.at(OperatorAttributeKey::DATA_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION)), - std::get>( - assignments.at(OperatorAttributeKey::REGULARIZER))}, - std::nullopt}; - case Op::MULTIHEAD_ATTENTION: - return Operator{ - MultiHeadAttentionAttrs{ - std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - std::get(assignments.at(OperatorAttributeKey::VDIM)), - std::get(assignments.at(OperatorAttributeKey::DROPOUT)), - std::get(assignments.at(OperatorAttributeKey::BIAS)), - std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - std::get( - assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - std::nullopt}; - case Op::NOOP: - return Operator{NoopAttrs{}, std::nullopt}; - case Op::POOL2D: - return Operator{ - Pool2DAttrs{ - std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - std::get( - assignments.at(OperatorAttributeKey::ACTIVATION))}, - std::nullopt}; - case Op::REDUCE_ARGMAX: - case Op::REDUCE_ARGMIN: - case Op::REDUCE_MAX: - case Op::REDUCE_MEAN: - case Op::REDUCE_MIN: - case Op::REDUCE_PROD: - case Op::REDUCE_SUM: - return Operator{ - ReduceAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::AXES)), - op_type, - std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - std::nullopt}; - case Op::REVERSE: - return Operator{ReverseAttrs{std::get( - assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case Op::RESHAPE: - return Operator{ReshapeAttrs{std::get( - assignments.at(OperatorAttributeKey::SHAPE))}, - std::nullopt}; - case Op::SPLIT: - return Operator{ - SplitAttrs{ - std::get>( - assignments.at(OperatorAttributeKey::SPLITS)), - std::get(assignments.at(OperatorAttributeKey::AXIS))}, - std::nullopt}; - case Op::SOFTMAX: - return Operator{SoftmaxAttrs{std::get( - assignments.at(OperatorAttributeKey::DIM))}, - std::nullopt}; - case Op::TOPK: - return Operator{ - TopKAttrs{ - std::get(assignments.at(OperatorAttributeKey::K)), - std::get(assignments.at(OperatorAttributeKey::SORTED))}, - std::nullopt}; - case Op::TRANSPOSE: - return Operator{ - TransposeAttrs{std::get>( - assignments.at(OperatorAttributeKey::PERMUTATION))}, - std::nullopt}; - case Op::COMBINE: - return Operator{CombineAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REDUCTION: - return Operator{ - ReductionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REPARTITION: - return Operator{ - RepartitionAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - case Op::REPLICATE: - return Operator{ - ReplicateAttrs{std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DIM)), - std::get(assignments.at( - OperatorAttributeKey::PARALLEL_DEGREE))}, - std::nullopt}; - default: - throw mk_runtime_error("Unknown Operator"); - } -} - -struct AddMappedEdgeFunctor { - bidict const &node_mapping; - SubParallelComputationGraph &new_pcg; - - template - void operator()(T const &t) { - return add_mapped_edge(t); - } - - void add_mapped_edge(InputMultiDiEdge const &e) { - new_pcg.add_edge(InputMultiDiEdge{ - node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(OutputMultiDiEdge const &e) { - new_pcg.add_edge(OutputMultiDiEdge{ - node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); - } - - void add_mapped_edge(MultiDiEdge const &e) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } -}; - -struct AddNewEdgeFunctor { - SubParallelComputationGraph const &old_pcg; - SubParallelComputationGraph &new_pcg; - MultiDiGraphPatternMatch const &match; - bidict node_mapping; - - template - void operator()(TO const &old_edge, TN const &new_edge) { - return add_new_edge(old_edge, new_edge); - } - - void add_new_edge(InputMultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - InputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(old_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(OutputMultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), - new_pcg.add_node_port(), - old_edge.uid}); - } - - void add_new_edge(MultiDiEdge const &old_edge, - OutputMultiDiEdge const &new_edge) { - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), - new_pcg.add_node_port(), - node_mapping.at_l(new_edge.src), - new_pcg.add_node_port()}); - } - - void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { - assert(false); - } - - void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { - assert(false); - } -}; - -SubParallelComputationGraph - apply_substitution(SubParallelComputationGraph const &pcg, - Substitution const &substitution, - MultiDiGraphPatternMatch const &match) { - SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::template create< - UnorderedOutputLabelledOpenMultiDiGraph>(); - bidict node_mapping; // Refactor it with global nodes - for (Node const &node : get_nodes(pcg)) { - if (!contains_r(match.node_assignment, node)) { - node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); - } - } - for (OpenMultiDiEdge const &edge : get_edges(pcg)) { - if (!contains_r(match.edge_assignment, edge)) { - visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); - } - } - for (Node const &output_node : - get_nodes(substitution.output_graph_expr.value())) { - Operator new_op = get_operator_attrs( - pcg, match, substitution.output_graph_expr.value().at(output_node)); - Node new_node = new_pcg.add_node(new_op); - node_mapping.equate(output_node, new_node); - } - for (OpenMultiDiEdge const &output_edge : - get_edges(substitution.output_graph_expr.value())) { - if (std::holds_alternative(output_edge)) { - InputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else if (std::holds_alternative(output_edge)) { - OutputMultiDiEdge e = std::get(output_edge); - OpenMultiDiEdge original_edge = - match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); - visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, - original_edge, - output_edge); - } else { - assert(std::holds_alternative(output_edge)); - MultiDiEdge e = std::get(output_edge); - new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), - new_pcg.add_node_port(), - node_mapping.at_l(e.src), - new_pcg.add_node_port()}); - } - } - - return new_pcg; -} +/* struct DeriveValidOperatorAttributeExpr { */ +/* template */ +/* std::unordered_set> */ +/* operator()(T const &t) { */ +/* return derive_valid_operator_attribute_expr(t); */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { + */ +/* return {key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListIndexAccess const &access) { */ +/* return {access, access.attribute_key}; */ +/* } */ + +/* std::unordered_set> */ +/* derive_valid_operator_attribute_expr( */ +/* ListSize const &ls) { */ +/* return {ls, ls.attribute_key}; */ +/* } */ +/* }; */ + +/* std::unordered_set> */ +/* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ +/* return set_union(transform( */ +/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) + * { */ +/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); + */ +/* })); */ +/* } */ + +/* bool is_valid_operator_attribute_expr( */ +/* OperatorPattern const &pattern, */ +/* AttributeExpr const &expr) { */ +/* return contains(get_valid_operator_attribute_exprs(pattern), expr); */ +/* } */ + +/* struct IsValidOperatorAttributeExprFunctor { */ +/* GraphPattern const &graph_pattern; */ + +/* template */ +/* bool operator()(T const &t) const { */ +/* return is_valid(t); */ +/* } */ + +/* bool is_valid(OperatorAttrAccess const &t) const { */ +/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + */ +/* t.attr_expr); */ +/* } */ + +/* bool is_valid(AttrConstant const &t) const { */ +/* return true; */ +/* } */ +/* }; */ + +/* bool is_valid_operator_attribute_expr(GraphPattern const &pattern, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); */ +/* } */ + +/* bool is_valid_substitution(Substitution const &s) { */ +/* for (Node const &node : get_nodes(s.output_graph_expr.value())) { */ +/* for (OperatorAttributeExpr expr : */ +/* values(s.output_graph_expr.value().at(node).assignments)) { */ +/* if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { */ +/* return false; */ +/* } */ +/* } */ +/* } */ +/* return true; */ +/* } */ + +/* struct EvaluateOperatorAttributeExpr { */ +/* SubParallelComputationGraph const &graph; */ +/* MultiDiGraphPatternMatch const &match; */ + +/* template */ +/* OperatorAttributeValue operator()(T const &t) { */ +/* return evaluate(t); */ +/* } */ + +/* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ +/* Node node_in_pattern = t.node; */ +/* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ +/* return evaluate_attribute_expr(graph.at(node_in_pcg), + * t.attr_expr).value(); */ +/* } */ + +/* OperatorAttributeValue evaluate(AttrConstant const &t) { */ +/* return t.value; */ +/* } */ +/* }; */ + +/* OperatorAttributeValue */ +/* evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttributeExpr const &expr) { */ +/* return visit(EvaluateOperatorAttributeExpr{g, match}, expr); */ +/* } */ + +/* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ +/* MultiDiGraphPatternMatch const &match, */ +/* OperatorAttrAssignment const &assignment) { */ +/* std::unordered_map + * assignments; */ +/* for (auto const &[key, expr] : assignment.assignments) { */ +/* OperatorAttributeValue value = */ +/* evaluate_graph_attribute_expr(graph, match, expr); */ +/* assignments.emplace(key, value); */ +/* } */ +/* assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); */ +/* assert(std::holds_alternative( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))); */ +/* OperatorType op_type = */ +/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); + */ +/* switch (op_type) { */ +/* case OperatorType::BATCHMATMUL: */ +/* return Operator{ */ +/* BatchMatmulAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::A_SEQ_LENGTH_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::BATCHNORM: */ +/* return Operator{BatchNormAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::RELU))}, */ +/* std::nullopt}; */ +/* case OperatorType::CAST: */ +/* return Operator{CastAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, + */ +/* std::nullopt}; */ +/* case OperatorType::CONCAT: */ +/* return Operator{ */ +/* ConcatAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::CONV2D: */ +/* return Operator{ */ +/* Conv2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::DROPOUT: */ +/* return Operator{DropoutAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::RATE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SEED))}, */ +/* std::nullopt}; */ +/* case OperatorType::EW_ADD: */ +/* case OperatorType::EW_DIV: */ +/* case OperatorType::EW_EQUAL: */ +/* case OperatorType::EW_GREATER: */ +/* case OperatorType::EW_LESS: */ +/* case OperatorType::EW_MAX: */ +/* case OperatorType::EW_MIN: */ +/* case OperatorType::EW_MUL: */ +/* case OperatorType::EW_SUB: */ +/* return Operator{ */ +/* ElementBinaryAttrs{op_type, */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::DATA_TYPE)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::SCALAR_ADD: */ +/* case OperatorType::SCALAR_FLOOR_DIV: */ +/* case OperatorType::SCALAR_MULTIPLY: */ +/* case OperatorType::SCALAR_SUB: */ +/* case OperatorType::SCALAR_TRUE_DIV: */ +/* return Operator{ */ +/* ElementScalarUnaryAttrs{ */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, + */ +/* std::nullopt}; */ +/* case OperatorType::EXP: */ +/* case OperatorType::IDENTITY: */ +/* case OperatorType::GELU: */ +/* case OperatorType::RSQRT: */ +/* case OperatorType::POW: */ +/* case OperatorType::SIN: */ +/* case OperatorType::COS: */ +/* return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; */ +/* case OperatorType::EMBEDDING: */ +/* return Operator{ */ +/* EmbeddingAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::AGGR)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::FLAT: */ +/* return Operator{FlatAttrs{}, std::nullopt}; */ +/* case OperatorType::GATHER: */ +/* return Operator{GatherAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::INPUT: */ +/* return Operator{InputAttrs{}, std::nullopt}; */ +/* case OperatorType::LAYERNORM: */ +/* return Operator{ */ +/* LayerNormAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), + */ +/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, + */ +/* std::nullopt}; */ +/* case OperatorType::LINEAR: */ +/* return Operator{ */ +/* LinearAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::REGULARIZER))}, */ +/* std::nullopt}; */ +/* case OperatorType::MULTIHEAD_ATTENTION: */ +/* return Operator{ */ +/* MultiHeadAttentionAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + */ +/* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ +/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + */ +/* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ +/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ +/* std::nullopt}; */ +/* case OperatorType::NOOP: */ +/* return Operator{NoopAttrs{}, std::nullopt}; */ +/* case OperatorType::POOL2D: */ +/* return Operator{ */ +/* Pool2DAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + */ +/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + */ +/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + */ +/* std::get( */ +/* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::REDUCE_ARGMAX: */ +/* case OperatorType::REDUCE_ARGMIN: */ +/* case OperatorType::REDUCE_MAX: */ +/* case OperatorType::REDUCE_MEAN: */ +/* case OperatorType::REDUCE_MIN: */ +/* case OperatorType::REDUCE_PROD: */ +/* case OperatorType::REDUCE_SUM: */ +/* return Operator{ */ +/* ReduceAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::AXES)), */ +/* op_type, */ +/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::REVERSE: */ +/* return Operator{ReverseAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::AXIS))}, */ +/* std::nullopt}; */ +/* case OperatorType::RESHAPE: */ +/* return Operator{ReshapeAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::SHAPE))}, */ +/* std::nullopt}; */ +/* case OperatorType::SPLIT: */ +/* return Operator{ */ +/* SplitAttrs{ */ +/* std::get>( */ +/* assignments.at(OperatorAttributeKey::SPLITS)), */ +/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, + */ +/* std::nullopt}; */ +/* case OperatorType::SOFTMAX: */ +/* return Operator{SoftmaxAttrs{std::get( */ +/* assignments.at(OperatorAttributeKey::DIM))}, */ +/* std::nullopt}; */ +/* case OperatorType::TOPK: */ +/* return Operator{ */ +/* TopKAttrs{ */ +/* std::get(assignments.at(OperatorAttributeKey::K)), */ +/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, + */ +/* std::nullopt}; */ +/* case OperatorType::TRANSPOSE: */ +/* return Operator{ */ +/* TransposeAttrs{std::get>( */ +/* assignments.at(OperatorAttributeKey::PERMUTATION))}, */ +/* std::nullopt}; */ +/* case OperatorType::COMBINE: */ +/* return Operator{CombineAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), + */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, + */ +/* std::nullopt}; */ +/* case OperatorType::REDUCTION: */ +/* return Operator{ */ +/* ReductionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPARTITION: */ +/* return Operator{ */ +/* RepartitionAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* case OperatorType::REPLICATE: */ +/* return Operator{ */ +/* ReplicateAttrs{std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DIM)), */ +/* std::get(assignments.at( */ +/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ +/* std::nullopt}; */ +/* default: */ +/* throw mk_runtime_error("Unknown Operator"); */ +/* } */ +/* } */ } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/constraint_type.dtg.cc b/lib/substitutions/src/substitutions/constraint_type.dtg.cc new file mode 100644 index 0000000000..aa5c30dbe9 --- /dev/null +++ b/lib/substitutions/src/substitutions/constraint_type.dtg.cc @@ -0,0 +1,64 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/constraint_type.enum.toml +/* proj-data +{ + "generated_from": "06b029d76658cb434abf08b1fdb86137" +} +*/ + +#include "substitutions/constraint_type.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::ConstraintType x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + return "EQUAL"; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, ConstraintType x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, ConstraintType x) { + switch (x) { + case ConstraintType::EQUAL: + j = "EQUAL"; + break; + default: + std::ostringstream oss; + oss << "Unknown ConstraintType value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, ConstraintType &x) { + std::string as_str = j.get(); + if (as_str == "EQUAL") { + x = ConstraintType::EQUAL; + } else { + std::ostringstream oss; + oss << "Unknown ConstraintType value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::element( + FlexFlow::ConstraintType::EQUAL); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/graph_pattern.cc b/lib/substitutions/src/substitutions/graph_pattern.cc new file mode 100644 index 0000000000..22cf12b4cf --- /dev/null +++ b/lib/substitutions/src/substitutions/graph_pattern.cc @@ -0,0 +1,42 @@ +#include "substitutions/graph_pattern.h" +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/satisfies_pattern.h" + +namespace FlexFlow { + +UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { + return UnlabelledGraphPattern{p.raw_graph}; +} + +TensorAttributePattern get_tensor_pattern(PCGPattern const &p, + PatternEdge const &e) { + return p.raw_graph.at(e.raw_edge); +} + +OperatorAttributePattern get_operator_pattern(PCGPattern const &p, + PatternNode const &n) { + return p.raw_graph.at(n.raw_node); +} + +bool assignment_satisfies(SubParallelComputationGraph const &pcg, + PCGPattern const &pattern, + MultiDiGraphPatternMatch const &patternMatch) { + return unlabelled_pattern_does_match( + get_unlabelled_pattern(pattern), + pcg.raw_graph, + patternMatch, + MatchAdditionalCriterion{ + [&](PatternNode const &patternNode, Node const &pcgNode) { + return operator_satisfies_pattern( + get_operator_attrs(pcg, pcgNode), + get_operator_pattern(pattern, patternNode)); + }, + [&](PatternEdge const &patternEdge, OpenMultiDiEdge const &pcgEdge) { + return parallel_tensor_satisfies_pattern( + get_parallel_tensor_attrs(pcg, pcgEdge), + get_tensor_pattern(pattern, patternEdge)); + }}); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc new file mode 100644 index 0000000000..53973dc1cb --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_access.cc @@ -0,0 +1,41 @@ +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + eval_list_access(PCGOperatorAttrs const &attrs, + OperatorAttributeListIndexAccess const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + int value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else if constexpr (std::is_same_v>) { + if (acc.index >= v.size()) { + return std::nullopt; + } else { + ff_dim_t value = v.at(acc.index); + return OperatorAttributeValue{value}; + } + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc new file mode 100644 index 0000000000..a3ae9c84d1 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/eval_list_size.cc @@ -0,0 +1,31 @@ +#include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + eval_list_size(PCGOperatorAttrs const &attrs, + OperatorAttributeListSize const &acc) { + std::optional from_attr = + get_attribute(attrs, acc.attribute_key); + + if (!from_attr.has_value()) { + return std::nullopt; + } + + return from_attr.value().visit>( + [&](auto const &v) -> std::optional { + using T = std::decay_t; + + if constexpr (std::is_same_v> || + std::is_same_v>) { + size_t size = v.size(); + return OperatorAttributeValue{size}; + } else { + throw mk_runtime_error("Invalid operand"); + } + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc similarity index 72% rename from lib/substitutions/src/operator_attributes.cc rename to lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 8bd8688194..e168760c3b 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,11 +1,26 @@ +#include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "substitutions/get_attribute.h" +#include "utils/containers.h" namespace FlexFlow { std::optional get_attribute(BatchMatmulAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + +std::optional get_attribute(BatchNormAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::RELU: + return p.relu; default: return std::nullopt; } @@ -14,6 +29,8 @@ std::optional get_attribute(BatchMatmulAttrs const &p, std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: @@ -24,6 +41,8 @@ std::optional get_attribute(CastAttrs const &p, std::optional get_attribute(CombineAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; case OperatorAttributeKey::PARALLEL_DIM: @@ -36,6 +55,8 @@ std::optional get_attribute(CombineAttrs const &p, std::optional get_attribute(ConcatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -46,6 +67,8 @@ std::optional get_attribute(ConcatAttrs const &p, std::optional get_attribute(Conv2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -72,6 +95,8 @@ std::optional get_attribute(Conv2DAttrs const &p, std::optional get_attribute(ElementBinaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -80,6 +105,8 @@ std::optional get_attribute(ElementBinaryAttrs const &p, std::optional get_attribute(ElementUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -88,6 +115,8 @@ std::optional get_attribute(ElementUnaryAttrs const &p, std::optional get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -96,6 +125,8 @@ std::optional std::optional get_attribute(DropoutAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -104,6 +135,8 @@ std::optional get_attribute(DropoutAttrs const &p, std::optional get_attribute(EmbeddingAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::DATA_TYPE: return p.data_type; case OperatorAttributeKey::AGGR: @@ -120,6 +153,8 @@ std::optional get_attribute(EmbeddingAttrs const &p, std::optional get_attribute(FlatAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -128,6 +163,8 @@ std::optional get_attribute(FlatAttrs const &p, std::optional get_attribute(GatherAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -135,9 +172,21 @@ std::optional get_attribute(GatherAttrs const &p, } } +std::optional get_attribute(InputAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(LayerNormAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -146,6 +195,8 @@ std::optional get_attribute(LayerNormAttrs const &p, std::optional get_attribute(LinearAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; case OperatorAttributeKey::USE_BIAS: @@ -166,6 +217,8 @@ std::optional get_attribute(LinearAttrs const &p, std::optional get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::NUM_HEADS: return p.num_heads; case OperatorAttributeKey::USE_BIAS: @@ -175,9 +228,21 @@ std::optional } } +std::optional get_attribute(NoopAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + std::optional get_attribute(Pool2DAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::KERNEL_H: return p.kernel_h; case OperatorAttributeKey::KERNEL_W: @@ -202,6 +267,8 @@ std::optional get_attribute(Pool2DAttrs const &p, std::optional get_attribute(ReduceAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -210,8 +277,8 @@ std::optional get_attribute(ReduceAttrs const &p, std::optional get_attribute(ReductionAttrs const &p, OperatorAttributeKey key) { switch (key) { - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.reduction_dim; + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: @@ -222,6 +289,8 @@ std::optional get_attribute(ReductionAttrs const &p, std::optional get_attribute(RepartitionAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: @@ -234,8 +303,8 @@ std::optional get_attribute(RepartitionAttrs const &p, std::optional get_attribute(ReplicateAttrs const &p, OperatorAttributeKey key) { switch (key) { - case OperatorAttributeKey::PARALLEL_OP_DIM: - return p.replicate_dim; + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: @@ -246,6 +315,20 @@ std::optional get_attribute(ReplicateAttrs const &p, std::optional get_attribute(ReshapeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + default: + return std::nullopt; + } +} + +std::optional get_attribute(ReverseAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::AXIS: + return p.axis; default: return std::nullopt; } @@ -254,6 +337,8 @@ std::optional get_attribute(ReshapeAttrs const &p, std::optional get_attribute(SplitAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.axis; default: @@ -264,6 +349,8 @@ std::optional get_attribute(SplitAttrs const &p, std::optional get_attribute(SoftmaxAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::AXIS: return p.dim; default: @@ -274,6 +361,8 @@ std::optional get_attribute(SoftmaxAttrs const &p, std::optional get_attribute(TopKAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); default: return std::nullopt; } @@ -282,38 +371,19 @@ std::optional get_attribute(TopKAttrs const &p, std::optional get_attribute(TransposeAttrs const &p, OperatorAttributeKey key) { switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return p.perm; + return as_vector(p.perm); default: return std::nullopt; } } -struct GetAttribute { - GetAttribute(OperatorAttributeKey key) : key(key) {} - - template - std::optional operator()(T const &t) { - return get_attribute(t, this->key); - } - -private: - OperatorAttributeKey key; -}; - -struct GetOpType { - template - std::optional operator()(T const &t) { - return get_op_type(t); - } -}; - std::optional get_attribute(PCGOperatorAttrs const &p, OperatorAttributeKey key) { - if (key == OperatorAttributeKey::OP_TYPE) { - return std::visit(GetOpType{}, p); - } - return std::visit(GetAttribute(key), p); + return p.visit>( + [&](auto const &attrs) { return get_attribute(attrs, key); }); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..bc913b7c1a --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.dtg.cc @@ -0,0 +1,121 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "7867bd0f403866c13417171bb5ec364c" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeConstraint::OperatorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::OperatorAttributeExpr const &attribute_expr, + ::FlexFlow::OperatorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool OperatorAttributeConstraint::operator==( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator!=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator<=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool OperatorAttributeConstraint::operator>=( + OperatorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::OperatorAttributeExpr>(), + j.at("attribute_value") + .template get<::FlexFlow::OperatorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeConstraint const &v) { + j["__type"] = "OperatorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc new file mode 100644 index 0000000000..4a55fa3de3 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -0,0 +1,23 @@ +#include "substitutions/operator_pattern/operator_attribute_expr.h" +#include "substitutions/operator_pattern/eval_list_access.h" +#include "substitutions/operator_pattern/eval_list_size.h" +#include "substitutions/operator_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + evaluate_attribute_expr(PCGOperatorAttrs const &attrs, + OperatorAttributeExpr const &expr) { + return expr.visit>(overload{ + [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, + [&](OperatorAttributeListSize const &k) { + return eval_list_size(attrs, k); + }, + [&](OperatorAttributeListIndexAccess const &k) { + return eval_list_access(attrs, k); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..60c77d8d0f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.dtg.cc @@ -0,0 +1,137 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "15d26dd1f08092ecc82b725aa9411597" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeKey const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListSize const &v) + : raw_variant(v) {} +OperatorAttributeExpr::OperatorAttributeExpr( + ::FlexFlow::OperatorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool OperatorAttributeExpr::operator==( + OperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeExpr::operator!=( + OperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeExpr::operator<( + OperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeExpr::operator>( + OperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeExpr::operator<=( + OperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeExpr::operator>=( + OperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeExpr>::operator()( + ::FlexFlow::OperatorAttributeExpr const &x) const { + return std::hash< + std::variant<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OperatorAttributeListSize, + ::FlexFlow::OperatorAttributeListIndexAccess>>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeExpr + adl_serializer<::FlexFlow::OperatorAttributeExpr>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value").template get<::FlexFlow::OperatorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::OperatorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::OperatorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeExpr>::to_json( + json &j, ::FlexFlow::OperatorAttributeExpr const &x) { + j["__type"] = "OperatorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::OperatorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::OperatorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc new file mode 100644 index 0000000000..a24e1c12e4 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_key.dtg.cc @@ -0,0 +1,505 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "e637388397720b328b1f4b9ba6b14611" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + return "OP_TYPE"; + case OperatorAttributeKey::USE_BIAS: + return "USE_BIAS"; + case OperatorAttributeKey::GROUPS: + return "GROUPS"; + case OperatorAttributeKey::POOL_TYPE: + return "POOL_TYPE"; + case OperatorAttributeKey::KERNEL_H: + return "KERNEL_H"; + case OperatorAttributeKey::KERNEL_W: + return "KERNEL_W"; + case OperatorAttributeKey::DATA_TYPE: + return "DATA_TYPE"; + case OperatorAttributeKey::SCALAR: + return "SCALAR"; + case OperatorAttributeKey::STRIDE_H: + return "STRIDE_H"; + case OperatorAttributeKey::STRIDE_W: + return "STRIDE_W"; + case OperatorAttributeKey::PADDING_H: + return "PADDING_H"; + case OperatorAttributeKey::PADDING_W: + return "PADDING_W"; + case OperatorAttributeKey::AGGR: + return "AGGR"; + case OperatorAttributeKey::NUM_ENTRIES: + return "NUM_ENTRIES"; + case OperatorAttributeKey::OUT_CHANNELS: + return "OUT_CHANNELS"; + case OperatorAttributeKey::ACTIVATION: + return "ACTIVATION"; + case OperatorAttributeKey::NUMDIM: + return "NUMDIM"; + case OperatorAttributeKey::AXIS: + return "AXIS"; + case OperatorAttributeKey::PERMUTATION: + return "PERMUTATION"; + case OperatorAttributeKey::OUTSHUFFLE: + return "OUTSHUFFLE"; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + return "MERGE_GCONV_COUNT"; + case OperatorAttributeKey::AXES: + return "AXES"; + case OperatorAttributeKey::KEEP_DIMS: + return "KEEP_DIMS"; + case OperatorAttributeKey::EPSILON: + return "EPSILON"; + case OperatorAttributeKey::PARALLEL_OP_DIM: + return "PARALLEL_OP_DIM"; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + return "PARALLEL_OP_DEGREE"; + case OperatorAttributeKey::SOFTMAX_DIM: + return "SOFTMAX_DIM"; + case OperatorAttributeKey::NUM_HEADS: + return "NUM_HEADS"; + case OperatorAttributeKey::PARALLEL_DIM: + return "PARALLEL_DIM"; + case OperatorAttributeKey::PARALLEL_DEGREE: + return "PARALLEL_DEGREE"; + case OperatorAttributeKey::PAD: + return "PAD"; + case OperatorAttributeKey::EMBED_DIM: + return "EMBED_DIM"; + case OperatorAttributeKey::KDIM: + return "KDIM"; + case OperatorAttributeKey::VDIM: + return "VDIM"; + case OperatorAttributeKey::DROPOUT: + return "DROPOUT"; + case OperatorAttributeKey::BIAS: + return "BIAS"; + case OperatorAttributeKey::ADD_BIAS_KV: + return "ADD_BIAS_KV"; + case OperatorAttributeKey::ADD_ZERO_ATTN: + return "ADD_ZERO_ATTN"; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + return "A_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + return "B_SEQ_LENGTH_DIM"; + case OperatorAttributeKey::RELU: + return "RELU"; + case OperatorAttributeKey::TARGET_DIMS: + return "TARGET_DIMS"; + case OperatorAttributeKey::RATE: + return "RATE"; + case OperatorAttributeKey::SEED: + return "SEED"; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + return "SHOULD_BROADCAST_LHS"; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + return "SHOULD_BROADCAST_RHS"; + case OperatorAttributeKey::DIM: + return "DIM"; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + return "ELEMENTWISE_AFFINE"; + case OperatorAttributeKey::REGULARIZER: + return "REGULARIZER"; + case OperatorAttributeKey::SHAPE: + return "SHAPE"; + case OperatorAttributeKey::SPLITS: + return "SPLITS"; + case OperatorAttributeKey::K: + return "K"; + case OperatorAttributeKey::SORTED: + return "SORTED"; + case OperatorAttributeKey::COMBINE_DIM: + return "COMBINE_DIM"; + case OperatorAttributeKey::COMBINE_DEGREE: + return "COMBINE_DEGREE"; + case OperatorAttributeKey::NUM_INPUTS: + return "NUM_INPUTS"; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, OperatorAttributeKey x) { + switch (x) { + case OperatorAttributeKey::OP_TYPE: + j = "OP_TYPE"; + break; + case OperatorAttributeKey::USE_BIAS: + j = "USE_BIAS"; + break; + case OperatorAttributeKey::GROUPS: + j = "GROUPS"; + break; + case OperatorAttributeKey::POOL_TYPE: + j = "POOL_TYPE"; + break; + case OperatorAttributeKey::KERNEL_H: + j = "KERNEL_H"; + break; + case OperatorAttributeKey::KERNEL_W: + j = "KERNEL_W"; + break; + case OperatorAttributeKey::DATA_TYPE: + j = "DATA_TYPE"; + break; + case OperatorAttributeKey::SCALAR: + j = "SCALAR"; + break; + case OperatorAttributeKey::STRIDE_H: + j = "STRIDE_H"; + break; + case OperatorAttributeKey::STRIDE_W: + j = "STRIDE_W"; + break; + case OperatorAttributeKey::PADDING_H: + j = "PADDING_H"; + break; + case OperatorAttributeKey::PADDING_W: + j = "PADDING_W"; + break; + case OperatorAttributeKey::AGGR: + j = "AGGR"; + break; + case OperatorAttributeKey::NUM_ENTRIES: + j = "NUM_ENTRIES"; + break; + case OperatorAttributeKey::OUT_CHANNELS: + j = "OUT_CHANNELS"; + break; + case OperatorAttributeKey::ACTIVATION: + j = "ACTIVATION"; + break; + case OperatorAttributeKey::NUMDIM: + j = "NUMDIM"; + break; + case OperatorAttributeKey::AXIS: + j = "AXIS"; + break; + case OperatorAttributeKey::PERMUTATION: + j = "PERMUTATION"; + break; + case OperatorAttributeKey::OUTSHUFFLE: + j = "OUTSHUFFLE"; + break; + case OperatorAttributeKey::MERGE_GCONV_COUNT: + j = "MERGE_GCONV_COUNT"; + break; + case OperatorAttributeKey::AXES: + j = "AXES"; + break; + case OperatorAttributeKey::KEEP_DIMS: + j = "KEEP_DIMS"; + break; + case OperatorAttributeKey::EPSILON: + j = "EPSILON"; + break; + case OperatorAttributeKey::PARALLEL_OP_DIM: + j = "PARALLEL_OP_DIM"; + break; + case OperatorAttributeKey::PARALLEL_OP_DEGREE: + j = "PARALLEL_OP_DEGREE"; + break; + case OperatorAttributeKey::SOFTMAX_DIM: + j = "SOFTMAX_DIM"; + break; + case OperatorAttributeKey::NUM_HEADS: + j = "NUM_HEADS"; + break; + case OperatorAttributeKey::PARALLEL_DIM: + j = "PARALLEL_DIM"; + break; + case OperatorAttributeKey::PARALLEL_DEGREE: + j = "PARALLEL_DEGREE"; + break; + case OperatorAttributeKey::PAD: + j = "PAD"; + break; + case OperatorAttributeKey::EMBED_DIM: + j = "EMBED_DIM"; + break; + case OperatorAttributeKey::KDIM: + j = "KDIM"; + break; + case OperatorAttributeKey::VDIM: + j = "VDIM"; + break; + case OperatorAttributeKey::DROPOUT: + j = "DROPOUT"; + break; + case OperatorAttributeKey::BIAS: + j = "BIAS"; + break; + case OperatorAttributeKey::ADD_BIAS_KV: + j = "ADD_BIAS_KV"; + break; + case OperatorAttributeKey::ADD_ZERO_ATTN: + j = "ADD_ZERO_ATTN"; + break; + case OperatorAttributeKey::A_SEQ_LENGTH_DIM: + j = "A_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::B_SEQ_LENGTH_DIM: + j = "B_SEQ_LENGTH_DIM"; + break; + case OperatorAttributeKey::RELU: + j = "RELU"; + break; + case OperatorAttributeKey::TARGET_DIMS: + j = "TARGET_DIMS"; + break; + case OperatorAttributeKey::RATE: + j = "RATE"; + break; + case OperatorAttributeKey::SEED: + j = "SEED"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_LHS: + j = "SHOULD_BROADCAST_LHS"; + break; + case OperatorAttributeKey::SHOULD_BROADCAST_RHS: + j = "SHOULD_BROADCAST_RHS"; + break; + case OperatorAttributeKey::DIM: + j = "DIM"; + break; + case OperatorAttributeKey::ELEMENTWISE_AFFINE: + j = "ELEMENTWISE_AFFINE"; + break; + case OperatorAttributeKey::REGULARIZER: + j = "REGULARIZER"; + break; + case OperatorAttributeKey::SHAPE: + j = "SHAPE"; + break; + case OperatorAttributeKey::SPLITS: + j = "SPLITS"; + break; + case OperatorAttributeKey::K: + j = "K"; + break; + case OperatorAttributeKey::SORTED: + j = "SORTED"; + break; + case OperatorAttributeKey::COMBINE_DIM: + j = "COMBINE_DIM"; + break; + case OperatorAttributeKey::COMBINE_DEGREE: + j = "COMBINE_DEGREE"; + break; + case OperatorAttributeKey::NUM_INPUTS: + j = "NUM_INPUTS"; + break; + default: + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, OperatorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "OP_TYPE") { + x = OperatorAttributeKey::OP_TYPE; + } else if (as_str == "USE_BIAS") { + x = OperatorAttributeKey::USE_BIAS; + } else if (as_str == "GROUPS") { + x = OperatorAttributeKey::GROUPS; + } else if (as_str == "POOL_TYPE") { + x = OperatorAttributeKey::POOL_TYPE; + } else if (as_str == "KERNEL_H") { + x = OperatorAttributeKey::KERNEL_H; + } else if (as_str == "KERNEL_W") { + x = OperatorAttributeKey::KERNEL_W; + } else if (as_str == "DATA_TYPE") { + x = OperatorAttributeKey::DATA_TYPE; + } else if (as_str == "SCALAR") { + x = OperatorAttributeKey::SCALAR; + } else if (as_str == "STRIDE_H") { + x = OperatorAttributeKey::STRIDE_H; + } else if (as_str == "STRIDE_W") { + x = OperatorAttributeKey::STRIDE_W; + } else if (as_str == "PADDING_H") { + x = OperatorAttributeKey::PADDING_H; + } else if (as_str == "PADDING_W") { + x = OperatorAttributeKey::PADDING_W; + } else if (as_str == "AGGR") { + x = OperatorAttributeKey::AGGR; + } else if (as_str == "NUM_ENTRIES") { + x = OperatorAttributeKey::NUM_ENTRIES; + } else if (as_str == "OUT_CHANNELS") { + x = OperatorAttributeKey::OUT_CHANNELS; + } else if (as_str == "ACTIVATION") { + x = OperatorAttributeKey::ACTIVATION; + } else if (as_str == "NUMDIM") { + x = OperatorAttributeKey::NUMDIM; + } else if (as_str == "AXIS") { + x = OperatorAttributeKey::AXIS; + } else if (as_str == "PERMUTATION") { + x = OperatorAttributeKey::PERMUTATION; + } else if (as_str == "OUTSHUFFLE") { + x = OperatorAttributeKey::OUTSHUFFLE; + } else if (as_str == "MERGE_GCONV_COUNT") { + x = OperatorAttributeKey::MERGE_GCONV_COUNT; + } else if (as_str == "AXES") { + x = OperatorAttributeKey::AXES; + } else if (as_str == "KEEP_DIMS") { + x = OperatorAttributeKey::KEEP_DIMS; + } else if (as_str == "EPSILON") { + x = OperatorAttributeKey::EPSILON; + } else if (as_str == "PARALLEL_OP_DIM") { + x = OperatorAttributeKey::PARALLEL_OP_DIM; + } else if (as_str == "PARALLEL_OP_DEGREE") { + x = OperatorAttributeKey::PARALLEL_OP_DEGREE; + } else if (as_str == "SOFTMAX_DIM") { + x = OperatorAttributeKey::SOFTMAX_DIM; + } else if (as_str == "NUM_HEADS") { + x = OperatorAttributeKey::NUM_HEADS; + } else if (as_str == "PARALLEL_DIM") { + x = OperatorAttributeKey::PARALLEL_DIM; + } else if (as_str == "PARALLEL_DEGREE") { + x = OperatorAttributeKey::PARALLEL_DEGREE; + } else if (as_str == "PAD") { + x = OperatorAttributeKey::PAD; + } else if (as_str == "EMBED_DIM") { + x = OperatorAttributeKey::EMBED_DIM; + } else if (as_str == "KDIM") { + x = OperatorAttributeKey::KDIM; + } else if (as_str == "VDIM") { + x = OperatorAttributeKey::VDIM; + } else if (as_str == "DROPOUT") { + x = OperatorAttributeKey::DROPOUT; + } else if (as_str == "BIAS") { + x = OperatorAttributeKey::BIAS; + } else if (as_str == "ADD_BIAS_KV") { + x = OperatorAttributeKey::ADD_BIAS_KV; + } else if (as_str == "ADD_ZERO_ATTN") { + x = OperatorAttributeKey::ADD_ZERO_ATTN; + } else if (as_str == "A_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::A_SEQ_LENGTH_DIM; + } else if (as_str == "B_SEQ_LENGTH_DIM") { + x = OperatorAttributeKey::B_SEQ_LENGTH_DIM; + } else if (as_str == "RELU") { + x = OperatorAttributeKey::RELU; + } else if (as_str == "TARGET_DIMS") { + x = OperatorAttributeKey::TARGET_DIMS; + } else if (as_str == "RATE") { + x = OperatorAttributeKey::RATE; + } else if (as_str == "SEED") { + x = OperatorAttributeKey::SEED; + } else if (as_str == "SHOULD_BROADCAST_LHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_LHS; + } else if (as_str == "SHOULD_BROADCAST_RHS") { + x = OperatorAttributeKey::SHOULD_BROADCAST_RHS; + } else if (as_str == "DIM") { + x = OperatorAttributeKey::DIM; + } else if (as_str == "ELEMENTWISE_AFFINE") { + x = OperatorAttributeKey::ELEMENTWISE_AFFINE; + } else if (as_str == "REGULARIZER") { + x = OperatorAttributeKey::REGULARIZER; + } else if (as_str == "SHAPE") { + x = OperatorAttributeKey::SHAPE; + } else if (as_str == "SPLITS") { + x = OperatorAttributeKey::SPLITS; + } else if (as_str == "K") { + x = OperatorAttributeKey::K; + } else if (as_str == "SORTED") { + x = OperatorAttributeKey::SORTED; + } else if (as_str == "COMBINE_DIM") { + x = OperatorAttributeKey::COMBINE_DIM; + } else if (as_str == "COMBINE_DEGREE") { + x = OperatorAttributeKey::COMBINE_DEGREE; + } else if (as_str == "NUM_INPUTS") { + x = OperatorAttributeKey::NUM_INPUTS; + } else { + std::ostringstream oss; + oss << "Unknown OperatorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::OperatorAttributeKey::OP_TYPE, + FlexFlow::OperatorAttributeKey::USE_BIAS, + FlexFlow::OperatorAttributeKey::GROUPS, + FlexFlow::OperatorAttributeKey::POOL_TYPE, + FlexFlow::OperatorAttributeKey::KERNEL_H, + FlexFlow::OperatorAttributeKey::KERNEL_W, + FlexFlow::OperatorAttributeKey::DATA_TYPE, + FlexFlow::OperatorAttributeKey::SCALAR, + FlexFlow::OperatorAttributeKey::STRIDE_H, + FlexFlow::OperatorAttributeKey::STRIDE_W, + FlexFlow::OperatorAttributeKey::PADDING_H, + FlexFlow::OperatorAttributeKey::PADDING_W, + FlexFlow::OperatorAttributeKey::AGGR, + FlexFlow::OperatorAttributeKey::NUM_ENTRIES, + FlexFlow::OperatorAttributeKey::OUT_CHANNELS, + FlexFlow::OperatorAttributeKey::ACTIVATION, + FlexFlow::OperatorAttributeKey::NUMDIM, + FlexFlow::OperatorAttributeKey::AXIS, + FlexFlow::OperatorAttributeKey::PERMUTATION, + FlexFlow::OperatorAttributeKey::OUTSHUFFLE, + FlexFlow::OperatorAttributeKey::MERGE_GCONV_COUNT, + FlexFlow::OperatorAttributeKey::AXES, + FlexFlow::OperatorAttributeKey::KEEP_DIMS, + FlexFlow::OperatorAttributeKey::EPSILON, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_OP_DEGREE, + FlexFlow::OperatorAttributeKey::SOFTMAX_DIM, + FlexFlow::OperatorAttributeKey::NUM_HEADS, + FlexFlow::OperatorAttributeKey::PARALLEL_DIM, + FlexFlow::OperatorAttributeKey::PARALLEL_DEGREE, + FlexFlow::OperatorAttributeKey::PAD, + FlexFlow::OperatorAttributeKey::EMBED_DIM, + FlexFlow::OperatorAttributeKey::KDIM, + FlexFlow::OperatorAttributeKey::VDIM, + FlexFlow::OperatorAttributeKey::DROPOUT, + FlexFlow::OperatorAttributeKey::BIAS, + FlexFlow::OperatorAttributeKey::ADD_BIAS_KV, + FlexFlow::OperatorAttributeKey::ADD_ZERO_ATTN, + FlexFlow::OperatorAttributeKey::A_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::B_SEQ_LENGTH_DIM, + FlexFlow::OperatorAttributeKey::RELU, + FlexFlow::OperatorAttributeKey::TARGET_DIMS, + FlexFlow::OperatorAttributeKey::RATE, + FlexFlow::OperatorAttributeKey::SEED, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_LHS, + FlexFlow::OperatorAttributeKey::SHOULD_BROADCAST_RHS, + FlexFlow::OperatorAttributeKey::DIM, + FlexFlow::OperatorAttributeKey::ELEMENTWISE_AFFINE, + FlexFlow::OperatorAttributeKey::REGULARIZER, + FlexFlow::OperatorAttributeKey::SHAPE, + FlexFlow::OperatorAttributeKey::SPLITS, + FlexFlow::OperatorAttributeKey::K, + FlexFlow::OperatorAttributeKey::SORTED, + FlexFlow::OperatorAttributeKey::COMBINE_DIM, + FlexFlow::OperatorAttributeKey::COMBINE_DEGREE, + FlexFlow::OperatorAttributeKey::NUM_INPUTS); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..71b71d4a51 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_access.dtg.cc @@ -0,0 +1,101 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "1dc90d1e823f05b82c1a5ff433fbf000" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListIndexAccess::OperatorAttributeListIndexAccess( + ::FlexFlow::OperatorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool OperatorAttributeListIndexAccess::operator==( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator!=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator<=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool OperatorAttributeListIndexAccess::operator>=( + OperatorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListIndexAccess const &v) { + j["__type"] = "OperatorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>(), + gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OperatorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..eb7ae28131 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_list_size.dtg.cc @@ -0,0 +1,88 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "30999ad6b0603e380bc33d32fa088e45" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_list_size.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include + +namespace FlexFlow { +OperatorAttributeListSize::OperatorAttributeListSize( + ::FlexFlow::OperatorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool OperatorAttributeListSize::operator==( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator!=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator<=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool OperatorAttributeListSize::operator>=( + OperatorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributeListSize + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_key").template get<::FlexFlow::OperatorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributeListSize const &v) { + j["__type"] = "OperatorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::OperatorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(OperatorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..5eaf54bb5f --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_pattern.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "968d7a3e93303a7fa7482bbcd50246b6" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" +#include "utils/fmt.h" +#include +#include + +namespace FlexFlow { +OperatorAttributePattern::OperatorAttributePattern( + std::unordered_set<::FlexFlow::OperatorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool OperatorAttributePattern::operator==( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool OperatorAttributePattern::operator!=( + OperatorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OperatorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::OperatorAttributePattern + adl_serializer::from_json( + json const &j) { + return { + j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::OperatorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::OperatorAttributePattern const &v) { + j["__type"] = "OperatorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(OperatorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OperatorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc new file mode 100644 index 0000000000..376a9c2ce8 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_value.dtg.cc @@ -0,0 +1,292 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "de14592f1f4bcfb52689bc95e9d3b55f" +} +*/ + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +OperatorAttributeValue::OperatorAttributeValue(int const &v) : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(bool const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(std::vector const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::vector<::FlexFlow::ff_dim_t> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + ::FlexFlow::OperatorType const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::Activation const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::ff_dim_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(size_t const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::AggregateOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue( + std::optional<::FlexFlow::RegularizerAttrs> const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::PoolOp const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::TensorShape const &v) + : raw_variant(v) {} +OperatorAttributeValue::OperatorAttributeValue(::FlexFlow::DataType const &v) + : raw_variant(v) {} +bool OperatorAttributeValue::operator==( + OperatorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OperatorAttributeValue::operator!=( + OperatorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OperatorAttributeValue::operator<( + OperatorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OperatorAttributeValue::operator>( + OperatorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OperatorAttributeValue::operator<=( + OperatorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OperatorAttributeValue::operator>=( + OperatorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OperatorAttributeValue>::operator()( + ::FlexFlow::OperatorAttributeValue const &x) const { + return std::hash, + std::vector<::FlexFlow::ff_dim_t>, + ::FlexFlow::OperatorType, + ::FlexFlow::Activation, + ::FlexFlow::ff_dim_t, + size_t, + ::FlexFlow::AggregateOp, + std::optional<::FlexFlow::RegularizerAttrs>, + ::FlexFlow::PoolOp, + ::FlexFlow::TensorShape, + ::FlexFlow::DataType>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::OperatorAttributeValue + adl_serializer<::FlexFlow::OperatorAttributeValue>::from_json( + json const &j) { + std::string key = j.at("type").template get(); + if (key == "int") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "bool") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "std::vector<::FlexFlow::ff_dim_t>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get>()}; + } else if (key == "::FlexFlow::OperatorType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::OperatorType>()}; + } else if (key == "::FlexFlow::Activation") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::Activation>()}; + } else if (key == "::FlexFlow::ff_dim_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::ff_dim_t>()}; + } else if (key == "size_t") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get()}; + } else if (key == "::FlexFlow::AggregateOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::AggregateOp>()}; + } else if (key == "std::optional<::FlexFlow::RegularizerAttrs>") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value") + .template get>()}; + } else if (key == "::FlexFlow::PoolOp") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::PoolOp>()}; + } else if (key == "::FlexFlow::TensorShape") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::TensorShape>()}; + } else if (key == "::FlexFlow::DataType") { + return ::FlexFlow::OperatorAttributeValue{ + j.at("value").template get<::FlexFlow::DataType>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::OperatorAttributeValue>::to_json( + json &j, ::FlexFlow::OperatorAttributeValue const &x) { + j["__type"] = "OperatorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "int"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "bool"; + j["value"] = x.get(); + break; + } + case 2: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + case 3: { + j["type"] = "std::vector<::FlexFlow::ff_dim_t>"; + j["value"] = x.get>(); + break; + } + case 4: { + j["type"] = "::FlexFlow::OperatorType"; + j["value"] = x.get<::FlexFlow::OperatorType>(); + break; + } + case 5: { + j["type"] = "::FlexFlow::Activation"; + j["value"] = x.get<::FlexFlow::Activation>(); + break; + } + case 6: { + j["type"] = "::FlexFlow::ff_dim_t"; + j["value"] = x.get<::FlexFlow::ff_dim_t>(); + break; + } + case 7: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 8: { + j["type"] = "::FlexFlow::AggregateOp"; + j["value"] = x.get<::FlexFlow::AggregateOp>(); + break; + } + case 9: { + j["type"] = "std::optional<::FlexFlow::RegularizerAttrs>"; + j["value"] = x.get>(); + break; + } + case 10: { + j["type"] = "::FlexFlow::PoolOp"; + j["value"] = x.get<::FlexFlow::PoolOp>(); + break; + } + case 11: { + j["type"] = "::FlexFlow::TensorShape"; + j["value"] = x.get<::FlexFlow::TensorShape>(); + break; + } + case 12: { + j["type"] = "::FlexFlow::DataType"; + j["value"] = x.get<::FlexFlow::DataType>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::OperatorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << "=" + << x.get>() << ">"; + break; + } + case 3: { + oss << "=" + << x.get>() << ">"; + break; + } + case 4: { + oss << ""; + break; + } + case 5: { + oss << ""; + break; + } + case 6: { + oss << ""; + break; + } + case 7: { + oss << ""; + break; + } + case 8: { + oss << ""; + break; + } + case 9: { + oss << "=" + << x.get>() << ">"; + break; + } + case 10: { + oss << ""; + break; + } + case 11: { + oss << ""; + break; + } + case 12: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OperatorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OperatorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..ae42515cc8 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -0,0 +1,26 @@ +#include "substitutions/operator_pattern/satisfies_constraint.h" +#include "substitutions/operator_pattern/operator_attribute_expr.h" + +namespace FlexFlow { + +bool operator_satisfies_constraint( + PCGOperatorAttrs const &attrs, + OperatorAttributeConstraint const &constraint) { + std::optional expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); + + if (!expr_val.has_value()) { + return false; + } + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val.value() == constraint.attribute_value; + default: + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..60ab363cc6 --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_pattern.cc @@ -0,0 +1,14 @@ +#include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/operator_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool operator_satisfies_pattern(PCGOperatorAttrs const &attrs, + OperatorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](OperatorAttributeConstraint const &c) { + return operator_satisfies_constraint(attrs, c); + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc new file mode 100644 index 0000000000..f20afc1164 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/attr_constant.dtg.cc @@ -0,0 +1,59 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/attr_constant.struct.toml +/* proj-data +{ + "generated_from": "1e5beabcb8e3657d8fe9c9c8b1310cb1" +} +*/ + +#include "substitutions/output_graph/attr_constant.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" +#include + +namespace FlexFlow { +AttrConstant::AttrConstant(::FlexFlow::OperatorAttributeValue const &value) + : value(value) {} +bool AttrConstant::operator==(AttrConstant const &other) const { + return std::tie(this->value) == std::tie(other.value); +} +bool AttrConstant::operator!=(AttrConstant const &other) const { + return std::tie(this->value) != std::tie(other.value); +} +bool AttrConstant::operator<(AttrConstant const &other) const { + return std::tie(this->value) < std::tie(other.value); +} +bool AttrConstant::operator>(AttrConstant const &other) const { + return std::tie(this->value) > std::tie(other.value); +} +bool AttrConstant::operator<=(AttrConstant const &other) const { + return std::tie(this->value) <= std::tie(other.value); +} +bool AttrConstant::operator>=(AttrConstant const &other) const { + return std::tie(this->value) >= std::tie(other.value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AttrConstant const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OperatorAttributeValue>{}(x.value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(AttrConstant const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, AttrConstant const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc new file mode 100644 index 0000000000..7d07bf9218 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.dtg.cc @@ -0,0 +1,20 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +/* proj-data +{ + "generated_from": "9084c9afb2724504a6f4db4288a83a0d" +} +*/ + +#include "substitutions/output_graph/output_graph_expr.dtg.h" + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +OutputGraphExpr::OutputGraphExpr( + ::FlexFlow::NodeLabelledOpenMultiDiGraph< + ::FlexFlow::OutputOperatorAttrsAssignment> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc new file mode 100644 index 0000000000..0c6abc925d --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attr_access.dtg.cc @@ -0,0 +1,77 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +/* proj-data +{ + "generated_from": "e3b3a741183fcb38cfa68aacb82e12d1" +} +*/ + +#include "substitutions/output_graph/output_operator_attr_access.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_expr.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +OutputOperatorAttrAccess::OutputOperatorAttrAccess( + ::FlexFlow::Node const &node, + ::FlexFlow::OperatorAttributeExpr const &attr_expr) + : node(node), attr_expr(attr_expr) {} +bool OutputOperatorAttrAccess::operator==( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) == + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator!=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) != + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) < + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) > + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator<=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) <= + std::tie(other.node, other.attr_expr); +} +bool OutputOperatorAttrAccess::operator>=( + OutputOperatorAttrAccess const &other) const { + return std::tie(this->node, this->attr_expr) >= + std::tie(other.node, other.attr_expr); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.node) + 0x9e3779b9 + (result << 6) + + (result >> 2); + result ^= std::hash<::FlexFlow::OperatorAttributeExpr>{}(x.attr_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OutputOperatorAttrAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc new file mode 100644 index 0000000000..bf1b07c825 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.dtg.cc @@ -0,0 +1,79 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "89ebf777a5b909eef78ab5a5a177e041" +} +*/ + +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" + +#include + +namespace FlexFlow { +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::OutputOperatorAttrAccess const &v) + : raw_variant(v) {} +OutputOperatorAttributeExpr::OutputOperatorAttributeExpr( + ::FlexFlow::AttrConstant const &v) + : raw_variant(v) {} +bool OutputOperatorAttributeExpr::operator==( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator!=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator<=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool OutputOperatorAttributeExpr::operator>=( + OutputOperatorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::OutputOperatorAttributeExpr>::operator()( + ::FlexFlow::OutputOperatorAttributeExpr const &x) const { + return std::hash>{}(x.raw_variant); +} +} // namespace std +namespace FlexFlow { +std::string format_as(::FlexFlow::OutputOperatorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type OutputOperatorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::OutputOperatorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc new file mode 100644 index 0000000000..7a1950482a --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.dtg.cc @@ -0,0 +1,58 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +/* proj-data +{ + "generated_from": "bbfb309c5a39a729da23dace4df4a9de" +} +*/ + +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/output_graph/output_operator_attribute_expr.dtg.h" +#include +#include + +namespace FlexFlow { +OutputOperatorAttrsAssignment::OutputOperatorAttrsAssignment( + std::unordered_map<::FlexFlow::OperatorAttributeKey, + ::FlexFlow::OutputOperatorAttributeExpr> const + &assignments) + : assignments(assignments) {} +bool OutputOperatorAttrsAssignment::operator==( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) == std::tie(other.assignments); +} +bool OutputOperatorAttrsAssignment::operator!=( + OutputOperatorAttrsAssignment const &other) const { + return std::tie(this->assignments) != std::tie(other.assignments); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputOperatorAttrsAssignment const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.assignments) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace FlexFlow { +std::string format_as(OutputOperatorAttrsAssignment const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + OutputOperatorAttrsAssignment const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc new file mode 100644 index 0000000000..7133ab42a7 --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern.dtg.cc @@ -0,0 +1,21 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/pcg_pattern.struct.toml +/* proj-data +{ + "generated_from": "f536f846828ba39266dd4a1fbaeec0e6" +} +*/ + +#include "substitutions/pcg_pattern.dtg.h" + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +PCGPattern::PCGPattern(::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::OperatorAttributePattern, + ::FlexFlow::TensorAttributePattern> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc new file mode 100644 index 0000000000..7736113819 --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -0,0 +1,22 @@ +#include "substitutions/sub_parallel_computation_graph.h" + +namespace FlexFlow { + +ParallelLayerAttrs + get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { + return spcg.raw_graph.at(n); +} + +PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, + Node const &n) { + return get_parallel_layer_attrs(spcg, n).attrs; +} + +ParallelTensorAttrs + get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, + OpenMultiDiEdge const &e) { + return spcg.raw_graph.at(e); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc new file mode 100644 index 0000000000..83baef2cfc --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.dtg.cc @@ -0,0 +1,22 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/sub_parallel_computation_graph.struct.toml +/* proj-data +{ + "generated_from": "0022d1b2c1447667695a120c154a0168" +} +*/ + +#include "substitutions/sub_parallel_computation_graph.dtg.h" + +#include "pcg/parallel_layer_attrs.dtg.h" +#include "pcg/parallel_tensor_attrs.dtg.h" +#include "utils/graph.h" + +namespace FlexFlow { +SubParallelComputationGraph::SubParallelComputationGraph( + ::FlexFlow::OutputLabelledOpenMultiDiGraph< + ::FlexFlow::ParallelLayerAttrs, + ::FlexFlow::ParallelTensorAttrs> const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc new file mode 100644 index 0000000000..e900175bc6 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -0,0 +1,154 @@ +#include "substitutions/substitution.h" + +namespace FlexFlow { + +/* struct AddMappedEdgeFunctor { */ +/* bidict const &node_mapping; */ +/* SubParallelComputationGraph &new_pcg; */ + +/* template */ +/* void operator()(T const &t) { */ +/* return add_mapped_edge(t); */ +/* } */ + +/* void add_mapped_edge(InputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(InputMultiDiEdge{ */ +/* node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(OutputMultiDiEdge const &e) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{ */ +/* node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); */ +/* } */ + +/* void add_mapped_edge(MultiDiEdge const &e) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* }; */ + +/* struct AddNewEdgeFunctor { */ +/* SubParallelComputationGraph const &old_pcg; */ +/* SubParallelComputationGraph &new_pcg; */ +/* MultiDiGraphPatternMatch const &match; */ +/* bidict node_mapping; */ + +/* template */ +/* void operator()(TO const &old_edge, TN const &new_edge) { */ +/* return add_new_edge(old_edge, new_edge); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* InputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(old_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port(), */ +/* old_edge.uid}); */ +/* } */ + +/* void add_new_edge(MultiDiEdge const &old_edge, */ +/* OutputMultiDiEdge const &new_edge) { */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(new_edge.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ + +/* void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { */ +/* assert(false); */ +/* } */ + +/* void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { */ +/* assert(false); */ +/* } */ +/* }; */ + +/* SubParallelComputationGraph */ +/* apply_substitution(SubParallelComputationGraph const &pcg, */ +/* Substitution const &substitution, */ +/* MultiDiGraphPatternMatch const &match) { */ +/* SubParallelComputationGraph new_pcg = */ +/* OutputLabelledOpenMultiDiGraph::template + * create< */ +/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ +/* bidict node_mapping; // Refactor it with global nodes */ +/* for (Node const &node : get_nodes(pcg)) { */ +/* if (!contains_r(match.node_assignment, node)) { */ +/* node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); */ +/* } */ +/* } */ +/* for (OpenMultiDiEdge const &edge : get_edges(pcg)) { */ +/* if (!contains_r(match.edge_assignment, edge)) { */ +/* visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); */ +/* } */ +/* } */ +/* for (Node const &output_node : */ +/* get_nodes(substitution.output_graph_expr.value())) { */ +/* Operator new_op = get_operator_attrs( */ +/* pcg, match, substitution.output_graph_expr.value().at(output_node)); + */ +/* Node new_node = new_pcg.add_node(new_op); */ +/* node_mapping.equate(output_node, new_node); */ +/* } */ +/* for (OpenMultiDiEdge const &output_edge : */ +/* get_edges(substitution.output_graph_expr.value())) { */ +/* if (std::holds_alternative(output_edge)) { */ +/* InputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else if (std::holds_alternative(output_edge)) { */ +/* OutputMultiDiEdge e = std::get(output_edge); */ +/* OpenMultiDiEdge original_edge = */ +/* match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); */ +/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ +/* original_edge, */ +/* output_edge); */ +/* } else { */ +/* assert(std::holds_alternative(output_edge)); */ +/* MultiDiEdge e = std::get(output_edge); */ +/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ +/* new_pcg.add_node_port(), */ +/* node_mapping.at_l(e.src), */ +/* new_pcg.add_node_port()}); */ +/* } */ +/* } */ + +/* return new_pcg; */ +/* } */ + +bool is_valid_substitution(Substitution const &) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraph + apply_substitution(SubParallelComputationGraph const &, + Substitution const &, + MultiDiGraphPatternMatch const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.dtg.cc b/lib/substitutions/src/substitutions/substitution.dtg.cc new file mode 100644 index 0000000000..67d39d6ff7 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution.dtg.cc @@ -0,0 +1,28 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/substitution.struct.toml +/* proj-data +{ + "generated_from": "c101f1d63e2d8d80a0ec9c5f5db4fa12" +} +*/ + +#include "substitutions/substitution.dtg.h" + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" + +namespace FlexFlow { +Substitution::Substitution( + ::FlexFlow::PCGPattern const &pcg_pattern, + ::FlexFlow::OutputGraphExpr const &output_graph_expr, + ::FlexFlow::bidict<::FlexFlow::InputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge> const + &input_edge_match_to_output, + ::FlexFlow::bidict<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::OutputMultiDiEdge> const + &output_edge_match_to_output) + : pcg_pattern(pcg_pattern), output_graph_expr(output_graph_expr), + input_edge_match_to_output(input_edge_match_to_output), + output_edge_match_to_output(output_edge_match_to_output) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc new file mode 100644 index 0000000000..ea4833d36a --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_access.cc @@ -0,0 +1,24 @@ +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue + eval_list_access(ParallelTensorAttrs const &attrs, + TensorAttributeListIndexAccess const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit(overload{ + [&](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{ + static_cast(at_idx(v, acc.index).value())}; + }, + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc new file mode 100644 index 0000000000..d1e97adc37 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/eval_list_size.cc @@ -0,0 +1,21 @@ +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue eval_list_size(ParallelTensorAttrs const &attrs, + TensorAttributeListSize const &acc) { + TensorAttributeValue from_attr = get_attribute(attrs, acc.attribute_key); + + return from_attr.visit(overload{ + [](std::vector const &v) -> TensorAttributeValue { + return TensorAttributeValue{v.size()}; + }, + [](auto &&) -> TensorAttributeValue { + throw mk_runtime_error("Invalid operand"); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc new file mode 100644 index 0000000000..7c42bdd904 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -0,0 +1,28 @@ +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/containers.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, + TensorAttributeKey key) { + switch (key) { + case TensorAttributeKey::DIM_SIZES: { + std::vector sizes = + transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return d.size; }); + return TensorAttributeValue{sizes}; + } + case TensorAttributeKey::DIM_DEGREES: { + std::vector degrees = transform( + as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); + return TensorAttributeValue{degrees}; + } + default: + throw std::runtime_error( + fmt::format("Unknown TensorAttributeKey {}", static_cast(key))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc new file mode 100644 index 0000000000..974bfcabc0 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -0,0 +1,22 @@ +#include "substitutions/tensor_pattern/satisfies_constraint.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_constraint( + ParallelTensorAttrs const &attrs, + TensorAttributeConstraint const &constraint) { + TensorAttributeValue expr_val = + evaluate_attribute_expr(attrs, constraint.attribute_expr); + + switch (constraint.constraint_type) { + case ConstraintType::EQUAL: + return expr_val == constraint.attribute_value; + default: + throw mk_runtime_error( + fmt::format("Unknown constraint type {}", + static_cast(constraint.constraint_type))); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc new file mode 100644 index 0000000000..35fec2dfea --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_pattern.cc @@ -0,0 +1,13 @@ +#include "substitutions/tensor_pattern/satisfies_pattern.h" +#include "substitutions/tensor_pattern/satisfies_constraint.h" + +namespace FlexFlow { + +bool parallel_tensor_satisfies_pattern(ParallelTensorAttrs const &attrs, + TensorAttributePattern const &pattern) { + return all_of(pattern.attribute_constraints, + [&](TensorAttributeConstraint const &c) { + return parallel_tensor_satisfies_constraint(attrs, c); + }); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc new file mode 100644 index 0000000000..6f9df90fb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_constraint.dtg.cc @@ -0,0 +1,119 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_constraint.struct.toml +/* proj-data +{ + "generated_from": "29dbf81668bc864b06af52261060335e" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" + +#include "substitutions/constraint_type.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeConstraint::TensorAttributeConstraint( + ::FlexFlow::ConstraintType const &constraint_type, + ::FlexFlow::TensorAttributeExpr const &attribute_expr, + ::FlexFlow::TensorAttributeValue const &attribute_value) + : constraint_type(constraint_type), attribute_expr(attribute_expr), + attribute_value(attribute_value) {} +bool TensorAttributeConstraint::operator==( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) == std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator!=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) != std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) < std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) > std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator<=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) <= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +bool TensorAttributeConstraint::operator>=( + TensorAttributeConstraint const &other) const { + return std::tie(this->constraint_type, + this->attribute_expr, + this->attribute_value) >= std::tie(other.constraint_type, + other.attribute_expr, + other.attribute_value); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeConstraint const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::ConstraintType>{}(x.constraint_type) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeExpr>{}(x.attribute_expr) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= std::hash<::FlexFlow::TensorAttributeValue>{}(x.attribute_value) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeConstraint + adl_serializer::from_json( + json const &j) { + return { + j.at("constraint_type").template get<::FlexFlow::ConstraintType>(), + j.at("attribute_expr").template get<::FlexFlow::TensorAttributeExpr>(), + j.at("attribute_value").template get<::FlexFlow::TensorAttributeValue>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeConstraint const &v) { + j["__type"] = "TensorAttributeConstraint"; + j["constraint_type"] = v.constraint_type; + j["attribute_expr"] = v.attribute_expr; + j["attribute_value"] = v.attribute_value; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributeConstraint const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeConstraint const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc new file mode 100644 index 0000000000..33bcc1a082 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.cc @@ -0,0 +1,22 @@ +#include "substitutions/tensor_pattern/tensor_attribute_expr.h" +#include "substitutions/tensor_pattern/eval_list_access.h" +#include "substitutions/tensor_pattern/eval_list_size.h" +#include "substitutions/tensor_pattern/get_attribute.h" +#include "utils/overload.h" + +namespace FlexFlow { + +TensorAttributeValue evaluate_attribute_expr(ParallelTensorAttrs const &attrs, + TensorAttributeExpr const &expr) { + + return expr.visit(overload{ + [&](TensorAttributeKey const &key) { return get_attribute(attrs, key); }, + [&](TensorAttributeListSize const &s) { + return eval_list_size(attrs, s); + }, + [&](TensorAttributeListIndexAccess const &s) { + return eval_list_access(attrs, s); + }}); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc new file mode 100644 index 0000000000..a42f18bf26 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_expr.dtg.cc @@ -0,0 +1,129 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_expr.variant.toml +/* proj-data +{ + "generated_from": "b91285329f12f1b409805cbf9be575b2" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_expr.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeKey const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListSize const &v) + : raw_variant(v) {} +TensorAttributeExpr::TensorAttributeExpr( + ::FlexFlow::TensorAttributeListIndexAccess const &v) + : raw_variant(v) {} +bool TensorAttributeExpr::operator==(TensorAttributeExpr const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeExpr::operator!=(TensorAttributeExpr const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeExpr::operator<(TensorAttributeExpr const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeExpr::operator>(TensorAttributeExpr const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeExpr::operator<=(TensorAttributeExpr const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeExpr::operator>=(TensorAttributeExpr const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeExpr>::operator()( + ::FlexFlow::TensorAttributeExpr const &x) const { + return std::hash>{}( + x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeExpr + adl_serializer<::FlexFlow::TensorAttributeExpr>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "key") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeKey>()}; + } else if (key == "list_size") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value").template get<::FlexFlow::TensorAttributeListSize>()}; + } else if (key == "list_idx") { + return ::FlexFlow::TensorAttributeExpr{ + j.at("value") + .template get<::FlexFlow::TensorAttributeListIndexAccess>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeExpr>::to_json( + json &j, ::FlexFlow::TensorAttributeExpr const &x) { + j["__type"] = "TensorAttributeExpr"; + switch (x.index()) { + case 0: { + j["type"] = "key"; + j["value"] = x.get<::FlexFlow::TensorAttributeKey>(); + break; + } + case 1: { + j["type"] = "list_size"; + j["value"] = x.get<::FlexFlow::TensorAttributeListSize>(); + break; + } + case 2: { + j["type"] = "list_idx"; + j["value"] = x.get<::FlexFlow::TensorAttributeListIndexAccess>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeExpr const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << ""; + break; + } + case 2: { + oss << ""; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeExpr", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeExpr const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc new file mode 100644 index 0000000000..fe87c63777 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_key.dtg.cc @@ -0,0 +1,73 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +/* proj-data +{ + "generated_from": "63a7c40c1e5b582f98b59750a35f0a08" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" + +#include +#include + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeKey x) const { + return std::hash{}(static_cast(x)); +} +} // namespace std +namespace FlexFlow { +std::string format_as(TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + return "DIM_SIZES"; + case TensorAttributeKey::DIM_DEGREES: + return "DIM_DEGREES"; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +std::ostream &operator<<(std::ostream &s, TensorAttributeKey x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow +namespace FlexFlow { +void to_json(::nlohmann::json &j, TensorAttributeKey x) { + switch (x) { + case TensorAttributeKey::DIM_SIZES: + j = "DIM_SIZES"; + break; + case TensorAttributeKey::DIM_DEGREES: + j = "DIM_DEGREES"; + break; + default: + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << static_cast(x); + throw std::runtime_error(oss.str()); + } +} +void from_json(::nlohmann::json const &j, TensorAttributeKey &x) { + std::string as_str = j.get(); + if (as_str == "DIM_SIZES") { + x = TensorAttributeKey::DIM_SIZES; + } else if (as_str == "DIM_DEGREES") { + x = TensorAttributeKey::DIM_DEGREES; + } else { + std::ostringstream oss; + oss << "Unknown TensorAttributeKey value " << as_str; + throw std::runtime_error(oss.str()); + } +} +} // namespace FlexFlow +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::element( + FlexFlow::TensorAttributeKey::DIM_SIZES, + FlexFlow::TensorAttributeKey::DIM_DEGREES); +} +} // namespace rc diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc new file mode 100644 index 0000000000..4e28de2c28 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_access.dtg.cc @@ -0,0 +1,99 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_access.struct.toml +/* proj-data +{ + "generated_from": "41f5449cd700b6d7ab017f3efa39dc1d" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_access.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListIndexAccess::TensorAttributeListIndexAccess( + ::FlexFlow::TensorAttributeKey const &attribute_key, int const &index) + : attribute_key(attribute_key), index(index) {} +bool TensorAttributeListIndexAccess::operator==( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) == + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator!=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) != + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) < + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) > + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator<=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) <= + std::tie(other.attribute_key, other.index); +} +bool TensorAttributeListIndexAccess::operator>=( + TensorAttributeListIndexAccess const &other) const { + return std::tie(this->attribute_key, this->index) >= + std::tie(other.attribute_key, other.index); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListIndexAccess const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + result ^= + std::hash{}(x.index) + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListIndexAccess + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>(), + j.at("index").template get()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListIndexAccess const &v) { + j["__type"] = "TensorAttributeListIndexAccess"; + j["attribute_key"] = v.attribute_key; + j["index"] = v.index; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>(), gen::arbitrary()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListIndexAccess const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + TensorAttributeListIndexAccess const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc new file mode 100644 index 0000000000..24d8b6c025 --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_list_size.dtg.cc @@ -0,0 +1,87 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_list_size.struct.toml +/* proj-data +{ + "generated_from": "ec72cd39de5d1c0f0478696d7b83e4e9" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_list_size.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_key.dtg.h" +#include + +namespace FlexFlow { +TensorAttributeListSize::TensorAttributeListSize( + ::FlexFlow::TensorAttributeKey const &attribute_key) + : attribute_key(attribute_key) {} +bool TensorAttributeListSize::operator==( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) == std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator!=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) != std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) < std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) > std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator<=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) <= std::tie(other.attribute_key); +} +bool TensorAttributeListSize::operator>=( + TensorAttributeListSize const &other) const { + return std::tie(this->attribute_key) >= std::tie(other.attribute_key); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributeListSize const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::TensorAttributeKey>{}(x.attribute_key) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributeListSize + adl_serializer::from_json( + json const &j) { + return {j.at("attribute_key").template get<::FlexFlow::TensorAttributeKey>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributeListSize const &v) { + j["__type"] = "TensorAttributeListSize"; + j["attribute_key"] = v.attribute_key; +} +} // namespace nlohmann + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::arbitrary<::FlexFlow::TensorAttributeKey>()); +} +} // namespace rc + +namespace FlexFlow { +std::string format_as(TensorAttributeListSize const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributeListSize const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc new file mode 100644 index 0000000000..121549d4dc --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.dtg.cc @@ -0,0 +1,71 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.struct.toml +/* proj-data +{ + "generated_from": "42a51afce383f1ddc3d70827aa94a68f" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +#include "substitutions/tensor_pattern/tensor_attribute_constraint.dtg.h" +#include "utils/hash-utils.h" +#include +#include + +namespace FlexFlow { +TensorAttributePattern::TensorAttributePattern( + std::unordered_set<::FlexFlow::TensorAttributeConstraint> const + &attribute_constraints) + : attribute_constraints(attribute_constraints) {} +bool TensorAttributePattern::operator==( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) == + std::tie(other.attribute_constraints); +} +bool TensorAttributePattern::operator!=( + TensorAttributePattern const &other) const { + return std::tie(this->attribute_constraints) != + std::tie(other.attribute_constraints); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::TensorAttributePattern const &x) const { + size_t result = 0; + result ^= + std::hash>{}( + x.attribute_constraints) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std + +namespace nlohmann { +FlexFlow::TensorAttributePattern + adl_serializer::from_json(json const &j) { + return {j.at("attribute_constraints") + .template get< + std::unordered_set<::FlexFlow::TensorAttributeConstraint>>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::TensorAttributePattern const &v) { + j["__type"] = "TensorAttributePattern"; + j["attribute_constraints"] = v.attribute_constraints; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(TensorAttributePattern const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, TensorAttributePattern const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc new file mode 100644 index 0000000000..27a82c4ffe --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_value.dtg.cc @@ -0,0 +1,105 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_value.variant.toml +/* proj-data +{ + "generated_from": "d80cf2e618d64df284c2647430a12a86" +} +*/ + +#include "substitutions/tensor_pattern/tensor_attribute_value.dtg.h" + +#include "fmt/format.h" +#include +#include + +namespace FlexFlow { +TensorAttributeValue::TensorAttributeValue(size_t const &v) : raw_variant(v) {} +TensorAttributeValue::TensorAttributeValue(std::vector const &v) + : raw_variant(v) {} +bool TensorAttributeValue::operator==(TensorAttributeValue const &other) const { + return this->raw_variant == other.raw_variant; +} +bool TensorAttributeValue::operator!=(TensorAttributeValue const &other) const { + return this->raw_variant != other.raw_variant; +} +bool TensorAttributeValue::operator<(TensorAttributeValue const &other) const { + return this->raw_variant < other.raw_variant; +} +bool TensorAttributeValue::operator>(TensorAttributeValue const &other) const { + return this->raw_variant > other.raw_variant; +} +bool TensorAttributeValue::operator<=(TensorAttributeValue const &other) const { + return this->raw_variant <= other.raw_variant; +} +bool TensorAttributeValue::operator>=(TensorAttributeValue const &other) const { + return this->raw_variant >= other.raw_variant; +} +} // namespace FlexFlow +namespace std { +size_t hash<::FlexFlow::TensorAttributeValue>::operator()( + ::FlexFlow::TensorAttributeValue const &x) const { + return std::hash>>{}(x.raw_variant); +} +} // namespace std +namespace nlohmann { +::FlexFlow::TensorAttributeValue + adl_serializer<::FlexFlow::TensorAttributeValue>::from_json(json const &j) { + std::string key = j.at("type").template get(); + if (key == "size_t") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get()}; + } else if (key == "std::vector") { + return ::FlexFlow::TensorAttributeValue{ + j.at("value").template get>()}; + } else { + throw std::runtime_error(fmt::format("Unknown type key {}", key)); + } +} +void adl_serializer<::FlexFlow::TensorAttributeValue>::to_json( + json &j, ::FlexFlow::TensorAttributeValue const &x) { + j["__type"] = "TensorAttributeValue"; + switch (x.index()) { + case 0: { + j["type"] = "size_t"; + j["value"] = x.get(); + break; + } + case 1: { + j["type"] = "std::vector"; + j["value"] = x.get>(); + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + } + } +} +} // namespace nlohmann +namespace FlexFlow { +std::string format_as(::FlexFlow::TensorAttributeValue const &x) { + std::ostringstream oss; + switch (x.index()) { + case 0: { + oss << ""; + break; + } + case 1: { + oss << "=" + << x.get>() << ">"; + break; + } + default: { + throw std::runtime_error(fmt::format( + "Unknown index {} for type TensorAttributeValue", x.index())); + break; + } + } + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::TensorAttributeValue const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fbefc6f01a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/closed_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/closed_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "b4086fd78ca7ec0475ed7abfd034304c" +} +*/ + +#include "substitutions/unlabelled/closed_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +ClosedPatternEdge::ClosedPatternEdge(::FlexFlow::MultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool ClosedPatternEdge::operator==(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator!=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator<=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool ClosedPatternEdge::operator>=(ClosedPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::ClosedPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::MultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc new file mode 100644 index 0000000000..704e0aea1a --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/downward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_src_idx(DownwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..30c52fbbb2 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/downward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/downward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "c67ec363a91ce090dc538dcf76fa1f12" +} +*/ + +#include "substitutions/unlabelled/downward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +DownwardOpenPatternEdge::DownwardOpenPatternEdge( + ::FlexFlow::DownwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool DownwardOpenPatternEdge::operator==( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator!=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator<=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool DownwardOpenPatternEdge::operator>=( + DownwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::DownwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::DownwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc new file mode 100644 index 0000000000..33ea7dc9f6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.cc @@ -0,0 +1,35 @@ +#include "substitutions/unlabelled/edge_splits.h" + +namespace FlexFlow { + +std::pair + get_split_edges(UnlabelledPatternEdgeSplits const &splits, + ClosedPatternEdge const &e) { + std::pair raw_result = + splits.unwrapped.at_l(e.raw_edge); + return { + OutputPatternEdge{raw_result.first}, + InputPatternEdge{raw_result.second}, + }; +} + +std::vector> + as_closed_output_input_tuples(UnlabelledPatternEdgeSplits const &s) { + std::vector< + std::tuple> + result; + + for (auto const &kv : s.unwrapped) { + MultiDiEdge standard_edge = kv.first; + OutputMultiDiEdge output_edge = kv.second.first; + InputMultiDiEdge input_edge = kv.second.second; + + result.push_back({ClosedPatternEdge{standard_edge}, + OutputPatternEdge{output_edge}, + InputPatternEdge{input_edge}}); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc new file mode 100644 index 0000000000..4da15179da --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/edge_splits.dtg.cc @@ -0,0 +1,31 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/edge_splits.struct.toml +/* proj-data +{ + "generated_from": "f172b041a99f4de1d396e5d451a5e64d" +} +*/ + +#include "substitutions/unlabelled/edge_splits.dtg.h" + +#include "utils/bidict.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +UnlabelledPatternEdgeSplits::UnlabelledPatternEdgeSplits( + ::FlexFlow::bidict<::FlexFlow::MultiDiEdge, + std::pair<::FlexFlow::OutputMultiDiEdge, + ::FlexFlow::InputMultiDiEdge>> const + &unwrapped) + : unwrapped(unwrapped) {} +bool UnlabelledPatternEdgeSplits::operator==( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) == std::tie(other.unwrapped); +} +bool UnlabelledPatternEdgeSplits::operator!=( + UnlabelledPatternEdgeSplits const &other) const { + return std::tie(this->unwrapped) != std::tie(other.unwrapped); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc new file mode 100644 index 0000000000..8c787ca255 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -0,0 +1,161 @@ +#include "substitutions/unlabelled/find_pattern_matches.h" +#include "substitutions/unlabelled/downward_open_pattern_edge.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "substitutions/unlabelled/upward_open_pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); +} + +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenPatternEdge const &e) { return get_src_idx(e); })); +} + +static std::vector + sorted_by_dst_idx(std::unordered_set const &in) { + return sorted_by( + in, compare_by([](UpwardOpenPatternEdge const &e) { + return get_dst_idx(e); + })); +} + +static std::vector + sorted_by_src_idx(std::unordered_set const &in) { + return sorted_by( + in, + compare_by( + [](DownwardOpenMultiDiEdge const &e) { return get_src_idx(e); })); +} + +static std::optional + get_candidate_singleton_match(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + Node const &graph_node) { + assert(is_singleton_pattern(pattern)); + + PatternNode pattern_node = get_only(get_nodes(pattern)); + + MultiDiGraphPatternMatch match = empty_multidigraph_pattern_match(); + match.node_assignment.equate(pattern_node, graph_node); + + std::unordered_set incoming = + get_incoming_edges(graph, graph_node); + std::unordered_set outgoing = + get_outgoing_edges(graph, graph_node); + + std::unordered_set pattern_incoming = + get_incoming_edges(pattern, pattern_node); + std::unordered_set pattern_outgoing = + get_outgoing_edges(pattern, pattern_node); + + if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { + return std::nullopt; + } + + if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { + return std::nullopt; + } + + std::vector incoming_ordered = + sorted_by_dst_idx(incoming); + std::vector outgoing_ordered = + sorted_by_src_idx(outgoing); + + std::vector pattern_incoming_ordered = + sorted_by_dst_idx(pattern_incoming); + std::vector pattern_outgoing_ordered = + sorted_by_src_idx(pattern_outgoing); + + if (pattern_incoming.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < incoming_ordered.size(); ++i) { + UpwardOpenMultiDiEdge graph_edge = incoming_ordered[i]; + UpwardOpenPatternEdge pattern_edge = pattern_incoming_ordered[i]; + NodePort graph_port = get_dst_idx(graph_edge), + pattern_port = get_dst_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.emplace(graph_port, pattern_port); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + if (pattern_outgoing.size() > 0) { + std::unordered_map node_port_mapping; + for (int i = 0; i < outgoing_ordered.size(); ++i) { + DownwardOpenMultiDiEdge graph_edge = outgoing_ordered[i], + DownwardOpenPatternEdge pattern_edge = + pattern_outgoing_ordered[i]; + + NodePort graph_port = get_src_idx(graph_edge), + pattern_port = get_src_idx(pattern_edge); + if (!contains_key(node_port_mapping, graph_port)) { + node_port_mapping.insert({graph_port, pattern_port}); + } else { + if (pattern_port != node_port_mapping.at(graph_port)) { + return std::nullopt; + } + } + match.edge_assignment.equate(widen(pattern_edge), + widen(graph_edge)); + } + } + + return match; +} + +std::vector + find_pattern_matches(UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MatchAdditionalCriterion const &additional_criterion) { + std::vector matches; + if (is_singleton_pattern(pattern)) { + for (Node const &graph_node : get_nodes(graph)) { + std::optional candidate = + get_candidate_singleton_match(pattern, graph, graph_node); + if (candidate.has_value() && + pattern_does_match( + pattern, graph, candidate.value(), additional_criterion)) { + matches.push_back(candidate.value()); + } + } + } else { + GraphSplit split = split_pattern(pattern); + auto subpatterns = apply_split(pattern, split); + auto prefix_matches = + find_pattern_matches(subpatterns.first, graph, additional_criterion); + auto postfix_matches = + find_pattern_matches(subpatterns.second, graph, additional_criterion); + auto edge_splits = get_edge_splits(pattern, split); + for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { + for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { + std::optional unsplit = + unsplit_matches(prefix_match, postfix_match, edge_splits); + if (unsplit.has_value()) { + matches.push_back(unsplit.value()); + } + } + } + } + + return matches; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc new file mode 100644 index 0000000000..2eff39bb1e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/input_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_dst_node(InputPatternEdge const &e) { + return PatternNode{e.raw_edge.dst}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc new file mode 100644 index 0000000000..f3f5a8ce45 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/input_pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/input_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "d0cc0e65c4e3feb2e9b8435947c99e5f" +} +*/ + +#include "substitutions/unlabelled/input_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +InputPatternEdge::InputPatternEdge(::FlexFlow::InputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool InputPatternEdge::operator==(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool InputPatternEdge::operator!=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool InputPatternEdge::operator<=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool InputPatternEdge::operator>=(InputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::InputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::InputMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc new file mode 100644 index 0000000000..613159ad83 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_additional_criterion.dtg.cc @@ -0,0 +1,25 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_additional_criterion.struct.toml +/* proj-data +{ + "generated_from": "2dff356c85dccda1fce8f714d41c6202" +} +*/ + +#include "substitutions/unlabelled/match_additional_criterion.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include + +namespace FlexFlow { +MatchAdditionalCriterion::MatchAdditionalCriterion( + std::function const &node_criterion, + std::function const + &edge_criterion) + : node_criterion(node_criterion), edge_criterion(edge_criterion) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.cc new file mode 100644 index 0000000000..ef0397d6a8 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.cc @@ -0,0 +1,69 @@ +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +MatchSplit empty_match_split() { + return MatchSplit{empty_multidigraph_pattern_match(), + empty_multidigraph_pattern_match()}; +} + +MatchSplit apply_split(UnlabelledGraphPattern const &pattern, + MultiDiGraphPatternMatch const &match, + PatternSplit const &split) { + std::unordered_set prefix = split.first; + std::unordered_set postfix = split.second; + + MatchSplit result = empty_match_split(); + + for (auto const &[pattern_node, match_node] : match.node_assignment) { + if (contains(split.first, pattern_node)) { + result.prefix_submatch.node_assignment.equate(pattern_node, match_node); + } else { + assert(contains(split.second, pattern_node)); + result.postfix_submatch.node_assignment.equate(pattern_node, match_node); + } + } + + UnlabelledPatternEdgeSplits edge_splits = get_edge_splits(pattern, split); + + std::function + handle_edge = [&](PatternEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) -> void { + std::unordered_set edge_nodes = get_nodes(pattern_edge); + + if (is_subseteq_of(edge_nodes, prefix)) { + result.prefix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else if (is_subseteq_of(edge_nodes, postfix)) { + result.postfix_submatch.edge_assignment.equate(pattern_edge, graph_edge); + } else { + assert(is_standard_edge(graph_edge)); + + ClosedPatternEdge closed_edge = require_closed_edge(pattern_edge); + + auto split = get_split_edges(edge_splits, closed_edge); + OutputPatternEdge output_edge = split.first; + InputPatternEdge input_edge = split.second; + + auto split_graph_edge = split_edge(std::get(graph_edge)); + OutputMultiDiEdge output_graph_edge = split_graph_edge.first; + InputMultiDiEdge input_graph_edge = split_graph_edge.second; + + handle_edge(pattern_edge_from_input_edge(input_edge), + OpenMultiDiEdge{input_graph_edge}); + handle_edge(pattern_edge_from_output_edge(output_edge), + OpenMultiDiEdge{output_graph_edge}); + } + }; + + for (auto const &[pattern_edge, match_edge] : match.edge_assignment) { + handle_edge(pattern_edge, match_edge); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc new file mode 100644 index 0000000000..ffbdf96912 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/match_split.dtg.cc @@ -0,0 +1,26 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/match_split.struct.toml +/* proj-data +{ + "generated_from": "e44c4347e07263a493cbbd5caccedd22" +} +*/ + +#include "substitutions/unlabelled/match_split.dtg.h" + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +namespace FlexFlow { +MatchSplit::MatchSplit(MultiDiGraphPatternMatch const &prefix_submatch, + MultiDiGraphPatternMatch const &postfix_submatch) + : prefix_submatch(prefix_submatch), postfix_submatch(postfix_submatch) {} +bool MatchSplit::operator==(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) == + std::tie(other.prefix_submatch, other.postfix_submatch); +} +bool MatchSplit::operator!=(MatchSplit const &other) const { + return std::tie(this->prefix_submatch, this->postfix_submatch) != + std::tie(other.prefix_submatch, other.postfix_submatch); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc new file mode 100644 index 0000000000..8f4fd7f535 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -0,0 +1,56 @@ +#include "substitutions/unlabelled/multidigraph_pattern_match.h" +#include "substitutions/unlabelled/edge_splits.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { + return MultiDiGraphPatternMatch{ + bidict{}, + bidict{}, + }; +} + +std::optional + unsplit_matches(MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, + UnlabelledPatternEdgeSplits const &edge_splits) { + + MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); + + std::unordered_set handled; + for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { + ClosedPatternEdge closed_edge = std::get(coi); + OutputPatternEdge output_edge = std::get(coi); + InputPatternEdge input_edge = std::get(coi); + + handled.insert(pattern_edge_from_output_edge(output_edge)); + handled.insert(pattern_edge_from_input_edge(input_edge)); + + OpenMultiDiEdge output_graph_edge = + prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); + OpenMultiDiEdge input_graph_edge = + postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); + if (output_graph_edge == input_graph_edge) { + result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), + output_graph_edge); + } else { + return std::nullopt; + } + } + + for (auto const &kv : + merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { + if (!contains(handled, kv.first)) { + result.edge_assignment.equate(kv.first, kv.second); + } + } + + result.node_assignment = + merge_maps(prefix.node_assignment, postfix.node_assignment); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc new file mode 100644 index 0000000000..9fc2169dd7 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.dtg.cc @@ -0,0 +1,34 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/multidigraph_pattern_match.struct.toml +/* proj-data +{ + "generated_from": "9842661a5d4e7d717f12d2c27da7df0d" +} +*/ + +#include "substitutions/unlabelled/multidigraph_pattern_match.dtg.h" + +#include "substitutions/unlabelled/pattern_edge.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/bidict.h" +#include "utils/graph.h" + +namespace FlexFlow { +MultiDiGraphPatternMatch::MultiDiGraphPatternMatch( + ::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::Node> const + &node_assignment, + ::FlexFlow::bidict<::FlexFlow::PatternEdge, + ::FlexFlow::OpenMultiDiEdge> const &edge_assignment) + : node_assignment(node_assignment), edge_assignment(edge_assignment) {} +bool MultiDiGraphPatternMatch::operator==( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) == + std::tie(other.node_assignment, other.edge_assignment); +} +bool MultiDiGraphPatternMatch::operator!=( + MultiDiGraphPatternMatch const &other) const { + return std::tie(this->node_assignment, this->edge_assignment) != + std::tie(other.node_assignment, other.edge_assignment); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc new file mode 100644 index 0000000000..6e70fc8df6 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/output_pattern_edge.h" + +namespace FlexFlow { + +PatternNode get_src_node(OutputPatternEdge const &e) { + return PatternNode{e.raw_edge.src}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc new file mode 100644 index 0000000000..fb9de06135 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/output_pattern_edge.dtg.cc @@ -0,0 +1,46 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/output_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "3222696e351c3e203e008714245c737f" +} +*/ + +#include "substitutions/unlabelled/output_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +OutputPatternEdge::OutputPatternEdge( + ::FlexFlow::OutputMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool OutputPatternEdge::operator==(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator!=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator<=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool OutputPatternEdge::operator>=(OutputPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::OutputPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OutputMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc new file mode 100644 index 0000000000..3dd4987705 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.cc @@ -0,0 +1,50 @@ +#include "substitutions/unlabelled/pattern_edge.h" +#include "utils/containers.h" + +namespace FlexFlow { + +std::unordered_set get_nodes(PatternEdge const &e) { + return transform(get_nodes(e.raw_edge), + [](Node const &n) { return PatternNode{n}; }); +} + +bool is_standard_edge(PatternEdge const &e) { + return is_standard_edge(e.raw_edge); +} + +bool is_input_edge(PatternEdge const &e) { + return is_input_edge(e.raw_edge); +} + +bool is_output_edge(PatternEdge const &e) { + return is_output_edge(e.raw_edge); +} + +ClosedPatternEdge require_closed_edge(PatternEdge const &e) { + assert(is_closed_edge(e)); + return ClosedPatternEdge{std::get(e.raw_edge)}; +} + +InputPatternEdge require_input_edge(PatternEdge const &e) { + assert(is_input_edge(e)); + return InputPatternEdge{std::get(e.raw_edge)}; +} + +OutputPatternEdge require_output_edge(PatternEdge const &e) { + assert(is_output_edge(e)); + return OutputPatternEdge{std::get(e.raw_edge)}; +} + +PatternEdge pattern_edge_from_input_edge(InputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_output_edge(OutputPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +PatternEdge pattern_edge_from_closed_edge(ClosedPatternEdge const &e) { + return PatternEdge{OpenMultiDiEdge{e.raw_edge}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc new file mode 100644 index 0000000000..e4d11d0d7e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_edge.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a3eff166b0c8be2ddf3f7305eec094fd" +} +*/ + +#include "substitutions/unlabelled/pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternEdge::PatternEdge(::FlexFlow::OpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool PatternEdge::operator==(PatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool PatternEdge::operator!=(PatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool PatternEdge::operator<(PatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool PatternEdge::operator>(PatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool PatternEdge::operator<=(PatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool PatternEdge::operator>=(PatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::OpenMultiDiEdge>{}(x.raw_edge) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc new file mode 100644 index 0000000000..335b9664ea --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -0,0 +1,74 @@ +#include "substitutions/unlabelled/pattern_matching.h" +#include "substitutions/unlabelled/input_pattern_edge.h" +#include "substitutions/unlabelled/match_split.h" +#include "substitutions/unlabelled/output_pattern_edge.h" +#include "substitutions/unlabelled/pattern_edge.h" +#include "substitutions/unlabelled/pattern_split.h" +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include + +namespace FlexFlow { + +bool unlabelled_pattern_does_match( + UnlabelledGraphPattern const &pattern, + OpenMultiDiGraphView const &graph, + MultiDiGraphPatternMatch const &match, + MatchAdditionalCriterion const &additional_criterion) { + if (is_singleton_pattern(pattern)) { + PatternNode pattern_node = get_only(get_nodes(pattern)); + Node matched_node = match.node_assignment.at_l(pattern_node); + if (!additional_criterion.node_criterion(pattern_node, matched_node)) { + return false; + } + for (PatternEdge const &e : get_edges(pattern)) { + OpenMultiDiEdge matched_edge = match.edge_assignment.at_l(e); + + assert(is_input_edge(e) || is_output_edge(e)); + if (is_input_edge(e)) { + if (is_output_edge(matched_edge)) { + return false; + } + UpwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + InputPatternEdge input_edge = require_input_edge(e); + if (match.node_assignment.at_l(get_dst_node(input_edge)) != + get_dst_node(matched_edge)) { + return false; + } + } else { + if (is_input_edge(matched_edge)) { + return false; + } + DownwardOpenMultiDiEdge matched_edge = + narrow(matched_edge).value(); + OutputPatternEdge output_edge = require_output_edge(e); + if (match.node_assignment.at_l(get_src_node(output_edge)) != + get_src_node(matched_edge)) { + return false; + } + } + + if (!additional_criterion.edge_criterion(e, matched_edge)) { + return false; + } + } + + return true; + } + + PatternSplit split = find_even_split(pattern); + std::pair subpatterns = + apply_split(pattern, split); + auto submatches = apply_split(pattern, match, split); + + return unlabelled_pattern_does_match(subpatterns.first, + graph, + submatches.prefix_submatch, + additional_criterion) && + unlabelled_pattern_does_match(subpatterns.second, + graph, + submatches.postfix_submatch, + additional_criterion); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc new file mode 100644 index 0000000000..6ea64de69e --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_node.dtg.cc @@ -0,0 +1,45 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_node.struct.toml +/* proj-data +{ + "generated_from": "a0e58ade010a9b250d2c1c378fde2639" +} +*/ + +#include "substitutions/unlabelled/pattern_node.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +PatternNode::PatternNode(::FlexFlow::Node const &raw_node) + : raw_node(raw_node) {} +bool PatternNode::operator==(PatternNode const &other) const { + return std::tie(this->raw_node) == std::tie(other.raw_node); +} +bool PatternNode::operator!=(PatternNode const &other) const { + return std::tie(this->raw_node) != std::tie(other.raw_node); +} +bool PatternNode::operator<(PatternNode const &other) const { + return std::tie(this->raw_node) < std::tie(other.raw_node); +} +bool PatternNode::operator>(PatternNode const &other) const { + return std::tie(this->raw_node) > std::tie(other.raw_node); +} +bool PatternNode::operator<=(PatternNode const &other) const { + return std::tie(this->raw_node) <= std::tie(other.raw_node); +} +bool PatternNode::operator>=(PatternNode const &other) const { + return std::tie(this->raw_node) >= std::tie(other.raw_node); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::PatternNode const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::Node>{}(x.raw_node) + 0x9e3779b9 + + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc new file mode 100644 index 0000000000..e116c062df --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.cc @@ -0,0 +1,42 @@ +#include "substitutions/unlabelled/pattern_split.h" + +namespace FlexFlow { + +PatternSplit find_even_split(UnlabelledGraphPattern const &p) { + std::vector topological_ordering = + get_topological_ordering(pattern.raw_graph); + assert(topological_ordering.size() >= 2); + + int split_point = topological_ordering.size() / 2; + auto split = vector_split(topological_ordering, split_point); + std::unordered_set prefix(split.first.begin(), + split.first.end()); + std::unordered_set postfix(split.second.begin(), + split.second.end()); + return {prefix, postfix}; +} + +GraphSplit get_raw_split(PatternSplit const &s) { + return std::pair{ + transform(s.first, [](PatternNode const &n) { return n.raw_node; }), + transform(s.second, [](PatternNode const &n) { return n.raw_node; }), + }; +} + +UnlabelledPatternEdgeSplits + get_edge_splits(UnlabelledGraphPattern const &pattern, + PatternSplit const &split) { + bidict> + raw_result = get_edge_splits(pattern.raw_graph, get_raw_split(split), ); + return UnlabelledPatternEdgeSplits{raw_result}; +} + +std::pair + apply_split(UnlabelledGraphPattern const &p, PatternSplit const &s) { + return { + get_subgraph(p, s.left); + get_subgraph(p, s.right); + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc new file mode 100644 index 0000000000..bbcd4c3902 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_split.dtg.cc @@ -0,0 +1,60 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/pattern_split.struct.toml +/* proj-data +{ + "generated_from": "8604edb5bd1a546ffa94ef496888e46d" +} +*/ + +#include "substitutions/unlabelled/pattern_split.dtg.h" + +#include "substitutions/unlabelled/pattern_node.dtg.h" +#include "utils/graph.h" +#include +#include + +namespace FlexFlow { +PatternSplit::PatternSplit( + std::unordered_set<::FlexFlow::PatternNode> const &first, + std::unordered_set<::FlexFlow::PatternNode> const &second) + : first(first), second(second) {} +bool PatternSplit::operator==(PatternSplit const &other) const { + return std::tie(this->first, this->second) == + std::tie(other.first, other.second); +} +bool PatternSplit::operator!=(PatternSplit const &other) const { + return std::tie(this->first, this->second) != + std::tie(other.first, other.second); +} +} // namespace FlexFlow + +namespace nlohmann { +FlexFlow::PatternSplit + adl_serializer::from_json(json const &j) { + return { + j.at("first").template get>(), + j.at("second") + .template get>()}; +} +void adl_serializer::to_json( + json &j, FlexFlow::PatternSplit const &v) { + j["__type"] = "PatternSplit"; + j["first"] = v.first; + j["second"] = v.second; +} +} // namespace nlohmann + +namespace FlexFlow { +std::string format_as(PatternSplit const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, PatternSplit const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc new file mode 100644 index 0000000000..df10507a04 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -0,0 +1,52 @@ +#include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/containers.h" + +namespace FlexFlow { + +size_t num_nodes(UnlabelledGraphPattern const &p) { + return num_nodes(p.raw_graph); +} + +bool is_singleton_pattern(UnlabelledGraphPattern const &pattern) { + return num_nodes(pattern) == 1; +} + +std::unordered_set get_nodes(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](Node const &n) { + return PatternNode{n}; }}); +} + +std::unordered_set get_edges(UnlabelledGraphPattern const &p) { + return transform(get_nodes(p.raw_graph), + [](OpenMultiDiEdge const &e) { + return PatternEdge{e}; }}); +} + +std::vector get_topological_ordering(UnlabelledGraphPattern const &p) { + return transform(get_topological_ordering(p), + [](Node const &n) { + return PatternNode{n}; }}); +} + +UnlabelledGraphPattern get_subgraph(UnlabelledGraphPattern const &p, + std::unordered_set const &n) { + return { + get_subgraph(p.raw_graph, + transform(n, [](PatternNode const &n) { return n.raw_node; })); + }; +} + +std::unordered_set + get_incoming_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_incoming_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); +} + +std::unordered_set + get_outgoing_edges(UnlabelledGraphPattern const &p, PatternNode const &n) { + return transform(get_outgoing_edges(p.raw_graph, n.raw_node), + [](Node const &n) { return PatternNode{n}; }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc new file mode 100644 index 0000000000..019209ee86 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.dtg.cc @@ -0,0 +1,18 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.struct.toml +/* proj-data +{ + "generated_from": "f494ed79eb1ba4010155e456b452157f" +} +*/ + +#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UnlabelledGraphPattern::UnlabelledGraphPattern( + ::FlexFlow::OpenMultiDiGraphView const &raw_graph) + : raw_graph(raw_graph) {} +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc new file mode 100644 index 0000000000..8664f3c66c --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.cc @@ -0,0 +1,9 @@ +#include "substitutions/unlabelled/upward_open_pattern_edge.h" + +namespace FlexFlow { + +int get_dst_idx(UpwardOpenPatternEdge const &e) { + return get_src_idx(e.raw_edge); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc new file mode 100644 index 0000000000..ca8dd6c020 --- /dev/null +++ b/lib/substitutions/src/substitutions/unlabelled/upward_open_pattern_edge.dtg.cc @@ -0,0 +1,52 @@ +// THIS FILE WAS AUTO-GENERATED BY proj. DO NOT MODIFY IT! +// If you would like to modify this datatype, instead modify +// lib/substitutions/include/substitutions/unlabelled/upward_open_pattern_edge.struct.toml +/* proj-data +{ + "generated_from": "a1d4c9d1dd94eb456c5e29d80ad579da" +} +*/ + +#include "substitutions/unlabelled/upward_open_pattern_edge.dtg.h" + +#include "utils/graph.h" + +namespace FlexFlow { +UpwardOpenPatternEdge::UpwardOpenPatternEdge( + ::FlexFlow::UpwardOpenMultiDiEdge const &raw_edge) + : raw_edge(raw_edge) {} +bool UpwardOpenPatternEdge::operator==( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) == std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator!=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) != std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) < std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) > std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator<=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) <= std::tie(other.raw_edge); +} +bool UpwardOpenPatternEdge::operator>=( + UpwardOpenPatternEdge const &other) const { + return std::tie(this->raw_edge) >= std::tie(other.raw_edge); +} +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::UpwardOpenPatternEdge const &x) const { + size_t result = 0; + result ^= std::hash<::FlexFlow::UpwardOpenMultiDiEdge>{}(x.raw_edge) + + 0x9e3779b9 + (result << 6) + (result >> 2); + return result; +} +} // namespace std diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index df22d8a620..2d9320275d 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -8,8 +8,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("apply_substitution") { OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + std::vector{ + OperatorAttributeConstraint{ConstraintType::EQUAL, + OperatorAttributeKey::OP_TYPE, + OperatorType::LINEAR}}}; ParallelTensorPattern tensor_pattern_e0{ std::vector{ @@ -38,12 +40,13 @@ TEST_SUITE(FF_TEST_SUITE) { GraphPattern input_graph{ig}; OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {{OperatorAttributeKey::OP_TYPE, + AttrConstant{OperatorType::REPARTITION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::LINEAR}}, {OperatorAttributeKey::OUT_CHANNELS, OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, {OperatorAttributeKey::USE_BIAS, @@ -56,7 +59,7 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {{OperatorAttributeKey::OP_TYPE, AttrConstant{OperatorType::REDUCTION}}, {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; @@ -100,10 +103,9 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiEdge e4{n5, p5, n4, p4}; pcg.add_edge(e4); - pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); + ParallelDim dim = {2, 1, false}; + ParallelTensorDims dims = {FFOrdered{dim}}; + pcg.add_label(e4, ParallelTensor(dims, DataType::FLOAT, CreateGrad::YES)); MatchAdditionalCriterion criterion{ [&](Node const &pattern_node, Node const &graph_node) { diff --git a/lib/utils/include/utils/bidict.h b/lib/utils/include/utils/bidict.h index 0869b0f9e8..6af18c2a4a 100644 --- a/lib/utils/include/utils/bidict.h +++ b/lib/utils/include/utils/bidict.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_BIDICT_H #define _FLEXFLOW_UTILS_BIDICT_H +#include "utils/fmt/unordered_map.h" #include #include @@ -55,9 +56,22 @@ struct bidict { bwd_map.insert({lr.second, lr.first}); } + bool operator==(bidict const &other) const { + bool result = this->fwd_map == other.fwd_map; + assert(result == (this->bwd_map == other.bwd_map)); + return result; + } + + bool operator!=(bidict const &other) const { + bool result = this->fwd_map != other.fwd_map; + assert(result == (this->bwd_map != other.bwd_map)); + return result; + } + R const &at_l(L const &l) const { return fwd_map.at(l); } + L const &at_r(R const &r) const { return bwd_map.at(r); } @@ -163,6 +177,22 @@ struct bidict { std::unordered_map bwd_map; }; +template +std::unordered_map format_as(bidict const &b) { + return b; +} + } // namespace FlexFlow +namespace std { + +template +struct hash<::FlexFlow::bidict> { + size_t operator()(::FlexFlow::bidict const &b) const { + return hash>{}(b); + } +}; + +} // namespace std + #endif diff --git a/lib/utils/include/utils/check_fmtable.h b/lib/utils/include/utils/check_fmtable.h new file mode 100644 index 0000000000..3b4e55c459 --- /dev/null +++ b/lib/utils/include/utils/check_fmtable.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CHECK_FMTABLE_H + +#define CHECK_FMTABLE(...) \ + static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ + #__VA_ARGS__ " must be fmtable"); + +namespace FlexFlow { + +template +using is_fmtable = ::fmt::is_formattable; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 0332a331b2..b02c95bf77 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -16,19 +16,6 @@ struct get_element_type; template using get_element_type_t = typename get_element_type::type; -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f); - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter); - -template -std::string join_strings(Container const &c, std::string const &delimiter); - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e); @@ -154,6 +141,9 @@ template > bidict generate_bidict(C const &c, F const &f); +template +std::optional at_idx(std::vector const &v, size_t idx); + template std::function lookup_in(std::unordered_map const &m); @@ -208,6 +198,9 @@ void extend(C &lhs, std::optional const &e); template bool all_of(C const &c, F const &f); +template +std::optional optional_all_of(Container const &, Function const &); + template int count(C const &c, F const &f); @@ -226,11 +219,6 @@ template auto transform(req const &c, F const &f) -> decltype(transform(std::declval(), std::declval())); -template ()(std::declval()))> -std::vector vector_transform(F const &f, std::vector const &v); - template ()(std::declval()))> diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 1606eb0605..fbaf572df1 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -5,6 +5,8 @@ #include "containers.decl.h" #include "required_core.h" #include "type_traits_core.h" +#include "utils/containers/extend_vector.h" +#include "utils/containers/vector_transform.h" #include "utils/exception.h" #include "utils/type_traits.h" #include @@ -21,38 +23,6 @@ namespace FlexFlow { -template -std::string join_strings(InputIt first, - InputIt last, - std::string const &delimiter, - F const &f) { - std::ostringstream oss; - bool first_iter = true; - /* int i = 0; */ - for (; first != last; first++) { - if (!first_iter) { - oss << delimiter; - } - oss << *first; - /* break; */ - first_iter = false; - /* i++; */ - } - return oss.str(); -} - -template -std::string - join_strings(InputIt first, InputIt last, std::string const &delimiter) { - using Ref = typename InputIt::reference; - return join_strings(first, last, delimiter, [](Ref r) { return r; }); -} - -template -std::string join_strings(Container const &c, std::string const &delimiter) { - return join_strings(c.cbegin(), c.cend(), delimiter); -} - template typename Container::const_iterator find(Container const &c, typename Container::value_type const &e) { @@ -346,6 +316,15 @@ bidict generate_bidict(C const &c, F const &f) { return {transformed.cbegin(), transformed.cend()}; } +template +std::optional at_idx(std::vector const &v, size_t idx) { + if (idx >= v.size()) { + return std::nullopt; + } else { + return v.at(idx); + } +} + template std::function lookup_in(std::unordered_map const &m) { return [&m](K const &k) -> V { return m.at(k); }; @@ -441,6 +420,7 @@ T get_first(std::unordered_set const &s) { template void extend(std::vector &lhs, C const &rhs) { + extend_vector(lhs, rhs); lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); lhs.insert(lhs.end(), rhs.begin(), rhs.end()); } @@ -468,6 +448,22 @@ bool all_of(C const &c, F const &f) { return true; } +template +std::optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + std::optional condition = func(element); + if (!condition.has_value()) { + return std::nullopt; + } + + if (!condition.value()) { + return false; + } + } + return true; +} + template int count(C const &c, F const &f) { int result = 0; @@ -509,11 +505,6 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template -std::vector vector_transform(F const &f, std::vector const &v) { - return transform(v, f); -} - template std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; @@ -693,7 +684,7 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { if (idx < 0) { - return v.size() - idx; + return v.size() + idx; } else { return idx; } diff --git a/lib/utils/include/utils/containers/concat_vectors.h b/lib/utils/include/utils/containers/concat_vectors.h new file mode 100644 index 0000000000..7940a37510 --- /dev/null +++ b/lib/utils/include/utils/containers/concat_vectors.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CONCAT_VECTORS_H + +#include "utils/containers/extend_vector.h" + +namespace FlexFlow { + +template +std::vector concat_vectors(std::vector const &prefix, + std::vector const &postfix) { + std::vector result = prefix; + extend_vector(result, postfix); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h new file mode 100644 index 0000000000..8d36a5fe3b --- /dev/null +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> enumerate_vector(std::vector const &v) { + std::vector> result; + for (int i = 0; i < v.size(); i++) { + result.push_back({i, v.at(i)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/extend_vector.h b/lib/utils/include/utils/containers/extend_vector.h new file mode 100644 index 0000000000..62ce94e49c --- /dev/null +++ b/lib/utils/include/utils/containers/extend_vector.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_EXTEND_VECTOR_H + +#include + +namespace FlexFlow { + +template +void extend_vector(std::vector &lhs, C const &rhs) { + lhs.reserve(lhs.size() + std::distance(rhs.begin(), rhs.end())); + lhs.insert(lhs.end(), rhs.begin(), rhs.end()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_transform.h b/lib/utils/include/utils/containers/vector_transform.h new file mode 100644 index 0000000000..6d13584775 --- /dev/null +++ b/lib/utils/include/utils/containers/vector_transform.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_TRANSFORM_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> + vector_transform(std::vector const &v, F const &f) { + using Out = std::invoke_result_t; + + std::vector result; + std::transform(v.cbegin(), v.cend(), std::back_inserter(result), f); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_vectors.h b/lib/utils/include/utils/containers/zip_vectors.h new file mode 100644 index 0000000000..d32e539bef --- /dev/null +++ b/lib/utils/include/utils/containers/zip_vectors.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VECTORS_H + +#include +#include + +namespace FlexFlow { + +template +std::vector> zip(std::vector const &l, + std::vector const &r) { + std::vector> result; + for (int i = 0; i < std::min(l.size(), r.size()); i++) { + result.push_back(std::make_pair(l.at(i), r.at(i))); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h index d27174f474..93c450294b 100644 --- a/lib/utils/include/utils/exception.decl.h +++ b/lib/utils/include/utils/exception.decl.h @@ -3,20 +3,30 @@ #include "utils/fmt.decl.h" #include +#include namespace FlexFlow { #ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() static_assert(false, "Function not yet implemented"); +#define NOT_IMPLEMENTED() \ + static_assert(false, \ + "Function " __FUNC__ " not yet implemented " __FILE__ \ + ":" __LINE__); #else -#define NOT_IMPLEMENTED() throw not_implemented(); +#define NOT_IMPLEMENTED() \ + throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); #endif class not_implemented : public std::logic_error { public: - not_implemented(); + not_implemented(std::string const &function_name, + std::string const &file_name, + int line); }; +template +T throw_if_unexpected(tl::expected const &r); + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args); diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index fd3a0b7ee0..a00d2dba2b 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -4,9 +4,19 @@ #include "utils/exception.decl.h" #include "utils/fmt.h" #include +#include namespace FlexFlow { +template +T throw_if_unexpected(tl::expected const &r) { + if (r.has_value()) { + return r.value(); + } else { + throw std::runtime_error(fmt::to_string(r.error())); + } +} + template std::runtime_error mk_runtime_error(fmt::format_string fmt_str, T &&...args) { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h index 7adb2052ad..04902c8240 100644 --- a/lib/utils/include/utils/fmt.decl.h +++ b/lib/utils/include/utils/fmt.decl.h @@ -2,22 +2,18 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H #include "fmt/format.h" +#include "utils/check_fmtable.h" #include +#include +#include #include -#define CHECK_FMTABLE(...) \ - static_assert(::FlexFlow::is_fmtable<__VA_ARGS__>::value, \ - #__VA_ARGS__ " must be fmtable"); - #define DELEGATE_OSTREAM(...) \ template <> \ struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} namespace FlexFlow { -template -using is_fmtable = ::fmt::is_formattable; - template struct delegate_ostream_operator : std::false_type {}; @@ -30,20 +26,38 @@ typename std::enable_if>::value, namespace fmt { -template -struct formatter<::std::unordered_set> : formatter<::std::string> { +template +struct formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value>> + : formatter<::std::string, Char> { template auto format(::std::unordered_set const &m, FormatContext &ctx) -> decltype(ctx.out()); }; -template -struct formatter<::std::vector> : formatter<::std::string> { +/* template */ +/* std::string format_as(::std::unordered_set const &); */ + +template +struct formatter< + ::std::vector, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { template auto format(::std::vector const &m, FormatContext &ctx) -> decltype(ctx.out()); }; +template +struct formatter<::std::variant> : formatter<::std::string> { + template + auto format(::std::variant const &m, FormatContext &ctx) + -> decltype(ctx.out()); +}; + } // namespace fmt #endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 9cb56e4e2b..967a41f22b 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,16 +1,69 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/containers.decl.h" +#include "utils/containers.h" #include "utils/fmt.decl.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include +#include #include +#include -#include +namespace fmt { + +template +template +auto formatter< + ::std::unordered_set, + Char, + std::enable_if_t>::value>>:: + format(::std::unordered_set const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + /* CHECK_FMTABLE(T); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); + * }); */ + std::string result = ""; + return formatter::format(result, ctx); +} + +template +template +auto formatter< + ::std::vector, + Char, + std::enable_if_t>::value>>:: + format(::std::vector const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = ::FlexFlow::join_strings( + m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); + return formatter::format("[" + result + "]", ctx); +} + +template +template +auto formatter<::std::variant>::format(::std::variant const &m, + FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result = + std::visit([](auto &&x) { return fmt::to_string(x); }, m); + return formatter::format(result, ctx); +} +} // namespace fmt namespace FlexFlow { +template +struct delegate_ostream_operator> : std::true_type {}; + +template +struct delegate_ostream_operator> : std::true_type {}; + template struct delegate_ostream_operator> : std::true_type {}; @@ -31,32 +84,4 @@ typename std::enable_if>::value, } // namespace FlexFlow -namespace fmt { - -template -template -auto formatter<::std::unordered_set>::format( - ::std::unordered_set const &m, FormatContext &ctx) - -> decltype(ctx.out()) { - CHECK_FMTABLE(T); - - std::string result = join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format(result, ctx); -} - -template -template -auto formatter<::std::vector>::format(::std::vector const &m, - FormatContext &ctx) - -> decltype(ctx.out()) { - CHECK_FMTABLE(T); - - std::string result = join_strings( - m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - return formatter::format(result, ctx); -} - -} // namespace fmt - #endif diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h new file mode 100644 index 0000000000..5edd054ebe --- /dev/null +++ b/lib/utils/include/utils/fmt/expected.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter<::tl::expected, Char> + /* std::enable_if_t>::value>> */ + : formatter<::std::string> { + template + auto format(::tl::expected const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + + std::string result; + if (m.has_value()) { + result = fmt::format("expected({})", m.value()); + } else { + result = fmt::format("unexpected({})", m.error()); + } + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h new file mode 100644 index 0000000000..eb1147ae3c --- /dev/null +++ b/lib/utils/include/utils/fmt/pair.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::pair const &m) { + CHECK_FMTABLE(L); + CHECK_FMTABLE(R); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h new file mode 100644 index 0000000000..19701bfb0c --- /dev/null +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MAP_H + +#include "fmt/format.h" +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include + +namespace fmt { + +template +struct formatter< + ::std::unordered_map, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::unordered_map const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + /* CHECK_FMTABLE(K); */ + /* CHECK_FMTABLE(V); */ + + /* std::string result = ::FlexFlow::join_strings( */ + /* m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return + * fmt::to_string(p); }); */ + std::string result = ""; + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { + CHECK_FMTABLE(K); + CHECK_FMTABLE(V); + + return s << fmt::to_string(m); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index c62b2df294..25b0103f9c 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -4,8 +4,8 @@ FlexFlow's graph library very intentionally attempts to balance performance and ease of use. The graph library aims to have a very simple external interface that is highly decoupled from the underlying representations, so performance and internal implementations can be tuned and modified over time without breaking the code that uses the library. -Because FlexFlow's graphs are not on the scale of machine memory or not so large that single traversals takes nontrivial time, the graph library intentially avoids performance opportunites that would expose many of these performance aspects to user code. -Of course, there are also some optimizations that simply have not been done due to time constraints: for example, algorithms currently are able to be specialized for the underlyign representation being used, but this could be added without modifying the user-side interface. +Because FlexFlow's graphs are not on the scale of machine memory or not so large that single traversals takes nontrivial time, the graph library intentionally avoids performance opportunities that would expose many of these performance aspects to user code. +Of course, there are also some optimizations that simply have not been done due to time constraints: for example, algorithms currently are able to be specialized for the underlying representation being used, but this could be added without modifying the user-side interface. ## Usage @@ -17,7 +17,7 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected - `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination` indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. +- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. Examples of the different graph variants are shown below. @@ -149,6 +149,7 @@ To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. `MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. +`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -179,6 +180,16 @@ Generally users will use underlying representations provided by the graph librar [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ +### Open, Upward, Downward + +`Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. +We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). + +![Open graphs inheritance diagram](docs/open.svg) + +Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) + + ### Labelled Graphs As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. @@ -191,6 +202,46 @@ As such, the labelled graph types provide the typical `at` method (as on `std::u [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. +![Labelled Graphs Inheritance Diagram](docs/labelled.svg) + + + ## Internals -TODO @lockshaw +Most of the major graph classes in the library come in sets of 4. For a given class `GlassName` we have: +1. `ClassName` +2. `ClassNameView` +3. `IClassName` +4. `IClassNameView` + +General rules which apply to most classes: +- `ClassName` (virtually) inherits from `ClassNameView`. Similarly, `IClassName` (virtually) inherits from `IClassNameView`. +- `ClassName` has, as a member variable, a `cow_ptr` of type `IClassName`. Same holds for `ClassNameView`. +Thus, the bulk of the inheritance that actually extends functionality is present among `IClassNameView` classes. + + +### cow_ptr and Interfaces + +The reason for the existence of the `View` variants has been explained in previous sections. +The existence of the `I(nterface)` variants stems from C++'s approach to modeling polymorphism. + +C++ polymorphism is achieved at runtime through the use of [virtual functions](https://www.learncpp.com/cpp-tutorial/virtual-functions/), which allow for a single function defined on some superclass to also work correctly on its subclasses. + +To create objects with polymorphic behaviour, we use the following syntax: +`BaseClass* obj = new DerivedClass(); //or alternatives such as std::shared_ptr obj = std::make_shared();` +Any call to `obj`'s member functions are resolved at runtime (dynamic binding), with C++ calling the most derived implementation of the function. + +While this pattern works nicely, the way instantiation is done leaves the burden of memory management on the user. +To address this, graph classes store a `cow_ptr` as a member variable, which point to instances of type equal to their corresponding interface class. + +All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. + +To create graphs within the library, we thus use the following syntax: +`BaseGraph obj = BaseGraph::create();` + +Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` + +### Virtual Inheritance +Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). +In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/docs/edges.svg b/lib/utils/include/utils/graph/docs/edges.svg new file mode 100644 index 0000000000..0e01479dc2 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/edges.svg @@ -0,0 +1 @@ +EdgesDiInputDiOutputDirectedEdgeInputMultiDiEdgeMultiDiEdgeMultiDiInputMultiDiOutputOutputMultiDiEdgeUndirectedEdge \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/generate_diagram.py b/lib/utils/include/utils/graph/docs/generate_diagram.py new file mode 100644 index 0000000000..5a4fa2e456 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/generate_diagram.py @@ -0,0 +1,135 @@ +''' +Script to generate a PlantUML graph for the inheritance / dependency hierarchy between the graph classes +Modify the `headers` and `selected_groups` variables to generated different diagrams +''' + +import re +from dataclasses import dataclass +from collections import defaultdict +from hpp2plantuml import CreatePlantUMLFile +import os + +@dataclass +class Component: + name: str + rawstring: str + +def clean_puml(puml : bytes) -> str: + puml = puml.decode().split('\n') + puml = filter(lambda string : all(not string.strip(' \t').startswith(char) for char in '+-#'), puml) #remove info related to class members + puml = (line.strip('\t') for line in puml) + puml = '\n'.join(puml) + puml = puml.replace(" {\n}", '') + puml = re.sub(r' <.*?<.*?>>', '', puml) #remove the templates + return puml + +def remove_enum(puml): + return puml.replace('\nenum LRDirection {\nLEFT\nRIGHT\n}\n', '') + + +def remove_namespace(puml): + pattern = r'namespace FlexFlow {([^}]*)}' + puml = re.sub(pattern, lambda x: x.group(1).strip(), puml, flags=re.DOTALL) + puml = puml.replace('FlexFlow.', '') + return puml + +def get_components(puml): + components = [] + for line in puml.split('\n'): + if 'class' in line: + name = re.sub(r'\b(?:class|abstract\s+class)\b ', '', line) + components.append(Component(name, line)) + return components + +def get_additional_cowptr_connections(components): + extra_connections = [] + names = {c.name for c in components} + for name in names: + if 'I'+name in names: + extra_connections.append(f'I{name} *-- {name}') + return extra_connections + +def get_connections(puml, includeaggregation=False): + pattern = '--' if includeaggregation else '<|--' + connections = [] + for line in puml.split('\n'): + if pattern in line: + connections.append(line) + return connections + +def filter_by_groups(groups, components): + component_classifications = defaultdict(list) + filtered_components = [] + for component in components: + for packagename in groups: + filtering_func = GROUPS[packagename] + if filtering_func(component.name): + component_classifications[packagename].append(component) + filtered_components.append(component) + break + return component_classifications, filtered_components + + +def filter_connections(connections, components): + filtered_connections = [] + component_names = {comp.name for comp in components} + for conn in connections: + parent, _, child = conn.split(' ') + if parent in component_names and child in component_names: + filtered_connections.append(conn) + return filtered_connections + +if __name__=='__main__': + + # Provide directory path(s) and selected_groups to generate the corresponding puml file + headers = ["../labelled/*.h", "../*.h"] + selected_groups = ('Labelled','Labelled.NodeLabelled','Labelled.OutputLabelled') + output_filename = 'output.puml' + + selected_groups = sorted(selected_groups, reverse=True) #to ensure that classification for subcategories is given precedence + GROUPS = { + 'Graph' : lambda comp : 'Graph' in comp, + 'Edges' : lambda comp : any(comp.endswith(pattern) for pattern in ('Input', 'Output', 'Edge')), + 'Open' : lambda comp : 'Open' in comp and 'Query' not in comp, # doesn't include Upwards or Downwards + 'Open.Upward' : lambda comp : 'Upward' in comp and 'Query' not in comp, + 'Open.Downward' : lambda comp : 'Downward' in comp and 'Query' not in comp, + 'DiGraphs.MultiDiGraphs' : lambda comp : 'MultiDiGraph' in comp, + 'DiGraphs' : lambda comp : 'DiGraph' in comp, + 'Undirected' : lambda comp : 'UndirectedGraph' in comp, + + 'Labelled' : lambda comp : 'Labelled' in comp, + 'Labelled.NodeLabelled' : lambda comp : 'NodeLabelled' in comp, + 'Labelled.OutputLabelled' : lambda comp : 'OutputLabelled' in comp + } + TEMP_FILENAME = 'generate_diagram_temp.puml' + + CreatePlantUMLFile(headers, output_file = TEMP_FILENAME) + + with open(TEMP_FILENAME, 'rb') as tempfile: + puml : bytes = tempfile.read() + os.remove(TEMP_FILENAME) + + puml = clean_puml(puml) + puml = remove_enum(puml) + puml = remove_namespace(puml) + + components = get_components(puml) + connections = get_connections(puml) + cowptr_connections = get_additional_cowptr_connections(components) + connections += cowptr_connections + + packageclassification, components = filter_by_groups(selected_groups, components) + connections = filter_connections(connections, components) + + final_puml = "" + final_puml += "@startuml\nleft to right direction\n\n" + + for packagename, components in packageclassification.items(): + component_string = '\n'.join(f'\t{c.rawstring}' for c in components) + final_puml+=f'package {packagename} {{ \n{component_string} \n}}\n\n' + + final_puml+='\n'.join(connections) + final_puml+="\n\n@enduml" + print(final_puml) + with open(output_filename, 'w') as file: + file.write(final_puml) diff --git a/lib/utils/include/utils/graph/docs/labelled.svg b/lib/utils/include/utils/graph/docs/labelled.svg new file mode 100644 index 0000000000..a439c85c04 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/labelled.svg @@ -0,0 +1 @@ +LabelledNodeLabelledOutputLabelledILabelledMultiDiGraphILabelledMultiDiGraphViewLabelledMultiDiGraphLabelledMultiDiGraphViewLabelledMultiDiSubgraphViewINodeLabelledMultiDiGraphINodeLabelledMultiDiGraphViewINodeLabelledOpenMultiDiGraphINodeLabelledOpenMultiDiGraphViewNodeLabelledMultiDiGraphNodeLabelledMultiDiGraphViewNodeLabelledMultiDiSubgraphViewNodeLabelledOpenMultiDiGraphNodeLabelledOpenMultiDiGraphViewUnorderedNodeLabelledOpenMultiDiGraphIOutputLabelledMultiDiGraphIOutputLabelledMultiDiGraphViewIOutputLabelledOpenMultiDiGraphIOutputLabelledOpenMultiDiGraphViewOutputLabelledMultiDiGraphOutputLabelledMultiDiGraphViewOutputLabelledOpenMultiDiGraphOutputLabelledOpenMultiDiGraphViewOutputLabelledOpenMultiDiSubgraphViewUnorderedOutputLabelledMultiDiGraphUnorderedOutputLabelledOpenMultiDiGraphViewMultiDiGraphAsOutputLabelledViewOutputLabelledAsOutputLabelledOpen \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/open.svg b/lib/utils/include/utils/graph/docs/open.svg new file mode 100644 index 0000000000..87766063f4 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/open.svg @@ -0,0 +1 @@ +OpenDownwardUpwardAdjacencyOpenMultiDiGraphIOpenMultiDiGraphIOpenMultiDiGraphViewOpenMultiDiGraphOpenMultiDiGraphViewOpenMultiDiSubgraphViewViewMultiDiGraphAsOpenMultiDiGraphDownwardOpenMultiDiGraphDownwardOpenMultiDiGraphViewDownwardOpenMultiDiSubgraphViewIDownwardOpenMultiDiGraphIDownwardOpenMultiDiGraphViewIUpwardOpenMultiDiGraphIUpwardOpenMultiDiGraphViewUpwardOpenMultiDiGraphUpwardOpenMultiDiGraphViewUpwardOpenMultiDiSubgraphView \ No newline at end of file diff --git a/lib/utils/include/utils/graph/docs/undirected.svg b/lib/utils/include/utils/graph/docs/undirected.svg new file mode 100644 index 0000000000..f04d893a45 --- /dev/null +++ b/lib/utils/include/utils/graph/docs/undirected.svg @@ -0,0 +1 @@ +UndirectedHashmapUndirectedGraphIUndirectedGraphIUndirectedGraphViewJoinedUndirectedGraphViewUndirectedGraphUndirectedGraphViewViewDiGraphAsUndirectedGraphViewUndirectedGraphAsDiGraph \ No newline at end of file diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index 808981afa1..d7c2c1590b 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -17,6 +17,10 @@ FF_VISIT_FMTABLE(MultiDiInput); struct MultiDiOutput : DiOutput { NodePort src_idx; + + bool operator>(MultiDiOutput const &) const; + bool operator>=(MultiDiOutput const &) const; + bool operator<=(MultiDiOutput const &) const; }; FF_VISITABLE_STRUCT(MultiDiOutput, src, src_idx); FF_VISIT_FMTABLE(MultiDiOutput); diff --git a/lib/utils/include/utils/integer_conversions.h b/lib/utils/include/utils/integer_conversions.h new file mode 100644 index 0000000000..154aaa2a67 --- /dev/null +++ b/lib/utils/include/utils/integer_conversions.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_INTEGER_CONVERSIONS_H + +#include + +namespace FlexFlow { + +size_t size_t_from_int(int); +int int_from_size_t(size_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h new file mode 100644 index 0000000000..db82004317 --- /dev/null +++ b/lib/utils/include/utils/join_strings.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JOIN_STRINGS_H + +#include +#include + +namespace FlexFlow { + +template +std::string join_strings(InputIt first, + InputIt last, + std::string const &delimiter, + F const &f) { + std::ostringstream oss; + bool first_iter = true; + /* int i = 0; */ + for (; first != last; first++) { + if (!first_iter) { + oss << delimiter; + } + oss << *first; + /* break; */ + first_iter = false; + /* i++; */ + } + return oss.str(); +} + +template +std::string + join_strings(InputIt first, InputIt last, std::string const &delimiter) { + using Ref = typename InputIt::reference; + return join_strings(first, last, delimiter, [](Ref r) { return r; }); +} + +template +std::string join_strings(Container const &c, std::string const &delimiter) { + return join_strings(c.cbegin(), c.cend(), delimiter); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json.h index 010943a9f9..f56917e329 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json.h @@ -143,40 +143,53 @@ struct VariantToJsonFunctor { void operator()(T const &t) { static_assert(is_jsonable::value, ""); - j["type"] = get_name(t); - j["value"] = t; + j = t; } }; template void variant_to_json(json &j, std::variant const &v) { - visit(::FlexFlow::VariantToJsonFunctor{j}, v.value); + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); } -template -struct VariantFromJsonFunctor { - VariantFromJsonFunctor(json const &j) : j(j) {} +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; - json const &j; - - template - std::optional - operator()(std::integral_constant const &) const { - using Type = typename std::variant_alternative::type; + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} - if (visit_struct::get_name()) { - return j.at("value").get(); +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); } } -}; + return std::nullopt; +} template std::variant variant_from_json(json const &j) { - ::FlexFlow::VariantFromJsonFunctor> func(j); - auto result = seq_map(func, seq_enumerate_args_t{}); + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); if (!result.has_value()) { throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("type").get()); + j.at("index").get()); } return result.value(); } diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 71b6d9d975..2594a96c8e 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "fmt.h" +#include "rapidcheck.h" #include "utils/exception.h" #include "utils/optional.decl" @@ -23,12 +24,28 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } +template +std::optional> transform(std::optional const &o, + F &&f) { + using Return = std::invoke_result_t; + if (o.has_value()) { + Return r = f(o.value()); + return std::optional{r}; + } else { + return std::nullopt; + } +} + } // namespace FlexFlow namespace fmt { -template -struct formatter<::std::optional> : formatter { +template +struct formatter< + ::std::optional, + Char, + std::enable_if_t>::value>> + : formatter { template auto format(::std::optional const &q, FormatContext &ctx) -> decltype(ctx.out()) { @@ -44,4 +61,18 @@ struct formatter<::std::optional> : formatter { } // namespace fmt +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/overload.h b/lib/utils/include/utils/overload.h new file mode 100644 index 0000000000..7e0431eba5 --- /dev/null +++ b/lib/utils/include/utils/overload.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OVERLOAD_H + +namespace FlexFlow { + +template +struct overload : Ts... { + using Ts::operator()...; +}; +template +overload(Ts...) -> overload; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 71b092d2c1..0074877768 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,6 +4,7 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/type_traits.h" #include #include @@ -64,6 +65,19 @@ struct stack_basic_string { template using stack_string = stack_basic_string; +template +void to_json(json &j, stack_string const &v) { + std::string as_string = v; + j = as_string; +} + +template +void from_json(json const &j, stack_string &v) { + std::string as_string; + j.get_to(as_string); + v = stack_string{as_string}; +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index fe665ed749..d47886b055 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,7 +3,9 @@ #include "containers.h" #include "hash-utils.h" +#include "rapidcheck.h" #include "utils/fmt.h" +#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include @@ -38,9 +40,10 @@ struct stack_vector { template stack_vector(Iterator start, Iterator end) { - assert(end - start >= 0); - assert(end - start <= MAXSIZE); - for (; start < end; start++) { + size_t elements_added = 0; + for (; start != end; start++) { + elements_added++; + assert(elements_added <= MAXSIZE); this->push_back(static_cast(*start)); } } @@ -310,8 +313,24 @@ struct stack_vector { implies, is_lt_comparable>::value, ""); }; +template +struct delegate_ostream_operator> : std::true_type {}; + // CHECK_FMTABLE(stack_vector); +template +void to_json(json &j, stack_vector const &v) { + std::vector as_vec(v.begin(), v.end()); + j = as_vec; +} + +template +void from_json(json const &j, stack_vector &v) { + std::vector as_vec; + j.get_to(as_vec); + v = stack_vector{as_vec.begin(), as_vec.end()}; +} + } // namespace FlexFlow namespace std { @@ -329,4 +348,18 @@ struct hash<::FlexFlow::stack_vector> { } // namespace std +namespace rc { + +template +struct Arbitrary<::FlexFlow::stack_vector> { + static Gen<::FlexFlow::stack_vector> arbitrary() { + return gen::mapcat(gen::inRange(0, MAXSIZE), [](size_t size) { + return gen::container<::FlexFlow::stack_vector>( + size, gen::arbitrary()); + }); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/include/utils/type_index.h b/lib/utils/include/utils/type_index.h index 49e893faa0..77a377a48d 100644 --- a/lib/utils/include/utils/type_index.h +++ b/lib/utils/include/utils/type_index.h @@ -3,17 +3,18 @@ #include "fmt.h" #include +#include namespace FlexFlow { template -std::type_index type_index() { +std::type_index get_type_index_for_type() { return std::type_index(typeid(T)); } template bool matches(std::type_index idx) { - return idx == type_index(); + return idx == get_type_index_for_type(); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index 272caaffde..bb2286a9cd 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_VARIANT_H #define _FLEXFLOW_UTILS_VARIANT_H +#include "rapidcheck.h" #include "utils/type_traits.h" #include #include @@ -212,4 +213,15 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf(gen::cast>(gen::arbitrary())...); + } +}; + +} // namespace rc + #endif diff --git a/lib/utils/src/exception.cc b/lib/utils/src/exception.cc index 7dccdc3074..5f78491ef2 100644 --- a/lib/utils/src/exception.cc +++ b/lib/utils/src/exception.cc @@ -2,7 +2,12 @@ namespace FlexFlow { -not_implemented::not_implemented() - : std::logic_error("Function not yet implemented"){}; +not_implemented::not_implemented(std::string const &function_name, + std::string const &file_name, + int line) + : std::logic_error(fmt::format("Function '{}' not yet implemented at {}:{}", + function_name, + file_name, + line)){}; } diff --git a/lib/utils/src/utils/graph/multidiedge.cc b/lib/utils/src/utils/graph/multidiedge.cc new file mode 100644 index 0000000000..cd3655c8e6 --- /dev/null +++ b/lib/utils/src/utils/graph/multidiedge.cc @@ -0,0 +1,17 @@ +#include "utils/graph/multidiedge.h" + +namespace FlexFlow { + +bool MultiDiOutput::operator>(MultiDiOutput const &other) const { + return !(*this < other) && !(*this == other); +} + +bool MultiDiOutput::operator>=(MultiDiOutput const &other) const { + return !(*this < other); +} + +bool MultiDiOutput::operator<=(MultiDiOutput const &other) const { + return (*this < other) || (*this == other); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/integer_conversions.cc b/lib/utils/src/utils/integer_conversions.cc new file mode 100644 index 0000000000..34ee3109bf --- /dev/null +++ b/lib/utils/src/utils/integer_conversions.cc @@ -0,0 +1,17 @@ +#include "utils/integer_conversions.h" +#include +#include + +namespace FlexFlow { + +size_t size_t_from_int(int x) { + assert(x >= 0); + return static_cast(x); +} + +int int_from_size_t(size_t x) { + assert(x < std::numeric_limits::max()); + return static_cast(x); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/overload.cc b/lib/utils/src/utils/overload.cc new file mode 100644 index 0000000000..55bfbdc08d --- /dev/null +++ b/lib/utils/src/utils/overload.cc @@ -0,0 +1 @@ +#include "utils/overload.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index 47c7ebde6d..ff7683dbcd 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -1,7 +1,9 @@ -#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN #include "doctest/doctest.h" #include "utils/containers.decl.h" +#include "utils/fmt/expected.h" +#include #include +#include #include #include #include @@ -64,10 +66,11 @@ namespace doctest { // } // }; -// template -// struct StringMaker> { -// static String convert(std::vector const &vec) { -// return doctest_print_container(vec, "[ ", ", ", " ]").c_str(); -// } -// }; +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + } // namespace doctest diff --git a/lib/utils/test/common/src/main.cc b/lib/utils/test/common/src/main.cc new file mode 100644 index 0000000000..9522fa7fdb --- /dev/null +++ b/lib/utils/test/common/src/main.cc @@ -0,0 +1,2 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" diff --git a/lib/utils/test/src/test_optional.cc b/lib/utils/test/src/test_optional.cc new file mode 100644 index 0000000000..8ef9e18f18 --- /dev/null +++ b/lib/utils/test/src/test_optional.cc @@ -0,0 +1,10 @@ +#include "test/utils/doctest.h" +#include "utils/optional.h" +#include + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + CHECK(rc::check("generate", [](std::optional o) {})); + } +} diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 6c0ecf36f3..141cd30e95 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,7 @@ #include "test/utils/doctest.h" #include "utils/stack_vector.h" #include +#include using namespace FlexFlow; @@ -76,4 +77,11 @@ TEST_SUITE(FF_TEST_SUITE) { vector.push_back(20); CHECK(vector.back() == 20); } + + TEST_CASE_TEMPLATE("RC arbitrary", T, int, double, char) { + constexpr std::size_t MAXSIZE = 10; + CHECK(rc::check("within bound", [](stack_vector v) { + return v.size() <= MAXSIZE; + })); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 0fef782c0e..7cffe9fbe4 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,5 +1,6 @@ #include "test/utils/doctest.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { @@ -69,4 +70,10 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } + + TEST_CASE("RC arbitrary") { + CHECK(rc::check("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + })); + } } From 3b74d9961d6a0d4e9e772eecf6f0c7db54cfbfc7 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Wed, 5 Jun 2024 00:10:26 -0400 Subject: [PATCH 5/9] small fix --- .proj.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/.proj.toml b/.proj.toml index 2ca91fabb7..8916c5dc09 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,7 +6,6 @@ fix_compile_commands = false build_targets = [ "kernels", -<<<<<<< HEAD ] test_targets = [ # "utils-tests", From 09b39a816503a4720f74fda4fd83ccfca84ae479 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Wed, 5 Jun 2024 20:06:47 -0400 Subject: [PATCH 6/9] ckpt, currently cmake succeeds but need to figure system hardware detection problem --- .flake/pkgs/legion.nix | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/.flake/pkgs/legion.nix b/.flake/pkgs/legion.nix index 814ef85e00..5b1713e672 100644 --- a/.flake/pkgs/legion.nix +++ b/.flake/pkgs/legion.nix @@ -2,9 +2,12 @@ , stdenv , fetchFromGitLab , cmake +, clang , python3 -, cudaPackages ? { } -, cudaCapabilities ? [ "60" "70" "80" "86" ] +# , cudaPackages ? { } +# , cudaCapabilities ? [ "60" "70" "80" "86" ] +, rocm +, rocmPackages , maxDim ? 5 }: @@ -13,7 +16,7 @@ let cmakeFlag = x: if x then "1" else "0"; - inherit (cudaPackages) cudatoolkit; + # inherit (cudaPackages) cudatoolkit; in stdenv.mkDerivation rec { @@ -35,14 +38,33 @@ stdenv.mkDerivation rec { cmakeFlags = [ "-DLegion_USE_Python=1" "-DLegion_BUILD_BINDINGS=1" - "-DLegion_USE_CUDA=1" - "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" + "-DLegion_USE_HIP=1" + "-DHIP_THRUST_ROOT_DIR=${rocm}/hip-thrust" + + "-DLegion_USE_CUDA=0" + # "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" "-DLegion_MAX_DIM=${toString maxDim}" + ]; + preConfigure = '' + echo "configuring Legion" + echo "including rocm path" + export ROCM_PATH=${rocm} + export HIP_PATH=${rocm}/hip + export HIP_THRUST_ROOT_DIR=${rocm}/hip-thrust + echo "rocm path is $ROCM_PATH" + echo "hip path is $HIP_PATH" + echo "hip thrust path is $HIP_THRUST_ROOT_DIR" + ''; + + preUnpack = '' + echo "Running pre-unpack steps..." +''; + buildInputs = [ python3 - cudatoolkit + rocm ]; meta = with lib; { From e102e3ca8617ecc6f48881dbf99f9f66f04d5b0b Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Fri, 21 Jun 2024 00:02:35 -0400 Subject: [PATCH 7/9] add rocm support and config option for legion build --- .flake/pkgs/legion.nix | 74 ++++++++++++++++++++------------- .flake/pkgs/rocthrust.nix | 87 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 29 deletions(-) create mode 100644 .flake/pkgs/rocthrust.nix diff --git a/.flake/pkgs/legion.nix b/.flake/pkgs/legion.nix index 5b1713e672..d51e4c6e95 100644 --- a/.flake/pkgs/legion.nix +++ b/.flake/pkgs/legion.nix @@ -1,22 +1,37 @@ { lib -, stdenv , fetchFromGitLab , cmake -, clang +, config , python3 -# , cudaPackages ? { } -# , cudaCapabilities ? [ "60" "70" "80" "86" ] -, rocm -, rocmPackages +, cudaPackages ? { } +, cudaCapabilities ? [ "60" "70" "80" "86" ] +, rocmPackages ? { } , maxDim ? 5 +, useCuda ? config.cudaSupport +, useRocm ? config.rocmSupport +, stdenv ? if useCuda then cudaPackages.backendStdenv else rocmPackages.llvm.rocmClangStdenv }: # from https://codeberg.org/Uli/nix-things/src/commit/776519e382c81b136c1d0b10d8c7b52b4acb9192/overlays/cq/python/libclang-python.nix let cmakeFlag = x: if x then "1" else "0"; + inherit (cudaPackages) cudatoolkit; + inherit (lib) + cmakeBool + cmakeFeature + optionals + ; - # inherit (cudaPackages) cudatoolkit; + cudaBuildInputs = with cudaPackages; [ + cudatoolkit + ]; + rocmBuildInputs = with rocmPackages; [ + clr + rocthrust + rocprim + llvm.clang + ]; in stdenv.mkDerivation rec { @@ -38,38 +53,39 @@ stdenv.mkDerivation rec { cmakeFlags = [ "-DLegion_USE_Python=1" "-DLegion_BUILD_BINDINGS=1" - "-DLegion_USE_HIP=1" - "-DHIP_THRUST_ROOT_DIR=${rocm}/hip-thrust" - - "-DLegion_USE_CUDA=0" - # "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" "-DLegion_MAX_DIM=${toString maxDim}" + ] + ++ optionals useRocm [ + # TODO: this is the legacy way of setting hip compiler. Once we update nixpkgs version we should use the new way. It will be a quick fix + (cmakeFeature "Legion_USE_HIP" "1") + (cmakeFeature "HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets)) + (cmakeFeature "HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang") + (cmakeFeature "HIP_RUNTIME" "rocclr") + (cmakeFeature "HIP_PLATFORM" "amd") + (cmakeFeature "HIP_PATH" "${rocmPackages.clr}/hip") + (cmakeFeature "HIP_ROOT_DIR" "${rocmPackages.clr}") + (cmakeFeature "HIP_THRUST_ROOT_DIR" "${rocmPackages.rocthrust}") + (cmakeFeature "ROCM_PATH" "${rocmPackages.clr}") - ]; + (cmakeFeature "CMAKE_CXX_COMPILER" "hipcc") + (cmakeFeature "CMAKE_C_COMPILER" "hipcc") + ] + ++ optionals useCuda [ + (cmakeFeature "Legion_USE_CUDA" "1") + (cmakeFeature "CMAKE_CUDA_ARCH" (builtins.concatStringsSep ";" cudaCapabilities)) + ]; - preConfigure = '' - echo "configuring Legion" - echo "including rocm path" - export ROCM_PATH=${rocm} - export HIP_PATH=${rocm}/hip - export HIP_THRUST_ROOT_DIR=${rocm}/hip-thrust - echo "rocm path is $ROCM_PATH" - echo "hip path is $HIP_PATH" - echo "hip thrust path is $HIP_THRUST_ROOT_DIR" - ''; - preUnpack = '' - echo "Running pre-unpack steps..." -''; buildInputs = [ python3 - rocm - ]; + ] + ++ optionals useCuda cudaBuildInputs + ++ optionals useRocm rocmBuildInputs; meta = with lib; { description = "Legion is a parallel programming model for distributed, heterogeneous machines"; homepage = "https://github.com/StanfordLegion/legion"; license = licenses.asl20; }; -} +} \ No newline at end of file diff --git a/.flake/pkgs/rocthrust.nix b/.flake/pkgs/rocthrust.nix new file mode 100644 index 0000000000..8b0b7c37f0 --- /dev/null +++ b/.flake/pkgs/rocthrust.nix @@ -0,0 +1,87 @@ +{ lib +, stdenv +, fetchFromGitHub +, rocmUpdateScript +, cmake +, rocm-cmake +, rocprim +, clr +, gtest +, buildTests ? false +, buildBenchmarks ? false +, gpuTargets ? [ ] +}: + +stdenv.mkDerivation (finalAttrs: { + pname = "rocthrust"; + version = "6.0.2"; + + outputs = [ + "out" + ] ++ lib.optionals buildTests [ + "test" + ] ++ lib.optionals buildBenchmarks [ + "benchmark" + ]; + + src = fetchFromGitHub { + owner = "ROCm"; + repo = "rocThrust"; + rev = "rocm-${finalAttrs.version}"; + hash = "sha256-Zk7FxcedaDUbx9RCX8aWN0xZO/B5cOs/l5MDqZKQpJo="; + }; + + nativeBuildInputs = [ + cmake + rocm-cmake + rocprim + clr + ]; + + buildInputs = lib.optionals buildTests [ + gtest + ]; + + cmakeFlags = [ + "-DCMAKE_CXX_COMPILER=hipcc" + "-DHIP_ROOT_DIR=${clr}" + # Manually define CMAKE_INSTALL_ + # See: https://github.com/NixOS/nixpkgs/pull/197838 + "-DCMAKE_INSTALL_BINDIR=bin" + "-DCMAKE_INSTALL_LIBDIR=lib" + "-DCMAKE_INSTALL_INCLUDEDIR=include" + ] ++ lib.optionals (gpuTargets != [ ]) [ + "-DAMDGPU_TARGETS=${lib.concatStringsSep ";" gpuTargets}" + ] ++ lib.optionals buildTests [ + "-DBUILD_TEST=ON" + ] ++ lib.optionals buildBenchmarks [ + "-DBUILD_BENCHMARKS=ON" + ] ++ lib.optionals (buildTests || buildBenchmarks) [ + "-DCMAKE_CXX_FLAGS=-Wno-deprecated-builtins" # Too much spam + ]; + + postInstall = lib.optionalString buildTests '' + mkdir -p $test/bin + mv $out/bin/{test_*,*.hip} $test/bin + '' + lib.optionalString buildBenchmarks '' + mkdir -p $benchmark/bin + mv $out/bin/benchmark_* $benchmark/bin + '' + lib.optionalString (buildTests || buildBenchmarks) '' + rm -rf $out/bin + ''; + + passthru.updateScript = rocmUpdateScript { + name = finalAttrs.pname; + owner = finalAttrs.src.owner; + repo = finalAttrs.src.repo; + }; + + meta = with lib; { + description = "ROCm parallel algorithm library"; + homepage = "https://github.com/ROCm/rocThrust"; + license = with licenses; [ asl20 ]; + maintainers = teams.rocm.members; + platforms = platforms.linux; + broken = versions.minor finalAttrs.version != versions.minor stdenv.cc.version || versionAtLeast finalAttrs.version "7.0.0"; + }; +}) \ No newline at end of file From 1d2941e02a5f64fabb20ac4e237199c40907a6dc Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Fri, 21 Jun 2024 00:03:36 -0400 Subject: [PATCH 8/9] rm redundant pkg --- .flake/pkgs/rocthrust.nix | 87 --------------------------------------- 1 file changed, 87 deletions(-) delete mode 100644 .flake/pkgs/rocthrust.nix diff --git a/.flake/pkgs/rocthrust.nix b/.flake/pkgs/rocthrust.nix deleted file mode 100644 index 8b0b7c37f0..0000000000 --- a/.flake/pkgs/rocthrust.nix +++ /dev/null @@ -1,87 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, rocmUpdateScript -, cmake -, rocm-cmake -, rocprim -, clr -, gtest -, buildTests ? false -, buildBenchmarks ? false -, gpuTargets ? [ ] -}: - -stdenv.mkDerivation (finalAttrs: { - pname = "rocthrust"; - version = "6.0.2"; - - outputs = [ - "out" - ] ++ lib.optionals buildTests [ - "test" - ] ++ lib.optionals buildBenchmarks [ - "benchmark" - ]; - - src = fetchFromGitHub { - owner = "ROCm"; - repo = "rocThrust"; - rev = "rocm-${finalAttrs.version}"; - hash = "sha256-Zk7FxcedaDUbx9RCX8aWN0xZO/B5cOs/l5MDqZKQpJo="; - }; - - nativeBuildInputs = [ - cmake - rocm-cmake - rocprim - clr - ]; - - buildInputs = lib.optionals buildTests [ - gtest - ]; - - cmakeFlags = [ - "-DCMAKE_CXX_COMPILER=hipcc" - "-DHIP_ROOT_DIR=${clr}" - # Manually define CMAKE_INSTALL_ - # See: https://github.com/NixOS/nixpkgs/pull/197838 - "-DCMAKE_INSTALL_BINDIR=bin" - "-DCMAKE_INSTALL_LIBDIR=lib" - "-DCMAKE_INSTALL_INCLUDEDIR=include" - ] ++ lib.optionals (gpuTargets != [ ]) [ - "-DAMDGPU_TARGETS=${lib.concatStringsSep ";" gpuTargets}" - ] ++ lib.optionals buildTests [ - "-DBUILD_TEST=ON" - ] ++ lib.optionals buildBenchmarks [ - "-DBUILD_BENCHMARKS=ON" - ] ++ lib.optionals (buildTests || buildBenchmarks) [ - "-DCMAKE_CXX_FLAGS=-Wno-deprecated-builtins" # Too much spam - ]; - - postInstall = lib.optionalString buildTests '' - mkdir -p $test/bin - mv $out/bin/{test_*,*.hip} $test/bin - '' + lib.optionalString buildBenchmarks '' - mkdir -p $benchmark/bin - mv $out/bin/benchmark_* $benchmark/bin - '' + lib.optionalString (buildTests || buildBenchmarks) '' - rm -rf $out/bin - ''; - - passthru.updateScript = rocmUpdateScript { - name = finalAttrs.pname; - owner = finalAttrs.src.owner; - repo = finalAttrs.src.repo; - }; - - meta = with lib; { - description = "ROCm parallel algorithm library"; - homepage = "https://github.com/ROCm/rocThrust"; - license = with licenses; [ asl20 ]; - maintainers = teams.rocm.members; - platforms = platforms.linux; - broken = versions.minor finalAttrs.version != versions.minor stdenv.cc.version || versionAtLeast finalAttrs.version "7.0.0"; - }; -}) \ No newline at end of file From c34fb69e34eee6922adbd6e8ae567ba16e451419 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Fri, 21 Jun 2024 21:13:36 -0400 Subject: [PATCH 9/9] checkpoint --- .flake/pkgs/legion.nix | 2 ++ .proj.toml | 15 +++++++----- CMakeLists.txt | 2 +- flake.nix | 52 +++++++++++++++++++++++++++--------------- 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/.flake/pkgs/legion.nix b/.flake/pkgs/legion.nix index d51e4c6e95..543ba3c04b 100644 --- a/.flake/pkgs/legion.nix +++ b/.flake/pkgs/legion.nix @@ -67,6 +67,8 @@ stdenv.mkDerivation rec { (cmakeFeature "HIP_THRUST_ROOT_DIR" "${rocmPackages.rocthrust}") (cmakeFeature "ROCM_PATH" "${rocmPackages.clr}") + (cmakeFeature "HIP_INCLUDE_DIRS" "${rocmPackages.clr}/hip/include") + (cmakeFeature "CMAKE_CXX_COMPILER" "hipcc") (cmakeFeature "CMAKE_C_COMPILER" "hipcc") ] diff --git a/.proj.toml b/.proj.toml index bc276f3589..ee2d677c18 100644 --- a/.proj.toml +++ b/.proj.toml @@ -6,27 +6,30 @@ fix_compile_commands = false build_targets = [ "kernels", - "pcg", + # "pcg", # "substitutions", # "compiler", - "substitution-generator", - "local-execution", + # "substitution-generator", + # "local-execution", ] test_targets = [ # "utils-tests", # "substitutions-tests", # "compiler-tests", - "pcg", + # "pcg", # "substitutions", # "compiler", - "substitution-generator", + # "substitution-generator", ] [cmake_flags_extra] FF_USE_HIP_ROCM = "ON" FF_GPU_BACKEND = "hip_rocm" -CMAKE_CUDA_ARCHITECTURES = "60" +# CMAKE_CUDA_ARCHITECTURES = "60" CMAKE_HIP_ARCHITECTURES = "gfx900" +# HIP_PLATFORM = "amd" +# HIP_RUNTIME = "rocclr" CMAKE_CXX_COMPILER = "hipcc" CMAKE_C_COMPILER = "hipcc" + # FF_CUDA_ARCH = "60" diff --git a/CMakeLists.txt b/CMakeLists.txt index 5222af555a..211f9a867c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,7 +100,7 @@ include(doctestlib) # named doctestlib to avoid a name collision with doctest.cm include(visit_struct) include(CTest) include(fmt) -# include(legion) +include(legion) include(rapidcheck) #include(gtest) diff --git a/flake.nix b/flake.nix index 253d009225..b718ea05e4 100644 --- a/flake.nix +++ b/flake.nix @@ -37,6 +37,8 @@ }; lib = pkgs.lib; + inherit (pkgs.rocmPackages) clr miopen miopengemm rccl rocm-runtime; + rocm = pkgs.symlinkJoin { name = "rocm"; paths = with pkgs.rocmPackages; [ @@ -46,12 +48,13 @@ clr hipcc rccl + llvm.clang miopen miopengemm miopen-hip hipblas rocm-cmake - clr + hip-common ]; }; @@ -84,7 +87,13 @@ devShells = rec { ci = mkShell { shellHook = '' + export HIP_COMPILER="${pkgs.rocmPackages.llvm.clang}/bin/clang" export PATH="$HOME/ff/.scripts/:$PATH" + export ROCM_PATH=${clr} + export HIP_DEVICE_LIB_PATH="${pkgs.rocmPackages.rocm-device-libs}/amdgcn/bitcode" + # export HIP_ROOT_DIR=${clr} + # export HIP_PATH=${clr}/hip + # export HIP_INCLUDE_DIRS=${clr}/hip/include echo "ROCm path set to: $ROCM_PATH" ''; @@ -100,6 +109,14 @@ "-DFF_USE_EXTERNAL_RANGEV3=ON" "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" + + # hip related flags + "-DHIP_PLATFORM=amd" + # "-DHIP_RUNTIME=rocclr" + # "-DHIP_COMPILER=${pkgs.rocmPackages.llvm.clang}/bin/clang" + "-DHIP_PATH=${clr}/hip" + "-DHIP_ROOT_DIR=${clr}/hip" + ]; RC_PARAMS = "max_discard_ratio=100"; @@ -116,12 +133,6 @@ ccache pkg-config python3 - # cudatoolkit - # cudaPackages.cuda_nvcc - # cudaPackages.cudnn - # cudaPackages.nccl - # cudaPackages.libcublas - # cudaPackages.cuda_cudart tl-expected ]) (with self.packages.${system}; [ @@ -130,16 +141,21 @@ rapidcheckFull doctest ]) - [ rocm ] - # (with pkgs.rocmPackages; [ - # hipcc - # rccl - # miopen - # miopen-hip - # hipblas - # rocm-cmake - # clr - # ]) + (with pkgs.rocmPackages; [ + clr + miopen + miopengemm + rccl + rocm-runtime + hipblas + hipcc + hip-common + rocm-cmake + miopen-hip + rocm-thunk + rocm-device-libs + ]) + # [ rocm ] ]; }; @@ -184,4 +200,4 @@ }; } ); -} +} \ No newline at end of file