Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace shark_turbine with iree.turbine #870

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/resnet_18.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
from shark_turbine.aot import *
from iree.turbine.aot import *
from iree.compiler.ir import Context
import iree.runtime as rt
from turbine_models.custom_models.sd_inference import utils
import shark_turbine.ops.iree as ops
import iree.turbine.ops.iree as ops
import argparse

parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import numpy as np
from tqdm.auto import tqdm
from shark_turbine.ops.iree import trace_tensor
from iree.turbine.ops.iree import trace_tensor

torch.random.manual_seed(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import torch
from typing import Any, Callable, Dict, List, Optional, Union
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
import iree.turbine.ops.iree as ops
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from iree.compiler.ir import Context
import iree.runtime as ireert
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
from turbine_models.custom_models.sd3_inference.text_encoder_impls import (
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sd3_inference/sd3_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from turbine_models.custom_models.sd_inference import utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch, math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
from shark_turbine import ops
from iree.turbine import ops

#################################################################################################
### Core/Utility
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import re

from iree.compiler.ir import Context
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import List

import torch
from shark_turbine.aot import *
import shark_turbine.ops.iree as ops
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
import iree.turbine.ops.iree as ops
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from iree.compiler.ir import Context
import iree.runtime as ireert
import numpy as np
Expand Down
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/sd_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
6 changes: 3 additions & 3 deletions models/turbine_models/custom_models/sd_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/custom_models/sdxl_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass

from turbine_models.custom_models.sd_inference import utils
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context

from shark_turbine.aot import *
import shark_turbine.ops as ops
from iree.turbine.aot import *
import iree.turbine.ops as ops

from turbine_models.custom_models.sd_inference import utils
from turbine_models.custom_models.sd_inference.schedulers import get_scheduler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import numpy as np
from tqdm.auto import tqdm
from shark_turbine.ops.iree import trace_tensor
from iree.turbine.ops.iree import trace_tensor

torch.random.manual_seed(0)

Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.aot import *
from iree.turbine.transforms.general.add_metadata import AddMetadataPass


from turbine_models.custom_models.sd_inference import utils
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/sdxl_inference/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from iree import runtime as ireert
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from shark_turbine.dynamo.passes import (
from iree.turbine.aot import *
from iree.turbine.dynamo.passes import (
DEFAULT_DECOMPOSITIONS,
)
from turbine_models.custom_models.sd_inference import utils
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils import _pytree as pytree
from shark_turbine.aot import *
from iree.turbine.aot import *
from iree.compiler.ir import Context
from turbine_models.custom_models.llm_optimizations.streaming_llm.modify_llama import (
enable_llama_pos_shift_attention,
Expand Down Expand Up @@ -458,7 +458,7 @@ def evict_kvcache_space(self):
# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant
from iree.turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
import shark_turbine.aot as aot
import iree.turbine.aot as aot
from turbine_models.turbine_tank import turbine_tank
import os
import re
Expand Down
4 changes: 2 additions & 2 deletions models/turbine_models/tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import os
import numpy as np
from iree.compiler.ir import Context
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models.sd_inference import utils
from turbine_models.custom_models.pipeline_base import (
PipelineComponent,
TurbinePipelineBase,
)
from shark_turbine.transforms.general.add_metadata import AddMetadataPass
from iree.turbine.transforms.general.add_metadata import AddMetadataPass

model_metadata_forward = {
"model_name": "TestModel2xLinear",
Expand Down
2 changes: 1 addition & 1 deletion models/turbine_models/tests/stateless_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import tempfile

os.environ["TORCH_LOGS"] = "dynamic"
from shark_turbine.aot import *
from iree.turbine.aot import *
from turbine_models.custom_models import llm_runner

from turbine_models.gen_external_params.gen_external_params import (
Expand Down
Loading