-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b1acb45
commit 96e22c8
Showing
345 changed files
with
304,584 additions
and
2 deletions.
There are no files selected for viewing
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,390 @@ | ||
# coding=utf-8 | ||
|
||
import os | ||
|
||
import argparse | ||
from packaging import version | ||
|
||
from accelerate.utils import ProjectConfiguration, set_seed | ||
#### | ||
import torch | ||
import random | ||
import numpy as np | ||
|
||
def set_seeds(seed): | ||
set_seed(42) | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
|
||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
set_seeds(42) | ||
#### | ||
import pickle | ||
from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler | ||
from diffusers import DDIMScheduler | ||
|
||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Simple example of a training script.") | ||
parser.add_argument( | ||
"--pretrained_model_name_or_path", | ||
type=str, | ||
default=None, | ||
required=True, | ||
help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
) | ||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default=None, | ||
required=True, | ||
help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
) | ||
parser.add_argument( | ||
"--revision", | ||
type=str, | ||
default=None, | ||
required=False, | ||
help="Revision of pretrained model identifier from huggingface.co/models.", | ||
) | ||
parser.add_argument( | ||
"--dataset_name", | ||
type=str, | ||
default=None, | ||
help=( | ||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," | ||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," | ||
" or to a folder containing files that 🤗 Datasets can understand." | ||
), | ||
) | ||
parser.add_argument( | ||
"--dataset_config_name", | ||
type=str, | ||
default=None, | ||
help="The config of the Dataset, leave as None if there's only one config.", | ||
) | ||
parser.add_argument( | ||
"--train_data_dir", | ||
type=str, | ||
default=None, | ||
help=( | ||
"A folder containing the training data. Folder contents must follow the structure described in" | ||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" | ||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified." | ||
), | ||
) | ||
parser.add_argument( | ||
"--image_column", type=str, default="image", help="The column of the dataset containing an image." | ||
) | ||
parser.add_argument( | ||
"--caption_column", | ||
type=str, | ||
default="text", | ||
help="The column of the dataset containing a caption or a list of captions.", | ||
) | ||
parser.add_argument( | ||
"--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." | ||
) | ||
parser.add_argument( | ||
"--num_validation_images", | ||
type=int, | ||
default=4, | ||
help="Number of images that should be generated during validation with `validation_prompt`.", | ||
) | ||
parser.add_argument( | ||
"--validation_epochs", | ||
type=int, | ||
default=1, | ||
help=( | ||
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt" | ||
" `args.validation_prompt` multiple times: `args.num_validation_images`." | ||
), | ||
) | ||
parser.add_argument( | ||
"--max_train_samples", | ||
type=int, | ||
default=None, | ||
help=( | ||
"For debugging purposes or quicker training, truncate the number of training examples to this " | ||
"value if set." | ||
), | ||
) | ||
parser.add_argument( | ||
"--output_dir", | ||
type=str, | ||
default="sd-model-finetuned-lora", | ||
help="The output directory where the model predictions and checkpoints will be written.", | ||
) | ||
parser.add_argument( | ||
"--cache_dir", | ||
type=str, | ||
default=None, | ||
help="The directory where the downloaded models and datasets will be stored.", | ||
) | ||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") | ||
parser.add_argument( | ||
"--resolution", | ||
type=int, | ||
default=256, | ||
help=( | ||
"The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
" resolution" | ||
), | ||
) | ||
parser.add_argument( | ||
"--center_crop", | ||
default=False, | ||
action="store_true", | ||
help=( | ||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly" | ||
" cropped. The images will be resized to the resolution first before cropping." | ||
), | ||
) | ||
parser.add_argument( | ||
"--random_flip", | ||
action="store_true", | ||
help="whether to randomly flip images horizontally", | ||
) | ||
parser.add_argument( | ||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." | ||
) | ||
parser.add_argument("--num_train_epochs", type=int, default=100) | ||
parser.add_argument( | ||
"--max_train_steps", | ||
type=int, | ||
default=None, | ||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
) | ||
parser.add_argument( | ||
"--gradient_accumulation_steps", | ||
type=int, | ||
default=1, | ||
help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
) | ||
parser.add_argument( | ||
"--gradient_checkpointing", | ||
action="store_true", | ||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=1e-4, | ||
help="Initial learning rate (after the potential warmup period) to use.", | ||
) | ||
parser.add_argument( | ||
"--scale_lr", | ||
action="store_true", | ||
default=False, | ||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
) | ||
parser.add_argument( | ||
"--lr_scheduler", | ||
type=str, | ||
default="constant", | ||
help=( | ||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
' "constant", "constant_with_warmup"]' | ||
), | ||
) | ||
parser.add_argument( | ||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | ||
) | ||
parser.add_argument( | ||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
) | ||
parser.add_argument( | ||
"--allow_tf32", | ||
action="store_true", | ||
help=( | ||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | ||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | ||
), | ||
) | ||
parser.add_argument( | ||
"--dataloader_num_workers", | ||
type=int, | ||
default=0, | ||
help=( | ||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." | ||
), | ||
) | ||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") | ||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") | ||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") | ||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") | ||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | ||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") | ||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") | ||
parser.add_argument( | ||
"--hub_model_id", | ||
type=str, | ||
default=None, | ||
help="The name of the repository to keep in sync with the local `output_dir`.", | ||
) | ||
parser.add_argument( | ||
"--logging_dir", | ||
type=str, | ||
default="logs", | ||
help=( | ||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" | ||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." | ||
), | ||
) | ||
parser.add_argument( | ||
"--mixed_precision", | ||
type=str, | ||
default=None, | ||
choices=["no", "fp16", "bf16"], | ||
help=( | ||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | ||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | ||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | ||
), | ||
) | ||
parser.add_argument( | ||
"--report_to", | ||
type=str, | ||
default="tensorboard", | ||
help=( | ||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | ||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | ||
), | ||
) | ||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") | ||
parser.add_argument( | ||
"--checkpointing_steps", | ||
type=int, | ||
default=500, | ||
help=( | ||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" | ||
" training using `--resume_from_checkpoint`." | ||
), | ||
) | ||
parser.add_argument( | ||
"--checkpoints_total_limit", | ||
type=int, | ||
default=None, | ||
help=( | ||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." | ||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" | ||
" for more docs" | ||
), | ||
) | ||
parser.add_argument( | ||
"--resume_from_checkpoint", | ||
type=str, | ||
default=None, | ||
help=( | ||
"Whether training should be resumed from a previous checkpoint. Use a path saved by" | ||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' | ||
), | ||
) | ||
parser.add_argument( | ||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." | ||
) | ||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") | ||
|
||
parser.add_argument( | ||
"--snr_gamma", | ||
type=float, | ||
default=None, | ||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " | ||
"More details here: https://arxiv.org/abs/2303.09556.", | ||
) | ||
|
||
parser.add_argument("--gen_seed", type=int, default=0, help="A seed for reproducible training.") | ||
|
||
|
||
args = parser.parse_args() | ||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | ||
if env_local_rank != -1 and env_local_rank != args.local_rank: | ||
args.local_rank = env_local_rank | ||
|
||
# Sanity checks | ||
# if args.dataset_name is None and args.train_data_dir is None: | ||
# raise ValueError("Need either a dataset name or a training folder.") | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, | ||
torch_dtype=torch.float16).to('cuda') | ||
#### | ||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | ||
print('DDIM!') | ||
#### | ||
def dummpy(images, **kwargs): | ||
return images, False | ||
|
||
pipe.safety_checker = dummpy | ||
#### | ||
# if args.enable_xformers_memory_efficient_attention: | ||
# if is_xformers_available(): | ||
# import xformers | ||
|
||
# xformers_version = version.parse(xformers.__version__) | ||
# if xformers_version == version.parse("0.0.16"): | ||
# logger.warn( | ||
# "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | ||
# ) | ||
# # pipe.unet.enable_xformers_memory_efficient_attention() | ||
# pass | ||
# else: | ||
# raise ValueError("xformers is not available. Make sure it is installed correctly") | ||
#### | ||
print(args.model_path) | ||
pipe.unet.load_attn_procs(args.model_path) | ||
pipe.unet.eval() | ||
|
||
total = 1000 | ||
|
||
for i in range(0, total, args.train_batch_size): | ||
bsz = args.train_batch_size | ||
|
||
if total-i<args.train_batch_size: | ||
bsz = total-i | ||
|
||
print(bsz) | ||
print([i+j for j in range(bsz)]) | ||
#### | ||
prompt_list = [] | ||
for j in range(bsz): | ||
if i+j<500: | ||
prompt_list.append('a ukiyo e painting') | ||
else: | ||
prompt_list.append('a post impressionism painting') | ||
print(prompt_list) | ||
#### | ||
generator = [torch.Generator('cpu').manual_seed(args.gen_seed*total+i+j) for j in range(bsz)] | ||
|
||
images = pipe(prompt_list, | ||
height=args.resolution, | ||
width=args.resolution, | ||
num_inference_steps=50, | ||
generator=generator, | ||
eta=0.0).images | ||
|
||
print(len(images)) | ||
|
||
for idx, image in enumerate(images): | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
print(os.path.join(args.output_dir, '{}.png'.format(i+idx))) | ||
image.save(os.path.join(args.output_dir, '{}.png'.format(i+idx))) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
Oops, something went wrong.