Skip to content

Commit

Permalink
Merge pull request #87 from BlackSamorez/peft_unimport
Browse files Browse the repository at this point in the history
Removing PEFT from dependencies. Replacing with runtime checks
  • Loading branch information
Andrei Panferov authored Jun 14, 2023
2 parents 3bb330e + 3833fad commit 62e5424
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = tensor_parallel
version = 1.2.4
version = 1.2.5
author = Andrei Panferov and Yaroslav Lisnyak
author_email = [email protected]
description = Automatically shard your large model between multiple GPUs, works without torch.distributed
Expand Down Expand Up @@ -34,7 +34,6 @@ python_requires = >=3.7
install_requires =
torch>=1.11
transformers>=4.20.1
peft>=0.3.0
[options.extras_require]
dev =
pytest==6.2.5
Expand All @@ -44,6 +43,7 @@ dev =
black==22.3.0
isort==5.10.1
psutil
peft>=0.3.0

[options.packages.find]
where = src
4 changes: 2 additions & 2 deletions src/tensor_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Callable, Dict, Sequence, Union

import torch
from peft.tuners.lora import LoraLayer
from torch import nn

import tensor_parallel.cross_device_ops as cross_device_ops
Expand All @@ -19,6 +18,7 @@
NCCLAllReduce,
)
from tensor_parallel.state_actions import LegacyStateAction, Split, StateAction
from tensor_parallel.utils import check_lora

logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -129,7 +129,7 @@ def add_lora_rules(model: nn.Module, config: Config) -> Config:
lora_input_rules = {}
lora_output_rules = {}
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if check_lora(module=module):
for pattern, action in config.state_rules.items():
if pattern.search(name + ".weight") is not None:
if isinstance(action, Split):
Expand Down
7 changes: 7 additions & 0 deletions src/tensor_parallel/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pkg_resources


def verify_peft_version():
peft_version = pkg_resources.get_distribution("peft").version
if peft_version < "0.3.0":
raise ImportError("tensor_parallel only works with peft>=0.3.0")
20 changes: 20 additions & 0 deletions src/tensor_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
Based on: https://stackoverflow.com/questions/49739102/python-nested-dictionary-comparison
"""

from inspect import getmodule
from itertools import chain
from typing import Mapping, Optional, Sequence

from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

from tensor_parallel.imports import verify_peft_version


def nested_compare(t, u):
"""
Expand Down Expand Up @@ -123,3 +126,20 @@ def find_tied_weight_aliases(
find_tied_weight_aliases(module=submodule, destination=destination, prefix=prefix + name + ".")

return destination


def check_lora(module: nn.Module) -> bool:
"""Checks if module is lora Linear from a correct version of PEFT
Args:
module (nn.Module): module to check
Returns:
bool: result
"""
definition_module = getmodule(module)
if definition_module is not None and definition_module.__name__ == "peft.tuners.lora":
verify_peft_version()
return type(module).__name__ == "Linear"
else:
return False

0 comments on commit 62e5424

Please sign in to comment.