From 7ec96619cc1ffb07c4014b4fbd80ee5ec53ffaca Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 12:47:17 +0000 Subject: [PATCH 1/8] dev deps update --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 351fe5b..a938ff2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,11 +39,11 @@ dev = pytest==6.2.5 pytest-forked pytest-asyncio==0.16.0 - accelerate==0.15.0 + accelerate==0.20.3 black==22.3.0 isort==5.10.1 psutil - peft>=0.3.0 + peft==0.4.0 einops==0.6.1 [options.packages.find] where = src From 73c5a65aa8acc8ffb09164d2261960076b67fd96 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 12:51:17 +0000 Subject: [PATCH 2/8] peft==0.3.0 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a938ff2..4596c47 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ dev = black==22.3.0 isort==5.10.1 psutil - peft==0.4.0 + peft==0.3.0 einops==0.6.1 [options.packages.find] where = src From 5d7716509df61a60d066be9a3f164c4d7c3fdb80 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 13:50:15 +0000 Subject: [PATCH 3/8] llama-2 --- src/tensor_parallel/slicing_configs.py | 16 ++++++++++++---- tests/test_transformers.py | 1 + 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/tensor_parallel/slicing_configs.py b/src/tensor_parallel/slicing_configs.py index 3eb59d1..62e4a2f 100644 --- a/src/tensor_parallel/slicing_configs.py +++ b/src/tensor_parallel/slicing_configs.py @@ -348,6 +348,8 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev world_size = len(devices) num_heads = model_config.num_attention_heads head_dim = model_config.hidden_size // model_config.num_attention_heads + num_kv = model_config.num_key_value_heads + q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads gather_kv_across_ranks = CollectiveOperation( world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size) @@ -356,10 +358,14 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev return Config( state_rules={ # LlamaAttention - r".*self_attn\.q_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), + r".*self_attn\.q_proj\.weight$": SplitInChunks( + world_size=world_size, dim=0, chunk_size=q_per_kv * head_dim + ), r".*self_attn\.k_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), r".*self_attn\.v_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim), - r".*self_attn\.o_proj\.weight$": SplitInChunks(world_size=world_size, dim=1, chunk_size=head_dim), + r".*self_attn\.o_proj\.weight$": SplitInChunks( + world_size=world_size, dim=1, chunk_size=q_per_kv * head_dim + ), # LlamaFeedForward r".*mlp\.gate_proj\.weight$": Split(world_size=world_size, dim=0), r".*mlp\.down_proj\.weight$": Split(world_size=world_size, dim=1), @@ -379,8 +385,10 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev }, attr_rules={ r".*self_attn$": { - "hidden_size": partial(split_inner_dim, num_heads=num_heads, world_size=world_size), - "num_heads": partial(split_num_heads, world_size=world_size), + "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size), + "num_key_value_heads": partial(split_num_heads, world_size=world_size), + "num_heads": lambda n, rank: q_per_kv + * split_num_heads(n // q_per_kv, rank=rank, world_size=world_size), } }, ) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 51edc24..aab638d 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -83,6 +83,7 @@ def all_equal(iterator): "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", "Salesforce/codegen-350M-mono", "Bingsu/llama-190m-arch", + "BlackSamorez/llama-2-tiny-testing", "BlackSamorez/falcon-40b-tiny-testing", ], ) From 4ae93c01818760a0d9488a74d201bb1adb40ff98 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 14:15:59 +0000 Subject: [PATCH 4/8] Back compat llama --- src/tensor_parallel/slicing_configs.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tensor_parallel/slicing_configs.py b/src/tensor_parallel/slicing_configs.py index 62e4a2f..fb0f405 100644 --- a/src/tensor_parallel/slicing_configs.py +++ b/src/tensor_parallel/slicing_configs.py @@ -346,10 +346,13 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config" world_size = len(devices) - num_heads = model_config.num_attention_heads head_dim = model_config.hidden_size // model_config.num_attention_heads - num_kv = model_config.num_key_value_heads - q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + try: + num_kv = model_config.num_key_value_heads + q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + except AttributeError: + num_kv = model_config.num_attention_heads + q_per_kv = 1 gather_kv_across_ranks = CollectiveOperation( world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size) From e0409366d23420e93e73ec32dfd219ec34c73050 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 14:34:47 +0000 Subject: [PATCH 5/8] back compat again --- src/tensor_parallel/slicing_configs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tensor_parallel/slicing_configs.py b/src/tensor_parallel/slicing_configs.py index fb0f405..f8504fa 100644 --- a/src/tensor_parallel/slicing_configs.py +++ b/src/tensor_parallel/slicing_configs.py @@ -350,15 +350,17 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev try: num_kv = model_config.num_key_value_heads q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads + new_modeling = True except AttributeError: num_kv = model_config.num_attention_heads q_per_kv = 1 + new_modeling = False gather_kv_across_ranks = CollectiveOperation( world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size) ) # this operation ensures that we get attention cache for all heads on each device - return Config( + config = Config( state_rules={ # LlamaAttention r".*self_attn\.q_proj\.weight$": SplitInChunks( @@ -389,13 +391,17 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev attr_rules={ r".*self_attn$": { "hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size), - "num_key_value_heads": partial(split_num_heads, world_size=world_size), "num_heads": lambda n, rank: q_per_kv * split_num_heads(n // q_per_kv, rank=rank, world_size=world_size), } }, ) + if new_modeling: + config.attr_rules[r".*self_attn$"]["num_key_value_heads"] = partial(split_num_heads, world_size=world_size) + + return config + def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config: # We can't use `RWConfig`` since it's custom code From a66f23b94382a3566f54f1a5007cbdbeb0efb701 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 14:50:55 +0000 Subject: [PATCH 6/8] skip llama2 with old transformers --- tests/test_transformers.py | 43 ++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index aab638d..9d2c439 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -72,39 +72,42 @@ def all_equal(iterator): ) # basically asserting that all of those have the same config +def prepare_model(model_name, use_lora): + if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0": + pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}") + if model_name == "BlackSamorez/llama-2-tiny-testing" and transformers.__version__ < "4.31": + pytest.skip(f"Not testing {model_name} with transformers=={transformers.__version__}") + + try: + model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True).float() + except KeyError as err: + pytest.skip(f"Could not create model {model_name} with error {err}") + if use_lora: + if model_name == "gpt2": + pytest.skip("Not testing LoRA for gpt2") + model = add_lora(model, model_name) + return model + + @pytest.mark.parametrize("use_lora", [False, True]) @pytest.mark.parametrize("use_config", [False, True]) @pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3]) @pytest.mark.parametrize( "model_name", [ - "bigscience/bloom-560m", - "gpt2", - "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", - "Salesforce/codegen-350M-mono", + # "bigscience/bloom-560m", + # "gpt2", + # "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + # "Salesforce/codegen-350M-mono", "Bingsu/llama-190m-arch", "BlackSamorez/llama-2-tiny-testing", - "BlackSamorez/falcon-40b-tiny-testing", + # "BlackSamorez/falcon-40b-tiny-testing", ], ) def test_forward_gpt2_like(use_lora, use_config, devices, model_name): torch.manual_seed(0) - if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0": - pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}") - - try: - model = ( - AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True) - .float() - .to(devices[0]) - ) - except KeyError as err: - pytest.skip(f"Could not create model {model_name} with error {err}") - if use_lora: - if model_name == "gpt2": - pytest.skip("Not testing LoRA for gpt2") - model = add_lora(model, model_name) + model = prepare_model(model_name, use_lora) inp1 = torch.randint(1, 1000, size=(2, 3), device=devices[0]) inp2 = torch.randint(1, 1000, size=(2, 1), device=devices[0]) From 18c38cc0987604c0aea00126d858b1a53d65b962 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 14:53:13 +0000 Subject: [PATCH 7/8] idk --- tests/test_transformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 9d2c439..a77e84e 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -95,13 +95,13 @@ def prepare_model(model_name, use_lora): @pytest.mark.parametrize( "model_name", [ - # "bigscience/bloom-560m", - # "gpt2", - # "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", - # "Salesforce/codegen-350M-mono", + "bigscience/bloom-560m", + "gpt2", + "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", + "Salesforce/codegen-350M-mono", "Bingsu/llama-190m-arch", "BlackSamorez/llama-2-tiny-testing", - # "BlackSamorez/falcon-40b-tiny-testing", + "BlackSamorez/falcon-40b-tiny-testing", ], ) def test_forward_gpt2_like(use_lora, use_config, devices, model_name): From 7fbd0a1c28e74eec2aba0cbafa408caacc65dbdb Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 21 Jul 2023 15:41:45 +0000 Subject: [PATCH 8/8] version bump --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 4596c47..030677c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = tensor_parallel -version = 1.2.8 +version = 1.2.9 author = Andrei Panferov and Yaroslav Lisnyak author_email = yalisnyak@nes.com description = Automatically shard your large model between multiple GPUs, works without torch.distributed