Skip to content

Commit

Permalink
Update project.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 committed Dec 11, 2024
1 parent 11f667f commit 41b97fa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
site/
.cache/
.venv/
venv/
.ipynb_checkpoints/

# Misc
Expand All @@ -16,4 +17,4 @@ dist/
uv.lock

# Benchmark images
benchmark/visualizations
benchmark/visualizations
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
requires = ["setuptools>=42", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"

[project]
Expand All @@ -9,6 +9,7 @@ description = "Efficient Triton kernels for LLM Training"
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
dynamic = ["dependencies", "optional-dependencies"]

[tool.setuptools]
package-dir = {"" = "src"}
Expand Down
16 changes: 10 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def get_default_dependencies():
"""Determine the appropriate dependencies based on detected hardware."""
platform = _get_platform()
platform = get_platform()

if platform in ["cuda", "cpu"]:
return [
Expand All @@ -24,7 +24,7 @@ def get_default_dependencies():

def get_optional_dependencies():
"""Get optional dependency groups."""
platform = _get_platform()
platform = get_platform()

if platform in ["cuda", "cpu"]:
return {
Expand All @@ -43,7 +43,7 @@ def get_optional_dependencies():
"seaborn",
],
"transformers": [
"transformers>=4.44.2"
"transformers~=4.0"
]
}
elif platform == "rocm":
Expand All @@ -58,22 +58,26 @@ def get_optional_dependencies():
}


def _get_platform() -> Literal["cuda", "rocm", "cpu"]:
def get_platform() -> Literal["cuda", "rocm", "cpu"]:
"""
Detect whether the system has NVIDIA or AMD GPU without torch dependency.
"""
# Try nvidia-smi first
try:
subprocess.run(["nvidia-smi"], check=True, capture_output=True)
subprocess.run(["nvidia-smi"], check=True)
print("NVIDIA GPU detected")
return "cuda"
except (subprocess.SubprocessError, FileNotFoundError):
# If nvidia-smi fails, check for ROCm
try:
subprocess.run(["rocm-smi"], check=True, capture_output=True)
subprocess.run(["rocm-smi"], check=True)
print("ROCm GPU detected")
return "rocm"
except (subprocess.SubprocessError, FileNotFoundError):
print("No GPU detected")
return "cpu"


setup(
name="liger_kernel",
package_dir={"": "src"},
Expand Down

0 comments on commit 41b97fa

Please sign in to comment.