Skip to content

Commit

Permalink
Merge branch 'main' into fix/distribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov authored Jul 21, 2023
2 parents c201f9b + 304c3c4 commit 3efdf94
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.2.8
version = 1.2.9
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down Expand Up @@ -39,11 +39,11 @@ dev =
pytest==6.2.5
pytest-forked
pytest-asyncio==0.16.0
accelerate==0.15.0
accelerate==0.20.3
black==22.3.0
isort==5.10.1
psutil
peft>=0.3.0
peft==0.3.0
einops==0.6.1
[options.packages.find]
where = src
29 changes: 23 additions & 6 deletions src/tensor_parallel/slicing_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,31 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
assert model_config.model_type == "llama", f"Trying to pass {model_config.model_type} as llama config"

world_size = len(devices)
num_heads = model_config.num_attention_heads
head_dim = model_config.hidden_size // model_config.num_attention_heads
try:
num_kv = model_config.num_key_value_heads
q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
new_modeling = True
except AttributeError:
num_kv = model_config.num_attention_heads
q_per_kv = 1
new_modeling = False

gather_kv_across_ranks = CollectiveOperation(
world_size=world_size, func=lambda *kvs: gather_kv(*kvs, world_size=world_size)
) # this operation ensures that we get attention cache for all heads on each device

return Config(
config = Config(
state_rules={
# LlamaAttention
r".*self_attn\.q_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.q_proj\.weight$": SplitInChunks(
world_size=world_size, dim=0, chunk_size=q_per_kv * head_dim
),
r".*self_attn\.k_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.v_proj\.weight$": SplitInChunks(world_size=world_size, dim=0, chunk_size=head_dim),
r".*self_attn\.o_proj\.weight$": SplitInChunks(world_size=world_size, dim=1, chunk_size=head_dim),
r".*self_attn\.o_proj\.weight$": SplitInChunks(
world_size=world_size, dim=1, chunk_size=q_per_kv * head_dim
),
# LlamaFeedForward
r".*mlp\.gate_proj\.weight$": Split(world_size=world_size, dim=0),
r".*mlp\.down_proj\.weight$": Split(world_size=world_size, dim=1),
Expand All @@ -379,12 +390,18 @@ def get_llama_config(model_config: PretrainedConfig, devices: Sequence[torch.dev
},
attr_rules={
r".*self_attn$": {
"hidden_size": partial(split_inner_dim, num_heads=num_heads, world_size=world_size),
"num_heads": partial(split_num_heads, world_size=world_size),
"hidden_size": partial(split_inner_dim, num_heads=num_kv, world_size=world_size),
"num_heads": lambda n, rank: q_per_kv
* split_num_heads(n // q_per_kv, rank=rank, world_size=world_size),
}
},
)

if new_modeling:
config.attr_rules[r".*self_attn$"]["num_key_value_heads"] = partial(split_num_heads, world_size=world_size)

return config


def get_refined_web_config(model_config: PretrainedConfig, devices: Sequence[torch.device]) -> Config:
# We can't use `RWConfig`` since it's custom code
Expand Down
34 changes: 19 additions & 15 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def all_equal(iterator):
) # basically asserting that all of those have the same config


def prepare_model(model_name, use_lora):
if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")
if model_name == "BlackSamorez/llama-2-tiny-testing" and transformers.__version__ < "4.31":
pytest.skip(f"Not testing {model_name} with transformers=={transformers.__version__}")

try:
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True).float()
except KeyError as err:
pytest.skip(f"Could not create model {model_name} with error {err}")
if use_lora:
if model_name == "gpt2":
pytest.skip("Not testing LoRA for gpt2")
model = add_lora(model, model_name)
return model


@pytest.mark.parametrize("use_lora", [False, True])
@pytest.mark.parametrize("use_config", [False, True])
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3])
Expand All @@ -83,27 +100,14 @@ def all_equal(iterator):
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
"Salesforce/codegen-350M-mono",
"Bingsu/llama-190m-arch",
"BlackSamorez/llama-2-tiny-testing",
"BlackSamorez/falcon-40b-tiny-testing",
],
)
def test_forward_gpt2_like(use_lora, use_config, devices, model_name):
torch.manual_seed(0)

if model_name == "BlackSamorez/falcon-40b-tiny-testing" and torch.__version__ < "2.0":
pytest.skip(f"Not testing {model_name} with torch=={torch.__version__}")

try:
model = (
AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, trust_remote_code=True)
.float()
.to(devices[0])
)
except KeyError as err:
pytest.skip(f"Could not create model {model_name} with error {err}")
if use_lora:
if model_name == "gpt2":
pytest.skip("Not testing LoRA for gpt2")
model = add_lora(model, model_name)
model = prepare_model(model_name, use_lora)

inp1 = torch.randint(1, 1000, size=(2, 3), device=devices[0])
inp2 = torch.randint(1, 1000, size=(2, 1), device=devices[0])
Expand Down

0 comments on commit 3efdf94

Please sign in to comment.