Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 authored Dec 11, 2024
2 parents 7d96535 + 966eb73 commit 11f667f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "liger_kernel"
version = "0.5.1"
version = "0.5.2"
description = "Efficient Triton kernels for LLM Training"
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
readme = { file = "README.md", content-type = "text/markdown" }
Expand Down
13 changes: 4 additions & 9 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

import pytest
import torch
import transformers
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

Expand Down Expand Up @@ -380,9 +378,8 @@ def run_mini_model_multimodal(
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__) >= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
),
pytest.param(
Expand All @@ -401,10 +398,8 @@ def run_mini_model_multimodal(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__)
>= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
],
),
Expand Down
13 changes: 4 additions & 9 deletions test/convergence/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

import pytest
import torch
import transformers
from datasets import load_from_disk
from packaging import version
from torch.utils.data import DataLoader
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
Expand Down Expand Up @@ -540,9 +538,8 @@ def run_mini_model(
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__) >= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
),
pytest.param(
Expand All @@ -561,10 +558,8 @@ def run_mini_model(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__)
>= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
),
],
),
Expand Down

0 comments on commit 11f667f

Please sign in to comment.