Skip to content

Commit

Permalink
Merge pull request #91 from BlackSamorez/falcon
Browse files Browse the repository at this point in the history
Falcon predefined config
  • Loading branch information
IaroslavLisniak authored Jun 20, 2023
2 parents 8b7a3b4 + bbd5f40 commit f8475bc
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.2.6
version = 1.2.7
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down Expand Up @@ -44,6 +44,6 @@ dev =
isort==5.10.1
psutil
peft>=0.3.0

einops==0.6.1
[options.packages.find]
where = src
55 changes: 55 additions & 0 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,60 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
)


def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config:
# We can't use `RWConfig`` since it's custom code
assert model_config.model_type == "RefinedWeb", f"Trying to pass {model_config.model_type} as RefinedWeb config"
assert not model_config.bias and not model_config.alibi, f"Running Falcon with biases or alibi is not supported"

world_size = len(devices)
hidden_size = model_config.hidden_size
num_heads = model_config.n_head
num_kv = model_config.n_head_kv
head_dim = hidden_size // num_heads
q_per_kv = num_heads // num_kv

head_dim = model_config.hidden_size // model_config.num_attention_heads

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size)
) # this operation ensures that we get attention cache for all heads on each device

return Config(
state_rules={
# Attention
r".*self_attention\.query_key_value\.weight$": SplitInChunks(
world_size=world_size, dim=0, chunk_size=(2 + q_per_kv) * head_dim
),
r".*self_attention\.dense\.weight$": SplitInChunks(
world_size=world_size, dim=1, chunk_size=q_per_kv * head_dim
),
# MLP
r".*mlp\.dense_h_to_4h\.weight$": Split(world_size=world_size, dim=0),
r".*mlp\.dense_4h_to_h\.weight$": Split(world_size=world_size, dim=1),
# RWModel
r".*word_embeddings\.weight$": Split(world_size=world_size, dim=1),
},
input_rules={
r".*self_attention$": {"layer_past": select_kv_for_rank},
r".*lm_head$": {0: "split -1"}, # note: we need to split lm_head inputs because
# ... lm_head's weights (tied embeddings) are already split across input dimension
},
output_rules={
r".*self_attention$": {0: "sum", 1: gather_kv_across_ranks},
r".*\.mlp$": {0: "sum"},
r".*word_embeddings$": {0: "gather -1"},
r".*lm_head$": {0: "sum"},
},
attr_rules={
r".*self_attention$": {
"num_kv": partial(split_num_heads, world_size=world_size),
"num_heads": lambda n, rank: q_per_kv
* split_num_heads(n // q_per_kv, rank=rank, world_size=world_size),
}
},
)


PREDEFINED_CONFIGS: Dict[str, ConfigGetter] = {
"bloom": get_bloom_config,
"t5": get_t5_config,
Expand All @@ -394,4 +448,5 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
"gpt_neox": get_gpt_neox_config,
"codegen": get_codegen_config,
"llama": get_llama_config,
"RefinedWeb": get_refined_web_config,
}
18 changes: 16 additions & 2 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,21 @@ def all_equal(iterator):
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
"Salesforce/codegen-350M-mono",
"Bingsu/llama-190m-arch",
"BlackSamorez/falcon-40b-tiny-testing",
],
)
def test_forward_gpt2_like(use_lora, use_config, devices, model_name):
torch.manual_seed(0)

if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")

try:
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True).float().to(devices[0])
model = (
AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True)
.float()
.to(devices[0])
)
except KeyError as err:
pytest.skip(f"Could not create model {model_name} with error {err}")
if use_lora:
Expand Down Expand Up @@ -202,12 +210,16 @@ def test_forward_bert_like(use_lora, use_config, devices, model_name):
"gpt2",
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
"Bingsu/llama-190m-arch",
"BlackSamorez/falcon-40b-tiny-testing",
],
)
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
def test_generate(generate_kwargs, model_name, devices):
torch.manual_seed(0)

if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")

def _generate_scores(model, input_ids, generate_kwargs):
scores_tuple = model.generate(
input_ids,
Expand Down Expand Up @@ -237,7 +249,9 @@ def _assert_scores_allclose_long_enough(
)
else:
model = (
transformers.AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
transformers.AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, trust_remote_code=True
)
.float()
.to(devices[0])
)
Expand Down

0 comments on commit f8475bc

Please sign in to comment.