From cfb29a6144a7aca72e6184795f78262fe429e43b Mon Sep 17 00:00:00 2001 From: Luca Medeiros <67411094+luca-medeiros@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:15:50 +0900 Subject: [PATCH] fix default values for sam 2.1 --- app.py | 8 ++++---- lang_sam/lang_sam.py | 2 +- lang_sam/server.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/app.py b/app.py index 176da0f..6f10a74 100644 --- a/app.py +++ b/app.py @@ -45,7 +45,7 @@ def inference(sam_type, box_threshold, text_threshold, image, text_prompt): with gr.Blocks(title="lang-sam") as blocks: with gr.Row(): - sam_model_choices = gr.Dropdown(choices=list(SAM_MODELS.keys()), label="SAM Model", value="sam2_hiera_small") + sam_model_choices = gr.Dropdown(choices=list(SAM_MODELS.keys()), label="SAM Model", value="sam2.1_hiera_small") box_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, label="Box Threshold") text_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="Text Threshold") with gr.Row(): @@ -63,21 +63,21 @@ def inference(sam_type, box_threshold, text_threshold, image, text_prompt): examples = [ [ - "sam2_hiera_small", + "sam2.1_hiera_small", 0.32, 0.25, os.path.join(os.path.dirname(__file__), "assets", "fruits.jpg"), "kiwi. watermelon. blueberry.", ], [ - "sam2_hiera_small", + "sam2.1_hiera_small", 0.3, 0.25, os.path.join(os.path.dirname(__file__), "assets", "car.jpeg"), "wheel.", ], [ - "sam2_hiera_small", + "sam2.1_hiera_small", 0.3, 0.25, os.path.join(os.path.dirname(__file__), "assets", "food.jpg"), diff --git a/lang_sam/lang_sam.py b/lang_sam/lang_sam.py index 3c319d0..79b62d0 100644 --- a/lang_sam/lang_sam.py +++ b/lang_sam/lang_sam.py @@ -6,7 +6,7 @@ class LangSAM: - def __init__(self, sam_type="sam2_hiera_small", ckpt_path: str | None = None): + def __init__(self, sam_type="sam2.1_hiera_small", ckpt_path: str | None = None): self.sam_type = sam_type self.sam = SAM() self.sam.build_model(sam_type, ckpt_path) diff --git a/lang_sam/server.py b/lang_sam/server.py index e460af5..34911c7 100644 --- a/lang_sam/server.py +++ b/lang_sam/server.py @@ -14,7 +14,7 @@ class LangSAMAPI(ls.LitAPI): def setup(self, device: str) -> None: """Initialize or load the LangSAM model.""" - self.model = LangSAM(sam_type="sam2_hiera_small") + self.model = LangSAM(sam_type="sam2.1_hiera_small") print("LangSAM model initialized.") def decode_request(self, request) -> dict: