Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] support flux #252

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10,000 changes: 10,000 additions & 0 deletions FID_caption.txt

Large diffs are not rendered by default.

90 changes: 90 additions & 0 deletions calculate_FID.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os
import requests
import zipfile
from PIL import Image
from pytorch_fid import fid_score
import torch
import argparse


# Paths
COCO_DOWNLOAD_URL = "http://images.cocodataset.org/zips/val2017.zip"
COCO_DIR = "coco_dataset"
REAL_IMAGES_DIR = os.path.join(COCO_DIR, "real_images")


def download_coco(output_dir):
os.makedirs(output_dir, exist_ok=True)
coco_zip_path = os.path.join(output_dir, "val2017.zip")

# Download MS-COCO validation set
print("Downloading MS-COCO validation images...")
with requests.get(COCO_DOWNLOAD_URL, stream=True) as r:
r.raise_for_status()
with open(coco_zip_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

print("Extracting images...")
with zipfile.ZipFile(coco_zip_path, "r") as zip_ref:
zip_ref.extractall(output_dir)

os.rename(os.path.join(output_dir, "val2017"), REAL_IMAGES_DIR)
os.remove(coco_zip_path)
print(f"MS-COCO images downloaded and extracted to {REAL_IMAGES_DIR}")

# Function to preprocess images (resize to 299x299)
def preprocess_images(input_dir, output_dir, target_size=(299, 299)):
os.makedirs(output_dir, exist_ok=True)
for img_name in os.listdir(input_dir):
img_path = os.path.join(input_dir, img_name)
output_path = os.path.join(output_dir, img_name)
with Image.open(img_path) as img:
img = img.convert("RGB").resize(target_size)
img.save(output_path)
print(f"Preprocessed images saved to {output_dir}")

# Function to calculate FID
def calculate_fid(real_dir, generated_dir):
print("Calculating FID score...")
fid_value = fid_score.calculate_fid_given_paths(
[real_dir, generated_dir],
batch_size=50, # Adjust based on your hardware
device="cuda" if torch.cuda.is_available() else "cpu",
dims=2048, # Default feature dimensions for Inception-v3 pool3
)
print(f"FID Score: {fid_value}")
return fid_value




if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Calculate FID for generated images.")
parser.add_argument(
"generated_images_dir",
type=str,
help="Path to the directory containing generated images."
)
args = parser.parse_args()

GENERATED_IMAGES_DIR = args.generated_images_dir
if not os.path.exists(REAL_IMAGES_DIR):
download_coco(COCO_DIR)
preprocessed_coco_dir = os.path.join(COCO_DIR, "preprocessed_real_images")
preprocess_images(REAL_IMAGES_DIR, preprocessed_coco_dir)

if not os.path.exists(GENERATED_IMAGES_DIR):
raise ValueError("Generated images directory does not exist. Add your images to 'generated_images/'.")

print(f"Calculating FID for {GENERATED_IMAGES_DIR}...")
fid_value = calculate_fid(preprocessed_coco_dir, GENERATED_IMAGES_DIR)

fid_score_file = os.path.join(GENERATED_IMAGES_DIR, "fid_score.txt")
with open(fid_score_file, "w") as f:
f.write(f"FID Score: {fid_value}\\n")
print(f"FID score saved to {fid_score_file}")



# python calculate_FID.py --generated_images_dir outputs/flux-pab
111 changes: 111 additions & 0 deletions examples/flux/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from videosys.pipelines.flux.pipeline_flux_pab import FluxConfig, FluxPipeline, FluxPABConfig
import torch
import time


def run_base():
# change num_gpus for multi-gpu inference
# sampling parameters are defined in the config

# config = OpenSoraConfig(num_sampling_steps=30, cfg_scale=7.0, num_gpus=1)
# engine = VideoSysEngine(config)
config = FluxConfig()

pipe = FluxPipeline(
config = config
)

prompts = [
"Sunset over the sea.",
"A group of people wading around a large body of water.",
"A cabinet towels a toilet and a sink.",
"A few boys read comic books together outside.",
"There is a man and a young girl snowboarding.",
"A male snowboarder sits with his board on a snowy hill.",
"A woman wearing a skirt shows off her tattoos.",
"A large dog is tied up to a fire hydrant.",
"A man making a pizza in a kitchen."
]

# Generate and save images
for prompt in prompts:
start_time = time.time()
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cuda:0").manual_seed(0)
).images[0]
end_time = time.time()
elapsed_time = end_time - start_time
print(f"'{prompt}' | {elapsed_time:.2f} s.")
pipe.save_image(image, f"./outputs/flux/{prompt}.png")

def run_pab():
pab_config = FluxPABConfig(
spatial_broadcast=True,
spatial_threshold=[100, 930],
spatial_range=5,
temporal_broadcast=False,
cross_broadcast=True,
cross_threshold=[100, 930],
cross_range=5,
mlp_broadcast=True
)
config = FluxConfig(
enable_pab=True,
pab_config=pab_config)
pipe = FluxPipeline(
config = config
)

prompts = [
"Sunset over the sea.",
"A group of people wading around a large body of water.",
"A cabinet towels a toilet and a sink.",
"A few boys read comic books together outside.",
"There is a man and a young girl snowboarding.",
"A male snowboarder sits with his board on a snowy hill.",
"A woman wearing a skirt shows off her tattoos.",
"A large dog is tied up to a fire hydrant.",
"A man making a pizza in a kitchen."
]

for prompt in prompts:
start_time = time.time()
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cuda:0").manual_seed(0)
).images[0]
end_time = time.time()
elapsed_time = end_time - start_time
print(f"'{prompt}' | {elapsed_time:.2f} s.")
pipe.save_image(image, f"./outputs/flux-pab/{prompt.replace(' ', '_')}.png")

# results = pipe(
# prompts,
# height=1024,
# width=1024,
# guidance_scale=3.5,
# num_inference_steps=50,
# max_sequence_length=512,
# generator=torch.Generator("cuda:0").manual_seed(0)
# )

# for idx, image in enumerate(results.images):
# safe_filename = f"./outputs/flux-pab-batch/image_{idx}.png"
# pipe.save_image(image, safe_filename)


if __name__ == "__main__":
run_base()
# run_low_mem()
run_pab()
139 changes: 139 additions & 0 deletions examples/flux/sample_pab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# process_captions.py
import os
import time
import argparse
import csv
from datetime import datetime
from typing import List, Tuple
from videosys.pipelines.flux.pipeline_flux_pab import FluxConfig, FluxPipeline, FluxPABConfig

def read_caption_file(caption_file: str, start_idx: int, end_idx: int) -> List[Tuple[str, str]]:
"""Read specific range of captions from file"""
caption_data = []
with open(caption_file, 'r') as f:
for i, line in enumerate(f):
if i >= start_idx and i < end_idx:
image_id, caption = line.strip().split(' ', 1)
caption_data.append((image_id, caption))
if i >= end_idx:
break
return caption_data

def log_to_csv(log_file: str, entries: List[dict]):
"""Write generation log to CSV file"""
fieldnames = ['batch_id', 'image_ids', 'prompts', 'start_time', 'end_time', 'processing_time']

file_exists = os.path.isfile(log_file)
with open(log_file, 'a', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
if not file_exists:
writer.writeheader()
writer.writerows(entries)

def process_captions(start_idx: int, end_idx: int,
caption_file: str = 'FID_caption.txt',
output_dir: str = './outputs/flux-pab',
log_file: str = './logs/generation_log.csv',
batch_size: int = 4,
height: int = 1024,
width: int = 1024,
guidance_scale: float = 3.5,
num_steps: int = 50,
max_seq_len: int = 512):

# Load captions
caption_data = read_caption_file(caption_file, start_idx, end_idx)
print(f"Processing captions from index {start_idx} to {end_idx}")

# Configure pipeline
pab_config = FluxPABConfig(
spatial_broadcast=True,
spatial_threshold=[100, 930],
spatial_range=5,
temporal_broadcast=False,
cross_broadcast=True,
cross_threshold=[100, 930],
cross_range=5,
mlp_broadcast=True
)
config = FluxConfig(
enable_pab=True,
pab_config=pab_config
)

# Initialize pipeline
pipe = FluxPipeline(
config=config,
device="cuda"
)

# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.dirname(log_file), exist_ok=True)

# Process captions
log_entries = []
total_batches = (len(caption_data) + batch_size - 1) // batch_size

for batch_idx in range(total_batches):
start_pos = batch_idx * batch_size
end_pos = min(start_pos + batch_size, len(caption_data))
batch_data = caption_data[start_pos:end_pos]
batch_ids, batch_prompts = zip(*batch_data)

start_time = datetime.now()

# Generate images
images = pipe(
list(batch_prompts),
height=height,
width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
max_sequence_length=max_seq_len
).images

end_time = datetime.now()
processing_time = (end_time - start_time).total_seconds()

# Save images
for image_id, image in zip(batch_ids, images):
save_path = os.path.join(output_dir, f"{image_id}.png")
image.save(save_path)

# Log batch information
log_entry = {
'batch_id': batch_idx,
'image_ids': ','.join(batch_ids),
'prompts': '|'.join(batch_prompts),
'start_time': start_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
'end_time': end_time.strftime('%Y-%m-%d %H:%M:%S.%f'),
'processing_time': processing_time
}
log_entries.append(log_entry)

print(f"Batch {batch_idx + 1}/{total_batches} processed in {processing_time:.2f}s")

if len(log_entries) >= 5:
log_to_csv(log_file, log_entries)
log_entries = []

if log_entries:
log_to_csv(log_file, log_entries)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process captions with Flux pipeline')
parser.add_argument('--start-idx', type=int, required=True)
parser.add_argument('--end-idx', type=int, required=True)
parser.add_argument('--caption-file', type=str, default='FID_caption.txt')
parser.add_argument('--output-dir', type=str, default='./outputs/flux-pab')
parser.add_argument('--log-file', type=str, default='./logs/generation_log.csv')
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--height', type=int, default=1024)
parser.add_argument('--width', type=int, default=1024)
parser.add_argument('--guidance-scale', type=float, default=3.5)
parser.add_argument('--num-steps', type=int, default=50)
parser.add_argument('--max-seq-len', type=int, default=512)

args = parser.parse_args()
process_captions(**vars(args))
2 changes: 1 addition & 1 deletion examples/open_sora_plan/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def run_base():
guidance_scale=7.5,
num_inference_steps=100,
seed=-1,
).video[0]
).images[0]
engine.save_video(video, f"./outputs/{prompt}.mp4")


Expand Down
43 changes: 43 additions & 0 deletions flux-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"

from videosys.pipelines.flux.pipeline_flux_pab import FluxConfig, FluxPipeline

pipe = FluxPipeline(
# "black-forest-labs/FLUX.1-dev",
# torch_dtype=torch.bfloat16,
# device_map="balanced",
config=FluxConfig()
)

# transformer = FluxTransformer2DModel.from_pretrained(
# "black-forest-labs/FLUX.1-dev",
# subfolder="transformer",
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True,
# offload_state_dict=False,
# ).to("cuda:1")

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map="balanced" # 自动分配到多个 GPU
)
pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"

image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cuda:0").manual_seed(0),
).images[0]

image.save("flux-dev.png")
Loading