Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

MODE models with hetereogeneous expert width #342

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 157 additions & 65 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TODO(noam): Remove the other copy of this code from tensor2tensor.
TODO(noam): Write a new, simpler, cleaner version of this code.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -29,10 +30,9 @@

import mesh_tensorflow as mtf
from mesh_tensorflow.transformer import transformer

import numpy as np
import tensorflow.compat.v1 as tf


@gin.configurable
class MoE1D(transformer.TransformerLayer):
"""Mixture of Experts Layer."""
Expand Down Expand Up @@ -64,7 +64,9 @@ def __init__(self,
z_loss=None,
word_embed_mode=None,
use_second_place_expert_prob=None,
use_second_place_expert_prob_temp=None):
use_second_place_expert_prob_temp=None,
num_layers=1,
heterogeneous_mask_info=None):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand Down Expand Up @@ -93,7 +95,9 @@ def __init__(self,
moe_use_second_place_expert_prob=(
use_second_place_expert_prob),
moe_use_second_place_expert_prob_temp=(
use_second_place_expert_prob_temp))
use_second_place_expert_prob_temp),
moe_num_layers=num_layers,
moe_heterogeneous_mask_info=heterogeneous_mask_info)
self._activation = activation

def call(self, context, x, losses=None):
Expand Down Expand Up @@ -125,7 +129,8 @@ def call(self, context, x, losses=None):
nonpadding=context.nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
token_embeddings=context.input_embeddings,
context=context)
if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
Expand Down Expand Up @@ -200,7 +205,7 @@ def call(self, context, x, losses=None):
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
num_microbatches=None, token_embeddings=None, context=None):
"""Local mixture of experts that works well on TPU.

Adapted from the paper https://arxiv.org/abs/1701.06538
Expand Down Expand Up @@ -279,6 +284,7 @@ def transformer_moe_layer_v1(
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
context: a Context.

Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
Expand Down Expand Up @@ -334,9 +340,24 @@ def transformer_moe_layer_v1(
#
# pylint: enable=line-too-long
orig_inputs = inputs
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

experts_dim = mtf.Dimension("experts", hparams.moe_num_experts)

if hparams.moe_heterogeneous_mask_info is not None:
tf.logging.info("moe_heterogeneous_mask_info: {}".format(
hparams.moe_heterogeneous_mask_info))
heterogeneous_mask = generate_heterogeneous_expert_masks(
hparams.moe_heterogeneous_mask_info,
hparams.moe_num_experts,
experts_dim,
mesh=inputs.mesh)
# overwrite num_layers and width with the mask dimension
#TODO(chaimerl) depending on whether function output is flattened or not
# this might need adjustment
hparams.moe_num_layers = heterogeneous_mask.shape[1].size
hparams.moe_hidden_size = heterogeneous_mask.shape[0].size
hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size)

# We "cheat" here and look at the mesh shape and layout. This is to ensure
# that the number of groups is a multiple of the mesh dimension
# over which those groups are split.
Expand Down Expand Up @@ -489,64 +510,81 @@ def transformer_moe_layer_v1(
input_dim
]))

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)

def _compute_output(hidden, layer_name):
"""Compute the output of the attention layer from the hidden vector."""
expert_output = mtf.layers.dense(
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
name=layer_name)

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))

# Split over experts -> split over batch
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output

if hparams.moe_use_experts_attention:
# We share k_h and v_h with no degradation in performance
q_h, k_h = h, h
outputs = []
q = _compute_output(q_h, layer_name="q_wo")
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
# Pretend we have heterogenous_mask with shape [num_layers, num_experts]
for layer in range(hparams.moe_num_layers):
with tf.variable_scope("expert_layer_{}".format(layer)):
res_h = 0.0
if layer > 0:
res_h = expert_inputs
expert_inputs = transformer.sublayer_rms_norm(
expert_inputs, None, context)

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
reduced_dims=expert_inputs.shape.dims[-1:],
new_dims=[hidden_dim],
expert_dims=[experts_dim],
activation_functions=activation, use_bias=False,
variable_dtype=variable_dtype, name="wi")

# apply dropout
if hparams.moe_dropout_rate != 0.0:
h = mtf.dropout(h, is_training=train,
keep_prob=1.0 - hparams.moe_dropout_rate)
#h = mtf.Print(h, [h], 'values of hidden activity before: ')
# only if heterogeneous
if hparams.moe_heterogeneous_mask_info is not None:
# Apply mask.
# TODO(chaimerl): change to include width --> needs to be applied
# within the expert --> h
heterogeneous_mask_slice = mtf.slice(
heterogeneous_mask, layer, 1, heterogeneous_mask.shape[1].name)

# Get rid of the expert layers dimension.
heterogeneous_mask_slice = mtf.reshape(heterogeneous_mask_slice,
[heterogeneous_mask_slice.shape[0],
heterogeneous_mask_slice.shape[-1]])
h *= mtf.cast(heterogeneous_mask_slice, h.dtype)
# h = mtf.Print(h, [h], 'values of hidden activity after: ')
# Q: what happens here? why going from expert_hidden dim to d_model dim
expert_output = mtf.layers.dense(
h, output_dim, expert_dims=[experts_dim], use_bias=False,
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
name="wo")

if layer < (hparams.moe_num_layers - 1):
expert_output = transformer.sublayer_dropout(
expert_output, None, context)
# import pdb; pdb.set_trace()
expert_output += res_h
expert_inputs = expert_output

# Extra reshape reduces communication cost for model-parallel versions.
# For model-parallel versions, this reshape causes an mtf.slice and for non-
# model-parallel versions, this has no effect.
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
expert_capacity_dim, d_model_split_dim
]))

# Split over experts -> split over batch
expert_output = mtf.reshape(
expert_output,
mtf.Shape([
outer_batch_dim,
experts_dim_unsplit,
num_groups_dim,
expert_capacity_dim,
output_dim,
]))
moe_output_dims = moe_input_dims[:-1] + [output_dim]
output = mtf.einsum([expert_output, combine_tensor],
mtf.Shape(moe_output_dims))
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
return output, loss * hparams.moe_loss_coef


def transformer_moe_layer_v2(
Expand Down Expand Up @@ -1720,3 +1758,57 @@ def __init__(self, **kwargs):

def add_hparam(self, k, v):
setattr(self, k, v)


def generate_heterogeneous_expert_masks(
mask_info, num_experts, experts_dim, mesh, default_width=256):
"""Returns mask of shape [num_layers, num_experts, hidden_size].

# mask_info
# num_experts: number of experts in the model
# experts_dim: mtf dimension for experts (partitioned)
# mesh: mesh object
#
# Example mask_info format:
# mask_info = [{'percent_number': .5, 'layers': 1, 'width':1},
# {'percent_number': .5, 'layers': 2, 'width':2}]
"""
# Get max num layers
max_i = 0
max_layers = [max(max_i, mask_i["layers"]) for mask_i in mask_info][-1]
# Get max width
max_width = [max(max_i, mask_i["width"])
for mask_i in mask_info][-1]*default_width
# Will be shape [max_width, max_layers, num_experts]
expert_mask = np.zeros([max_width, max_layers, 0])
for idx, mask_i in enumerate(mask_info):
if mask_i["percent_number"] < 1.0:
num_experts_in_mask = int(num_experts * mask_i["percent_number"])
else:
num_experts_in_mask = int(mask_i["percent_number"])
# this is ambivalent if percent_number=1 (could be either all or 1 expert)
# it looks though like the argument below takes care of that
if idx == (len(mask_info) - 1): # last position
num_experts_in_mask_tmp = num_experts - expert_mask.shape[2]
if num_experts_in_mask_tmp != num_experts_in_mask:
tf.logging.info(
"Expert layer probabilities do not evenly divide "
"the number of experts: {} {}".format(
num_experts_in_mask, num_experts_in_mask_tmp))
num_experts_in_mask = num_experts_in_mask_tmp
mask = np.zeros([int(max_width), int(max_layers),
num_experts_in_mask])
# Zero out the last layers of the experts.
mask[:(mask_i["width"]*default_width), :mask_i["layers"], :] = 1
expert_mask = np.concatenate([expert_mask, mask], axis=2) # expert dim
assert expert_mask.shape[2] == num_experts
tf.logging.info("heterogeneous mask: {}".format(expert_mask))

# Now import the numpy mask into Mesh TF.
layers_dim = mtf.Dimension("num_expert_layers", max_layers)
width_dim = mtf.Dimension("expert_hidden", max_width)
expert_mask_tf = tf.convert_to_tensor(expert_mask)
expert_mask_mtf = mtf.import_tf_tensor(
mesh, tf_tensor=expert_mask_tf,
shape=[width_dim, layers_dim, experts_dim])
return expert_mask_mtf