Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

componentwise config #68

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions src/sfast/compilers/stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
logger = logging.getLogger()


@dataclass
class ComponentCompileConfig:
text_encoder: bool = True,
text_encoder_2: bool = True,
image_processor: bool = True,
vae_encode: bool = True,
vae_decode: bool = True,
unet: bool = True,
controlnet: bool = True
scheduler: bool = False

class CompilationConfig:

@dataclass
Expand Down Expand Up @@ -53,8 +64,9 @@ class Default:
Triton generated CUDA kernels are faster than PyTorch's CUDA kernels.
However, Triton has a lot of bugs, and can increase the CPU overhead,
though the overhead can be reduced by enabling CUDA graph.
trace_scheduler:
Whether to trace the scheduler.
compiles:
Components to compile. If compiling any component gives you an
error message, you can set its entry to false.
'''
memory_format: torch.memory_format = (
torch.channels_last if gpu_device.device_has_tensor_core() else
Expand All @@ -70,6 +82,7 @@ class Default:
enable_cuda_graph: bool = False
enable_triton: bool = False
trace_scheduler: bool = False
compiles: ComponentCompileConfig = ComponentCompileConfig()


def compile(m, config):
Expand All @@ -78,9 +91,10 @@ def compile(m, config):
'cuda' if torch.cuda.is_available() else 'cpu')

enable_cuda_graph = config.enable_cuda_graph and device.type == 'cuda'

m.unet = compile_unet(m.unet, config)
if hasattr(m, 'controlnet'):

if config.compiles.unet:
m.unet = compile_unet(m.unet, config)
if hasattr(m, 'controlnet') and config.compiles.controlnet:
m.controlnet = compile_unet(m.controlnet, config)

if config.enable_xformers:
Expand All @@ -93,9 +107,9 @@ def compile(m, config):
lazy_trace_ = _build_lazy_trace(config)

# SVD doesn't have a text encoder
if getattr(m, 'text_encoder', None) is not None:
if getattr(m, 'text_encoder', None) is not None and config.compiles.text_encoder:
m.text_encoder.forward = lazy_trace_(m.text_encoder.forward)
if getattr(m, 'text_encoder_2', None) is not None:
if getattr(m, 'text_encoder_2', None) is not None and config.compiles.text_encoder_2:
m.text_encoder_2.forward = lazy_trace_(m.text_encoder_2.forward)
if (not packaging.version.parse('2.0.0') <= packaging.version.parse(
torch.__version__) < packaging.version.parse('2.1.0')):
Expand All @@ -106,37 +120,38 @@ def compile(m, config):

When executing AttnProcessor in TorchScript
"""
if hasattr(m.vae, 'decode'):
# May be incompatible with AutoencoderKLOutput's latent_dist of type
# DiagonalGaussianDistribution for img2img
if hasattr(m.vae, 'encode') and config.compiles.vae_encode:
m.vae.encode = lazy_trace_(m.vae.encode)
if hasattr(m.vae, 'decode') and config.compiles.vae_decode:
m.vae.decode = lazy_trace_(m.vae.decode)
# Incompatible with AutoencoderKLOutput's latent_dist of type DiagonalGaussianDistribution
# For img2img
# if hasattr(m.vae, 'encode'):
# m.vae.encode = lazy_trace_(m.vae.encode)
if config.trace_scheduler:
if config.compiles.scheduler:
m.scheduler.scale_model_input = lazy_trace_(
m.scheduler.scale_model_input)
m.scheduler.step = lazy_trace_(m.scheduler.step)

if enable_cuda_graph:
if getattr(m, 'text_encoder', None) is not None:
if getattr(m, 'text_encoder', None) is not None and config.compiles.text_encoder:
m.text_encoder.forward = make_dynamic_graphed_callable(
m.text_encoder.forward)
if getattr(m, 'text_encoder_2', None) is not None:
if getattr(m, 'text_encoder_2', None) is not None and config.compiles.text_encoder_2:
m.text_encoder_2.forward = make_dynamic_graphed_callable(
m.text_encoder_2.forward)
if hasattr(m.vae, 'decode'):
if hasattr(m.vae, 'encode') and config.compiles.vae_encode:
m.vae.encode = make_dynamic_graphed_callable(m.vae.encode)
if hasattr(m.vae, 'decode') and config.compiles.vae_decode:
m.vae.decode = make_dynamic_graphed_callable(m.vae.decode)
# if hasattr(m.vae, 'encode'):
# m.vae.encode = make_dynamic_graphed_callable(m.vae.encode)

if hasattr(m, 'image_processor'):
if hasattr(m, 'image_processor') and config.compiles.image_processor:
from sfast.libs.diffusers.image_processor import patch_image_prcessor
patch_image_prcessor(m.image_processor)

return m


def compile_unet(m, config):
assert config.compiles.unet or config.compiles.controlnet
# attribute `device` is not generally available
device = m.device if hasattr(m, 'device') else torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
Expand Down
2 changes: 1 addition & 1 deletion tests/compilers/test_stable_diffusion_pipeline_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def call_torch_compiled_model():
config.enable_triton = True
except ImportError:
logger.warning('triton not installed, skip')
# config.trace_scheduler = True
# config.compiles.scheduler = True
config.enable_cuda_graph = enable_cuda_graph
compiled_model = compile(load_model(), config)

Expand Down