Skip to content

Commit

Permalink
GPU Specification (#108)
Browse files Browse the repository at this point in the history
* Made gpu id specification consistent across synthetic image generation models

* Changed gpu_id to device

* Docstring grammar

* add neuron.device to SyntheticImageGenerator init

* Fixed variable names

* adding device to start_validator.sh

* deprecating old/biased random prompt generation

* properly clear gpu of moderation pipeline

* simplifying usage of self.device

* fixing moderation pipeline device

* explicitly defining model/tokenizer for moderation pipeline to avoid accelerate auto device management

* deprecating random prompt generation

---------

Co-authored-by: benliang99 <[email protected]>
  • Loading branch information
dylanuys and benliang99 authored Nov 24, 2024
1 parent b3d0a57 commit f2cefdd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 96 deletions.
33 changes: 22 additions & 11 deletions bitmind/synthetic_image_generation/image_annotation_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Transformer models
from transformers import Blip2Processor, Blip2ForConditionalGeneration, pipeline
from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

# Logging and progress handling
from transformers import logging as transformers_logging
Expand Down Expand Up @@ -34,7 +35,7 @@ class ImageAnnotationGenerator:
text_moderation_pipeline (pipeline): A Hugging Face pipeline for text moderation.
Methods:
__init__(self, model_name: str, text_moderation_model_name: str, device: str = 'auto', apply_moderation: bool = True):
__init__(self, model_name: str, text_moderation_model_name: str, device: str = cuda, apply_moderation: bool = True):
Initializes the ImageAnnotationGenerator with the specified model, device, and moderation settings.
load_models(self):
Expand All @@ -59,7 +60,7 @@ class ImageAnnotationGenerator:
Generates text annotations for a batch of images from the specified datasets and calculates the average processing latency.
"""
def __init__(
self, model_name: str, text_moderation_model_name: str, device: str = 'auto',
self, model_name: str, text_moderation_model_name: str, device: str = "cuda",
apply_moderation: bool = True
):
"""
Expand All @@ -68,13 +69,10 @@ def __init__(
Args:
model_name (str): The name of the BLIP model for generating image captions.
text_moderation_model_name (str): The name of the model used for moderating text descriptions.
device (str): The device to use ('auto' to choose automatically between 'cuda' and 'cpu').
device (str): Device to use for model inference. Defaults to "cuda".
apply_moderation (bool): Flag to determine whether text moderation should be applied to captions.
"""
self.device = torch.device(
'cuda' if torch.cuda.is_available() and device == 'auto' else 'cpu'
)

self.device = device
self.model_name = model_name
self.processor = Blip2Processor.from_pretrained(
self.model_name, cache_dir=HUGGINGFACE_CACHE_DIR
Expand All @@ -99,11 +97,21 @@ def load_models(self):
bt.logging.info(f"Loaded image annotation model {self.model_name}")
bt.logging.info(f"Loading annotation moderation model {self.text_moderation_model_name}...")
if self.apply_moderation:
model = AutoModelForCausalLM.from_pretrained(
self.text_moderation_model_name,
torch_dtype=torch.bfloat16,
cache_dir=HUGGINGFACE_CACHE_DIR
)

tokenizer = AutoTokenizer.from_pretrained(
self.text_moderation_model_name,
cache_dir=HUGGINGFACE_CACHE_DIR
)
model = model.to(self.device)
self.text_moderation_pipeline = pipeline(
"text-generation",
model=self.text_moderation_model_name,
model_kwargs={"torch_dtype": torch.bfloat16, "cache_dir": HUGGINGFACE_CACHE_DIR},
device_map="auto"
model=model,
tokenizer=tokenizer
)
bt.logging.info(f"Loaded annotation moderation model {self.text_moderation_model_name}.")

Expand All @@ -114,7 +122,10 @@ def clear_gpu(self):
bt.logging.debug(f"Clearing GPU memory after generating image annotation")
self.model.to('cpu')
del self.model
self.model = None
if self.text_moderation_pipeline:
self.text_moderation_pipeline.model.to('cpu')
del self.text_moderation_pipeline
self.text_moderation_pipeline = None
gc.collect()
torch.cuda.empty_cache()
Expand Down
72 changes: 10 additions & 62 deletions bitmind/synthetic_image_generation/synthetic_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SyntheticImageGenerator:
diffuser_name (str): Name of the image diffuser model.
image_annotation_generator (ImageAnnotationGenerator): The generator object for annotating images if required.
image_cache_dir (str): Directory to cache generated images.
device (str): Device to use for model inference. Defaults to "cuda".
"""
def __init__(
self,
Expand All @@ -63,7 +64,7 @@ def __init__(
diffuser_name=DIFFUSER_NAMES[0],
use_random_diffuser=False,
image_cache_dir=None,
gpu_id=0
device="cuda"
):
if prompt_type not in PROMPT_TYPES:
raise ValueError(f"Invalid prompt type '{prompt_type}'. Options are {PROMPT_TYPES}")
Expand All @@ -75,6 +76,7 @@ def __init__(
self.use_random_diffuser = use_random_diffuser
self.prompt_type = prompt_type
self.prompt_generator_name = prompt_generator_name
self.device = device

self.diffuser = None
if self.use_random_diffuser and diffuser_name is not None:
Expand All @@ -86,11 +88,10 @@ def __init__(
self.image_annotation_generator = None
if self.prompt_type == 'annotation':
self.image_annotation_generator = ImageAnnotationGenerator(model_name=IMAGE_ANNOTATION_MODEL,
text_moderation_model_name=TEXT_MODERATION_MODEL)
elif self.prompt_type == 'random':
bt.logging.info(f"Loading prompt generation model ({prompt_generator_name})...")
self.prompt_generator = pipeline(
'text-generation', **PROMPT_GENERATOR_ARGS[prompt_generator_name])
text_moderation_model_name=TEXT_MODERATION_MODEL,
device = self.device)
else:
raise NotImplementedError(f"Unsupported prompt_type: {self.prompt_type}")

self.image_cache_dir = image_cache_dir
if image_cache_dir is not None:
Expand All @@ -108,19 +109,13 @@ def generate(self, k: int = 1, real_images=None) -> list:
Returns:
list: List of dictionaries containing 'prompt', 'image', and 'id'.
"""
bt.logging.info("Generating prompts...")
if self.prompt_type == 'annotation':
if real_images is None:
raise ValueError(f"real_images can't be None if self.prompt_type is 'annotation'")
prompts = [
self.generate_image_caption(real_images[i])
for i in range(k)
]
elif self.prompt_type == 'random':
prompts = [
self.generate_random_prompt(retry_attempts=10)
for _ in range(k)
]
else:
raise NotImplementedError

Expand All @@ -129,7 +124,6 @@ def generate(self, k: int = 1, real_images=None) -> list:
else:
self.load_diffuser(self.diffuser_name)

bt.logging.info("Generating images...")
gen_data = []
for prompt in prompts:
image_data = self.generate_image(prompt)
Expand All @@ -152,13 +146,12 @@ def clear_gpu(self):
torch.cuda.empty_cache()
self.diffuser = None

def load_diffuser(self, diffuser_name, gpu_id=None) -> None:
def load_diffuser(self, diffuser_name) -> None:
"""
Loads a Hugging Face diffuser model to a specific GPU.
Parameters:
diffuser_name (str): Name of the diffuser to load.
gpu_index (int): Index of the GPU to use. Defaults to 0.
"""
if diffuser_name == 'random':
diffuser_name = np.random.choice(DIFFUSER_NAMES, 1)[0]
Expand All @@ -171,12 +164,9 @@ def load_diffuser(self, diffuser_name, gpu_id=None) -> None:
**DIFFUSER_ARGS[diffuser_name],
add_watermarker=False)
self.diffuser.set_progress_bar_config(disable=True)
self.diffuser.to(self.device)
if DIFFUSER_CPU_OFFLOAD_ENABLED[diffuser_name]:
self.diffuser.enable_model_cpu_offload()
elif not gpu_id:
self.diffuser.to("cuda")
elif gpu_id:
self.diffuser.to(f"cuda:{gpu_id}")

bt.logging.info(f"Loaded {diffuser_name} using {pipeline_class.__name__}.")

Expand Down Expand Up @@ -207,48 +197,6 @@ def generate_image_caption(self, image_sample) -> str:
self.image_annotation_generator.clear_gpu()
return annotation['description']

def generate_random_prompt(self, retry_attempts: int = 10) -> str:
"""
Generates a prompt for image generation.
Args:
retry_attempts (int): Number of attempts to generate a valid prompt.
Returns:
str: Generated prompt.
"""
seed = random.randint(100, 1000000)
set_seed(seed)

starters = [
'A photorealistic portrait',
'A photorealistic image of a person',
'A photorealistic landscape',
'A photorealistic scene'
]
quality = [
'RAW photo', 'subject', '8k uhd', 'soft lighting', 'high quality', 'film grain'
]
device = [
'Fujifilm XT3', 'iphone', 'canon EOS r8' , 'dslr',
]

for _ in range(retry_attempts):
starting_text = np.random.choice(starters, 1)[0]
response = self.prompt_generator(
starting_text, max_length=(77 - len(starting_text)), num_return_sequences=1, truncation=True)

prompt = response[0]['generated_text'].strip()
prompt = re.sub('[^ ]+\.[^ ]+','', prompt)
prompt = prompt.replace("<", "").replace(">", "")

# temporary removal of extra context (like "featured on artstation") until we've trained our own prompt generator
prompt = re.split('[,;]', prompt)[0] + ', '
prompt += ', '.join(np.random.choice(quality, np.random.randint(len(quality)//2, len(quality))))
prompt += ', ' + np.random.choice(device, 1)[0]
if prompt != "":
return prompt

def get_tokenizer_with_min_len(self):
"""
Returns the tokenizer with the smallest maximum token length from the 'diffuser` object.
Expand Down Expand Up @@ -355,4 +303,4 @@ def generate_image(self, prompt, name = None, generate_at_target_size = False) -
'id': image_name,
'gen_time': gen_time
}
return image_data
return image_data
23 changes: 2 additions & 21 deletions bitmind/validator/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ async def forward(self):
wandb_data = {}

miner_uids = get_random_uids(self, k=self.config.neuron.sample_size)
bt.logging.info("Generating challenge")
if np.random.rand() > self._fake_prob:
bt.logging.info('sampling real image')

label = 0
source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images)
wandb_data['source_dataset'] = source_dataset.huggingface_dataset_name
Expand All @@ -85,19 +84,14 @@ async def forward(self):

else:
label = 1

if self.config.neuron.prompt_type == 'annotation':
bt.logging.info('generating fake image from annotation of real image')

retries = 10
while retries > 0:
retries -= 1

source_dataset, local_index = sample_random_real_image(self.real_image_datasets, self.total_real_images)
source_sample = source_dataset[local_index]
source_image = source_sample['image']
if source_image is None:
bt.logging.warning(f"Missing image encountered at {source_image['id']}, resampling...")
continue

# generate captions for the real images, then synthetic images from these captions
Expand All @@ -111,23 +105,10 @@ async def forward(self):
wandb_data['prompt'] = sample['prompt']
if not np.any(np.isnan(sample['image'])):
break

bt.logging.warning("NaN encountered in prompt/image generation, retrying...")

elif self.config.neuron.prompt_type == 'random':
bt.logging.info('generating fake image using prompt_generator')
sample = self.synthetic_image_generator.generate(k=1)[0]

wandb_data['model'] = self.synthetic_image_generator.diffuser_name
wandb_data['image'] = wandb.Image(sample['image'])
wandb_data['prompt'] = sample['prompt']

else:
bt.logging.error(f'unsupported neuron.prompt_type: {self.config.neuron.prompt_type}')
raise NotImplementedError
raise NotImplementedError(f'unsupported neuron.prompt_type: {self.config.neuron.prompt_type}')

image = sample['image']

image, level, data_aug_params = apply_augmentation_by_level(image)

bt.logging.info(f"Querying {len(miner_uids)} miners...")
Expand Down
5 changes: 4 additions & 1 deletion neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def __init__(self, config=None):
])

self.synthetic_image_generator = SyntheticImageGenerator(
prompt_type='annotation', use_random_diffuser=True, diffuser_name=None)
prompt_type='annotation',
use_random_diffuser=True,
diffuser_name=None,
device=self.config.neuron.device)

self._fake_prob = self.config.get('fake_prob', 0.5)

Expand Down
1 change: 1 addition & 0 deletions setup_validator_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ WALLET_HOTKEY=default
# Validator Port Setting:
VALIDATOR_AXON_PORT=8092
VALIDATOR_PROXY_PORT=10913
DEVICE=cuda
# API Keys:
WANDB_API_KEY=your_wandb_api_key_here
Expand Down
3 changes: 2 additions & 1 deletion start_validator.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ pm2 start neurons/validator.py --name bitmind_validator -- \
--wallet.name $WALLET_NAME \
--wallet.hotkey $WALLET_HOTKEY \
--axon.port $VALIDATOR_AXON_PORT \
--proxy.port $VALIDATOR_PROXY_PORT
--proxy.port $VALIDATOR_PROXY_PORT \
--neuron.device $DEVICE

0 comments on commit f2cefdd

Please sign in to comment.