Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DLRM Model Computational Graph #1532

Merged
merged 20 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/dlrm/dlrm.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
Expand Down Expand Up @@ -68,6 +69,8 @@ tl::expected<ComputationGraph, std::string>
return get_candle_uno_computation_graph(get_default_candle_uno_config());
} else if (model_name == "bert") {
return get_bert_computation_graph(get_default_bert_config());
} else if (model_name == "dlrm") {
return get_dlrm_computation_graph(get_default_dlrm_config());
} else if (model_name == "split_test") {
int batch_size = 8;
return get_split_test_computation_graph(batch_size);
Expand Down Expand Up @@ -143,6 +146,7 @@ int main(int argc, char **argv) {
"inception_v3",
"candle_uno",
"bert",
"dlrm",
"split_test",
"single_operator"};
CLIArgumentKey key_model_name = cli_add_positional_argument(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
#include "models/dlrm/dlrm.h"
#include "models/inception_v3/inception_v3.h"
#include "models/split_test/split_test.h"
#include "models/transformer/transformer.h"
Expand Down Expand Up @@ -324,6 +325,16 @@ TEST_SUITE(FF_TEST_SUITE) {

CHECK(sp_decomposition.has_value());
}

SUBCASE("dlrm") {
ComputationGraph cg =
get_dlrm_computation_graph(get_default_dlrm_config());

std::optional<SeriesParallelDecomposition> sp_decomposition =
get_computation_graph_series_parallel_decomposition(cg);

CHECK(sp_decomposition.has_value());
}
}
}

Expand Down Expand Up @@ -393,5 +404,13 @@ TEST_SUITE(FF_TEST_SUITE) {
std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}

SUBCASE("dlrm") {
ComputationGraph cg =
get_dlrm_computation_graph(get_default_dlrm_config());

std::string result =
render_preprocessed_computation_graph_for_sp_decomposition(cg);
}
}
}
57 changes: 57 additions & 0 deletions lib/models/include/models/dlrm/dlrm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/**
* @file dlrm.h
*
* @brief DLRM model
*
* @details The DLRM implementation refers to the examples at
* https://github.com/flexflow/FlexFlow/blob/78307b0e8beb5d41ee003be8b5db168c2b3ef4e2/examples/cpp/DLRM/dlrm.cc
* and
* https://github.com/pytorch/torchrec/blob/7e7819e284398d7dc420e3bf149107ad310fa861/torchrec/models/dlrm.py#L440.
*/

#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H
#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_DLRM_H

#include "models/dlrm/dlrm_config.dtg.h"
#include "pcg/computation_graph_builder.h"

namespace FlexFlow {

// Helper functions to construct the DLRM model

/**
* @brief Get the default DLRM config.
*
* @details The configs here refer to the example at
* https://github.com/flexflow/FlexFlow/blob/78307b0e8beb5d41ee003be8b5db168c2b3ef4e2/examples/cpp/DLRM/dlrm.cc.
*/
DLRMConfig get_default_dlrm_config();

tensor_guid_t create_dlrm_mlp(ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &input,
std::vector<size_t> const &mlp_layers);

tensor_guid_t create_dlrm_sparse_embedding_network(ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &input,
int input_dim,
int output_dim);

tensor_guid_t create_dlrm_interact_features(
ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &bottom_mlp_output,
std::vector<tensor_guid_t> const &emb_outputs);

/**
* @brief Get the DLRM computation graph.
*
* @param DLRMConfig The config of DLRM model.
* @return ComputationGraph The computation graph of a DLRM model.
*/
ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config);

} // namespace FlexFlow

#endif
53 changes: 53 additions & 0 deletions lib/models/include/models/dlrm/dlrm_config.struct.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
namespace = "FlexFlow"
name = "DLRMConfig"

features = [
"eq",
"ord",
"hash",
"json",
"rapidcheck",
"fmt",
]

includes = [
"<vector>",
"<string>",
]

src_includes = [
"utils/fmt/vector.h",
"utils/hash/vector.h",
]

[[fields]]
name = "embedding_dim"
type = "int"

[[fields]]
name = "embedding_bag_size"
type = "int"

[[fields]]
name = "embedding_size"
type = "std::vector<int>"

[[fields]]
name = "dense_arch_layer_sizes"
type = "std::vector<int>"

[[fields]]
name = "over_arch_layer_sizes"
type = "std::vector<int>"

[[fields]]
name = "arch_interaction_op"
type = "std::string"

[[fields]]
name = "batch_size"
type = "int"

[[fields]]
name = "seed"
type = "int"
170 changes: 170 additions & 0 deletions lib/models/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include "models/dlrm/dlrm.h"
#include "pcg/computation_graph.h"
#include "utils/containers/concat_vectors.h"
#include "utils/containers/transform.h"
#include "utils/containers/zip.h"

namespace FlexFlow {

DLRMConfig get_default_dlrm_config() {
return DLRMConfig{
/*embedding_dim=*/64,
/*embedding_bag_size=*/1,
/*embedding_size=*/
std::vector<int>{
1000000,
1000000,
1000000,
1000000,
},
/*dense_arch_layer_sizes=*/
std::vector<int>{
4,
64,
64,
},
/*over_arch_layer_sizes=*/
std::vector<int>{
64,
64,
2,
},
/*arch_interaction_op=*/"cat",
/*batch_size=*/64,
/*seed=*/std::rand(),
};
}

tensor_guid_t create_dlrm_mlp(ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &input,
std::vector<int> const &mlp_layers) {
tensor_guid_t t = input;

// Refer to
// https://github.com/facebookresearch/dlrm/blob/64063a359596c72a29c670b4fcc9450bb342e764/dlrm_s_pytorch.py#L218-L228
// for example initializer.
for (size_t i = 0; i < mlp_layers.size() - 1; i++) {
float std_dev = sqrt(2.0f / (mlp_layers.at(i + 1) + mlp_layers.at(i)));
InitializerAttrs projection_initializer =
InitializerAttrs{NormInitializerAttrs{
/*seed=*/config.seed,
/*mean=*/0,
/*stddev=*/std_dev,
}};

std_dev = sqrt(2.0f / mlp_layers.at(i + 1));
InitializerAttrs bias_initializer = InitializerAttrs{NormInitializerAttrs{
/*seed=*/config.seed,
/*mean=*/0,
/*stddev=*/std_dev,
}};

t = cgb.dense(/*input=*/t,
/*outDim=*/mlp_layers.at(i + 1),
/*activation=*/Activation::RELU,
/*use_bias=*/true,
/*data_type=*/DataType::FLOAT,
/*projection_initializer=*/projection_initializer,
/*bias_initializer=*/bias_initializer);
}
return t;
}

tensor_guid_t create_dlrm_sparse_embedding_network(ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &input,
int input_dim,
int output_dim) {
float range = sqrt(1.0f / input_dim);
InitializerAttrs embed_initializer = InitializerAttrs{UniformInitializerAttrs{
/*seed=*/config.seed,
/*min_val=*/-range,
/*max_val=*/range,
}};

tensor_guid_t t = cgb.embedding(input,
/*num_entries=*/input_dim,
/*outDim=*/output_dim,
/*aggr=*/AggregateOp::SUM,
/*dtype=*/DataType::HALF,
/*kernel_initializer=*/embed_initializer);
return cgb.cast(t, DataType::FLOAT);
}

tensor_guid_t create_dlrm_interact_features(
ComputationGraphBuilder &cgb,
DLRMConfig const &config,
tensor_guid_t const &bottom_mlp_output,
std::vector<tensor_guid_t> const &emb_outputs) {
if (config.arch_interaction_op != "cat") {
throw mk_runtime_error(fmt::format(
"Currently only arch_interaction_op=cat is supported, but found "
"arch_interaction_op={}. If you need support for additional "
"arch_interaction_op value, please create an issue.",
config.arch_interaction_op));
}

return cgb.concat(
/*tensors=*/concat_vectors({bottom_mlp_output}, emb_outputs),
/*axis=*/1);
}

ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) {
ComputationGraphBuilder cgb;

auto create_input_tensor = [&](FFOrdered<size_t> const &dims,
DataType const &data_type) -> tensor_guid_t {
TensorShape input_shape = TensorShape{
TensorDims{dims},
data_type,
};
return cgb.create_input(input_shape, CreateGrad::YES);
};

// Create input tensors
std::vector<tensor_guid_t> sparse_inputs(
config.embedding_size.size(),
create_input_tensor({static_cast<size_t>(config.batch_size),
static_cast<size_t>(config.embedding_bag_size)},
DataType::INT64));

tensor_guid_t dense_input = create_input_tensor(
{static_cast<size_t>(config.batch_size),
static_cast<size_t>(config.dense_arch_layer_sizes.front())},
DataType::FLOAT);

// Construct the model
tensor_guid_t bottom_mlp_output = create_dlrm_mlp(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/dense_input,
/*mlp_layers=*/config.dense_arch_layer_sizes);

std::vector<tensor_guid_t> emb_outputs;
for (size_t i = 0; i < config.embedding_size.size(); i++) {
int input_dim = config.embedding_size.at(i);
emb_outputs.emplace_back(create_dlrm_sparse_embedding_network(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/sparse_inputs.at(i),
/*input_dim=*/input_dim,
/*output_dim=*/config.embedding_dim));
}

tensor_guid_t interacted_features = create_dlrm_interact_features(
/*cgb=*/cgb,
/*config=*/config,
/*bottom_mlp_output=*/bottom_mlp_output,
/*emb_outputs=*/emb_outputs);

tensor_guid_t output = create_dlrm_mlp(
/*cgb=*/cgb,
/*config=*/config,
/*input=*/interacted_features,
/*mlp_layers=*/config.over_arch_layer_sizes);

return cgb.computation_graph;
}

} // namespace FlexFlow
19 changes: 19 additions & 0 deletions lib/models/test/src/models/dlrm/dlrm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "models/dlrm/dlrm.h"
#include "pcg/computation_graph.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("get_dlrm_computation_graph") {
DLRMConfig config = get_default_dlrm_config();

ComputationGraph result = get_dlrm_computation_graph(config);

SUBCASE("num layers") {
int result_num_layers = get_layers(result).size();
int correct_num_layers = 27;
CHECK(result_num_layers == correct_num_layers);
}
}
}
Loading
Loading