forked from Anashel-RPG/anashel-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput.py
629 lines (547 loc) · 28 KB
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
# input.py
import os
import sys
import glob
from rich.console import Console
from rich.prompt import Prompt
from rich.panel import Panel
from tqdm import tqdm
from safetensors.torch import load_file as safe_load
from tabulate import tabulate
# Initialize the Rich console
console = Console()
def main_input():
# Display the welcome message in a box
console.print(Panel(
"Welcome to the Anashel's LoRA Merging Utility!\n\n"
"This tool allows you to merge LoRA models or merge a LoRA into a main checkpoint."
"The process will guide you through selecting your models and merge settings.",
title="[bold yellow]LoRA Merger[/bold yellow]",
expand=False
))
# Ask the user which type of merge they want to perform
console.print(
"\n[bold yellow]Would you like to merge:[/bold yellow]\n"
"[1] Two LoRA models\n"
"[2] A LoRA model into a main checkpoint"
)
choice = Prompt.ask("[bold green]Choose an option (1-2)[/bold green]", choices=["1", "2"])
# Call the respective merge function based on the user's choice
if choice == "1":
settings = option_5_merge_lora() # For merging two LoRA models
else:
settings = option_6_merge_lora_checkpoint() # For merging a LoRA model into a checkpoint
# Check if settings are valid before confirming
if settings:
# Confirm the settings with the user
confirm = confirm_settings(settings)
if confirm:
# Return settings to main.py for further processing
return settings
else:
# Option to restart the input process without exiting
console.clear()
else:
console.print("[bold red]Process aborted due to missing or invalid settings.[/bold red]\n")
sys.exit(1) # Exit the application with an error code
# def main_input():
# while True:
# # Display the welcome message and options
# display_welcome()
#
# # Ask the user to choose an option from 1 to 5
# choice = Prompt.ask(
# "[bold green]Choose a utility (1-5)[/bold green]"
# )
#
# # Call the respective function based on the user's choice
# if choice == "1":
# settings = option_1_generate_prompt_idea()
# elif choice == "2":
# settings = option_2_generate_image()
# elif choice == "3":
# settings = option_3_create_style_variation()
# elif choice == "4":
# settings = option_4_caption_images()
# elif choice == "5":
# settings = option_5_merge_lora()
# else:
# console.print("[bold red]Invalid choice. Please enter a number from 1 to 5.[/bold red]")
# continue
#
# # Check if settings are valid before confirming
# if settings:
# # Confirm the settings with the user
# confirm = confirm_settings(settings)
# if confirm:
# # Return settings to main.py for further processing
# return settings
# else:
# # Option to restart the input process without exiting
# console.clear()
# else:
# console.print("[bold red]Process aborted due to missing or invalid settings.[/bold red]\n")
# sys.exit(1) # Exit the application with an error code
def display_welcome():
"""Display the welcome message and options."""
console.print(Panel(
"[bold magenta]1:[/bold magenta] Generate prompt idea to help build a dataset\n"
"[bold magenta]2:[/bold magenta] Generate image using Flux or Leonardo\n"
"[bold magenta]3:[/bold magenta] Create style variation of images to expand a dataset\n"
"[bold magenta]4:[/bold magenta] Caption existing images using OpenAI\n"
"[bold magenta]5:[/bold magenta] Merge LoRA",
title="[bold yellow]Anashel Utility's[/bold yellow]"
))
def option_1_generate_prompt_idea():
"""Handle input for Generate prompt idea utility."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]Using your input, we will create 50 prompts for various images around your subject.[/bold green] "
"This is useful to start building a dataset for training a concept, a specific style, or enhancing an existing dataset with complementary images.\n"
)
settings = {"utility": "Generate Prompt Idea"}
# First question: Choose the type of prompt
console.print(
"[bold yellow]Do you wish to create prompts related to:[/bold yellow]\n"
"[1] A person or character (e.g., A paladin, a 1950's soldier...)\n"
"[2] A location (e.g., Dungeon, a building, a landscape...)\n"
"[3] An aesthetic style (e.g., Dystopian futuristic, medieval fantasy...)"
)
choice = Prompt.ask("[bold green]Choose a type (1-3)[/bold green]")
# Determine the type based on the choice
if choice == "1":
prompt_type = "person or character"
example_prompt = "A brave paladin in shining armor, standing on a battlefield at dawn, with a glowing sword raised."
elif choice == "2":
prompt_type = "location"
example_prompt = "A dark, eerie dungeon with flickering torches on the walls, filled with ancient stone carvings and scattered bones."
else:
prompt_type = "aesthetic style"
example_prompt = "A dystopian futuristic cityscape at night, with towering neon-lit skyscrapers, flying vehicles, and a moody, cyberpunk atmosphere."
settings["type"] = prompt_type
# Second question: Ask for specific details
while True:
detail = Prompt.ask(f"[bold yellow]Please describe the {prompt_type} in more detail[/bold yellow]")
settings["detail"] = detail
# Processing step - Example prompt creation (this is where an API call would be made)
console.print(
f"\n[bold cyan]Here is an example prompt based on your input:[/bold cyan]\n[italic magenta]{example_prompt}[/italic magenta]"
)
# Third question: Confirm or loop back
adjust = Prompt.ask(
"[bold yellow]Is this satisfactory?[/bold yellow] (no to adjust, yes to continue)"
)
if adjust.lower() in ["yes", "y", ""]:
break
return settings
def option_2_generate_image():
"""Handle input for Generate image using Flux or Leonardo."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]Using Flux or Leonardo, we can generate images based on your specifications.[/bold green] "
"This is ideal for quickly visualizing concepts, creating assets, or expanding your dataset with high-quality imagery.\n"
)
settings = {"utility": "Generate Image"}
# Check for prompt.txt in the specified folders
prompt_path_02 = "02-images_generation/prompt.txt"
prompt_path_01 = "01-prompt_creation/output/prompt.txt"
prompt_file = None
# Check which prompt.txt file to use
if os.path.exists(prompt_path_01) and os.path.exists(prompt_path_02):
console.print(
"[bold yellow]Do you wish to use the existing prompts from:[/bold yellow]\n"
"[1] Folder 01-prompt_creation/output/\n"
"[2] Folder 02-images_generation"
)
folder_choice = Prompt.ask("[bold green]Choose your folder (1-2)[/bold green]")
prompt_file = prompt_path_01 if folder_choice == "1" else prompt_path_02
elif os.path.exists(prompt_path_01):
prompt_file = prompt_path_01
elif os.path.exists(prompt_path_02):
prompt_file = prompt_path_02
else:
console.print(
"[bold red]Error: No prompt.txt file found in either 02-images_generation or 01-prompt_creation/output.[/bold red]\n"
"Please ensure that a prompt.txt file is present with one prompt per line before proceeding."
)
return None
# Count the number of lines (prompts) in the selected file
with open(prompt_file, "r") as file:
prompts = file.readlines()
num_prompts = len(prompts)
console.print(f"\n[bold cyan]Found {num_prompts} prompts in {prompt_file}.[/bold cyan]")
# Platform selection
console.print(
"[bold yellow]Which platform would you like to use for image generation?[/bold yellow]\n"
"[1] Flux\n"
"[2] Leonardo"
)
platform_choice = Prompt.ask("[bold green]Choose a platform (1-2)[/bold green]")
if platform_choice == "1":
settings["platform"] = "Flux"
else:
settings["platform"] = "Leonardo"
settings["prompt_file"] = prompt_file
settings["num_prompts"] = num_prompts
return settings
def option_3_create_style_variation():
"""Handle input for Create style variation of images."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]Style variation takes your images and generates variations using a style you provide.[/bold green] "
"This is useful for expanding a dataset with different visual styles based on existing images.\n"
)
settings = {"utility": "Create Style Variation"}
# Scan for images in 03-style_variation/input and 02-images_generation/output
folder_03_input = "03-style_variation/input"
folder_02_output = "02-images_generation/output"
images_03 = glob.glob(os.path.join(folder_03_input, "*.[jp][pn]g")) # Matches .jpg, .jpeg, .png
images_02 = glob.glob(os.path.join(folder_02_output, "*.[jp][pn]g"))
# Check for images
if not images_03 and not images_02:
console.print(
"[bold red]Error: No images found in 03-style_variation/input or 02-images_generation/output.[/bold red]\n"
"Please ensure that images (.jpg, .jpeg, .png) are present before proceeding."
)
return None
# Confirm the number of images found
if images_03 and images_02:
console.print(
f"[bold yellow]Images found in both folders:[/bold yellow]\n"
f"[1] {len(images_03)} images in 03-style_variation/input\n"
f"[2] {len(images_02)} images in 02-images_generation/output"
)
folder_choice = Prompt.ask("[bold green]Choose your folder (1-2):[/bold green]", choices=["1", "2"])
selected_images = images_03 if folder_choice == "1" else images_02
else:
selected_images = images_03 if images_03 else images_02
console.print(f"[bold cyan]Found {len(selected_images)} images to be transformed.[/bold cyan]")
# Check for style reference image in 03-style_variation
style_images = glob.glob("03-style_variation/*.[jp][pn]g")
if not style_images:
console.print(
"[bold red]Error: No style reference image found in 03-style_variation.[/bold red]\n"
"Please ensure that at least one image (.jpg, .jpeg, .png) is present at the root of the folder 03-style_variation to be used as a style reference."
)
return None
elif len(style_images) == 1:
style_image = style_images[0]
console.print(f"[bold cyan]Using {os.path.basename(style_image)} as the style reference image.[/bold cyan]")
else:
console.print("[bold yellow]Multiple style reference images found:[/bold yellow]")
for idx, img in enumerate(style_images, 1):
console.print(f"[{idx}] {os.path.basename(img)}")
img_choice = Prompt.ask(f"[bold green]Choose the style reference image (1-{len(style_images)}):[/bold green]",
choices=[str(i) for i in range(1, len(style_images) + 1)])
style_image = style_images[int(img_choice) - 1]
settings["selected_images"] = f"{len(selected_images)} images selected"
settings["style_image"] = os.path.basename(style_image)
return settings
def option_4_caption_images():
"""Handle input for Caption existing images using OpenAI."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]This utility will caption existing images using OpenAI, helping you generate descriptive captions for your dataset.[/bold green]\n"
)
settings = {"utility": "Caption Images"}
# Scan for images in 04-ai_caption/input
input_folder = "04-ai_caption/input"
images = glob.glob(os.path.join(input_folder, "*.[jp][pn]g")) # Matches .jpg, .jpeg, .png
# Check if any images are found
if not images:
console.print(
"[bold red]Error: No images found in 04-ai_caption/input.[/bold red]\n"
"Please ensure that images are present in the folder before proceeding."
)
return None
# Confirm the number of images found
console.print(f"[bold cyan]Found {len(images)} images for captioning in 04-ai_caption/input.[/bold cyan]")
settings["num_images"] = len(images)
return settings
def option_5_merge_lora():
"""Handle input for scanning LoRA models and setting up a merge."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]This utility allows you to set up a merge of LoRA models by selecting two models and adjusting the merge weight percentage.[/bold green]\n"
)
settings = {"utility": "Merge LoRA"}
# Step 1: Scan the folder and make an inventory of all LoRA (.safetensor) files
lora_folder = "05a-lora_merging"
lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')]
if not lora_files:
console.print(
"[bold red]Error: No LoRA files found in 05a-lora_merging.[/bold red]\n"
"Please ensure that .safetensors files are present before proceeding."
)
return None
if len(lora_files) == 1:
console.print(
"[bold red]Error: Only one LoRA file found in 05a-lora_merging.[/bold red]\n"
"A minimum of two LoRA files is required to perform a merge. Please add more files to proceed."
)
return None
# Load LoRA details once
lora_details = []
with tqdm(total=len(lora_files), desc="Loading LoRA models", unit="file", dynamic_ncols=True) as progress_bar:
for i, lora_file in enumerate(lora_files, 1):
lora_path = os.path.join(lora_folder, lora_file)
lora_model = load_lora_model(lora_path)
num_layers = len(lora_model.keys())
file_size = get_file_size(lora_path)
lora_filename = lora_file.replace('.safetensors', '').replace('.pt', '')
lora_details.append([i, lora_filename, num_layers, f"{file_size:.2f} MB"])
progress_bar.update(1)
while True:
# Display the table with LoRA details
formatted_table = tabulate(
lora_details,
headers=["Index", "LoRA Model", "Number of Layers", "File Size"],
tablefmt="pretty",
maxcolwidths=[None, 30, None, None]
)
console.print(f"\n{formatted_table}")
# Prompt for main LoRA source
while True:
try:
main_lora_index = int(Prompt.ask(f"Select the main LoRA source (1-{len(lora_files)})")) - 1
if 0 <= main_lora_index < len(lora_files):
main_lora_file = lora_files[main_lora_index]
break
else:
console.print(f"[bold red]Please enter a number between 1 and {len(lora_files)}.[/bold red]")
except ValueError:
console.print("[bold red]Please enter a valid number.[/bold red]")
main_lora_path = os.path.join(lora_folder, main_lora_file)
# Prompt for LoRA to merge with
while True:
try:
merge_lora_index = int(Prompt.ask(f"Select the LoRA to merge with (1-{len(lora_files)})")) - 1
if 0 <= merge_lora_index < len(lora_files):
if merge_lora_index != main_lora_index:
merge_lora_file = lora_files[merge_lora_index]
break
else:
console.print(
"[bold red]Cannot merge the same LoRA file with itself. Please select a different file.[/bold red]"
)
else:
console.print(f"[bold red]Please enter a number between 1 and {len(lora_files)}.[/bold red]")
except ValueError:
console.print("[bold red]Please enter a valid number.[/bold red]")
merge_lora_path = os.path.join(lora_folder, merge_lora_file)
# Prompt for merge strategy
console.print(
"[bold yellow]Choose the merging strategy:[/bold yellow]\n"
"[1] Adaptive Merge (uses tensor norms and weight)\n"
"[2] Manual Merge (uses fixed weights you specify)\n"
"[3] Additive Merge (uses 100% of the first and adds a percentage of the second)"
)
strategy_choice = Prompt.ask("[bold green]Choose a strategy (1-3)[/bold green]", choices=["1", "2", "3"])
if strategy_choice == "1":
merge_type = "adaptive"
console.print("[bold cyan]Selected Adaptive Merge strategy.[/bold cyan]")
elif strategy_choice == "2":
merge_type = "manual"
console.print("[bold cyan]Selected Manual Merge strategy.[/bold cyan]")
else:
merge_type = "additive"
console.print("[bold cyan]Selected Additive Merge strategy.[/bold cyan]")
# Handle Additive Merge specific input
if merge_type == "additive":
add_weight = float(Prompt.ask("[bold green]Enter the percentage of the second LoRA to add (e.g., 40 for 40%)[/bold green]"))
settings["merge_strategy"] = "Additive"
settings["add_weight"] = add_weight
settings["merge_type"] = merge_type # Fix: properly set the merge_type key
console.print(f"[bold cyan]Using Additive Merge: 100% of {main_lora_file} with {add_weight}% of {merge_lora_file}.[/bold cyan]")
else:
# Prompt for merge weight percentage
weight_input = Prompt.ask(
"Enter the percentage to keep from the main model (0-100)\nYou can also type 'mix' for 25%, 50%, 75% versions"
)
if weight_input.lower() == 'mix':
settings["merge_strategy"] = "Mix"
settings["weight_percentages"] = [25, 50, 75]
settings["merge_type"] = merge_type # Apply the selected merge strategy to the mix
console.print(f"[bold cyan]You've chosen to create three versions with weights: 25%, 50%, and 75% using {merge_type} merge strategy.[/bold cyan]")
else:
try:
weight_percentage = float(weight_input)
if 0 <= weight_percentage <= 100:
settings["merge_strategy"] = "Weighted"
settings["weight_percentage"] = weight_percentage
settings["merge_type"] = merge_type
alpha = weight_percentage / 100
beta = 1.0 - alpha
console.print(f"[bold cyan]Merge Weight: {alpha * 100}% main, {beta * 100}% merge using {merge_type} merge strategy.[/bold cyan]")
else:
raise ValueError
except ValueError:
console.print("[bold red]Invalid input. Please enter a number between 0 and 100 or 'mix'.[/bold red]")
continue
# Display settings before confirming
console.print(
f"\n[bold cyan]You have chosen to merge:[/bold cyan]\n"
f"Main LoRA: {main_lora_file.replace('.safetensors', '').replace('.pt', '')}\n"
f"Merge LoRA: {merge_lora_file.replace('.safetensors', '').replace('.pt', '')}\n"
f"Merge Strategy: {settings['merge_strategy']}\n"
f"Merge Type: {settings['merge_type']}"
)
if settings["merge_strategy"] == "Mix":
console.print(f"Weight Percentages: 25%, 50%, 75% using {settings['merge_type']} strategy")
else:
if "weight_percentage" in settings:
console.print(f"Weight Percentage: {settings['weight_percentage']}% using {settings['merge_type']} strategy")
else:
console.print(f"Add Weight Percentage: {settings['add_weight']}% using {settings['merge_type']} strategy")
# Confirm settings before proceeding
confirm = Prompt.ask(
"[bold yellow]Is this satisfactory?[/bold yellow] (no to adjust, yes to continue)"
)
# If the user chooses to adjust, restart the selection process without reloading models
if confirm.lower() in ["yes", "y", ""]:
break
else:
console.print("[bold yellow]Adjusting settings. Please make your selections again.[/bold yellow]")
settings["main_lora"] = main_lora_file
settings["merge_lora"] = merge_lora_file
return settings
def option_6_merge_lora_checkpoint():
"""Handle input for merging a LoRA model into a main checkpoint."""
console.print("----\n") # Visual separator for entering the new section
console.print(
"[bold green]This utility allows you to merge a LoRA model into a main checkpoint by selecting the models and adjusting the merge weight percentage.[/bold green]\n\n!!! WARNING: I can’t even begin to explain how seriously messed up and experimental this is.\n"
)
settings = {"utility": "Merge LoRA Checkpoint"}
# Step 1: Scan the folder for LoRA models
lora_folder = "05a-lora_merging"
checkpoint_folder = "05b-checkpoint/input" # Updated folder for input checkpoints
lora_files = [f for f in os.listdir(lora_folder) if f.endswith('.safetensors') or f.endswith('.pt')]
checkpoint_files = [f for f in os.listdir(checkpoint_folder) if f.endswith('.safetensors') or f.endswith('.pt')]
if not lora_files or not checkpoint_files:
console.print(
"[bold red]Error: No LoRA or checkpoint files found in the specified folders.[/bold red]\n"
"Please ensure that .safetensors files are present in both 05a-lora_merging and 05b-checkpoint/input before proceeding."
)
return None
# Display available models
lora_details = []
with tqdm(total=len(lora_files), desc="Loading LoRA models", unit="file", dynamic_ncols=True) as progress_bar:
for i, lora_file in enumerate(lora_files, 1):
lora_path = os.path.join(lora_folder, lora_file)
model = load_lora_model(lora_path)
num_layers = len(model.keys())
file_size = get_file_size(lora_path)
lora_filename = lora_file.replace('.safetensors', '').replace('.pt', '')
lora_details.append([i, lora_filename, num_layers, f"{file_size:.2f} MB"])
progress_bar.update(1)
checkpoint_details = []
with tqdm(total=len(checkpoint_files), desc="Loading Checkpoints", unit="file", dynamic_ncols=True) as progress_bar:
for i, checkpoint_file in enumerate(checkpoint_files, 1):
checkpoint_path = os.path.join(checkpoint_folder, checkpoint_file)
model = load_lora_model(checkpoint_path)
num_layers = len(model.keys())
file_size = get_file_size(checkpoint_path)
checkpoint_filename = checkpoint_file.replace('.safetensors', '').replace('.pt', '')
checkpoint_details.append([i, checkpoint_filename, num_layers, f"{file_size:.2f} MB"])
progress_bar.update(1)
while True:
# Display the table with LoRA details
formatted_lora_table = tabulate(
lora_details,
headers=["Index", "LoRA Model", "Number of Layers", "File Size"],
tablefmt="pretty",
maxcolwidths=[None, 30, None, None]
)
console.print(f"\n{formatted_lora_table}")
# Display the table with checkpoint details
formatted_checkpoint_table = tabulate(
checkpoint_details,
headers=["Index", "Checkpoint Model", "Number of Layers", "File Size"],
tablefmt="pretty",
maxcolwidths=[None, 30, None, None]
)
console.print(f"\n{formatted_checkpoint_table}")
# Prompt for main LoRA source
while True:
try:
lora_index = int(Prompt.ask(f"Select the LoRA model (1-{len(lora_files)})")) - 1
if 0 <= lora_index < len(lora_files):
lora_file = lora_files[lora_index]
break
else:
console.print(f"[bold red]Please enter a number between 1 and {len(lora_files)}.[/bold red]")
except ValueError:
console.print("[bold red]Please enter a valid number.[/bold red]")
# Prompt for checkpoint source
while True:
try:
checkpoint_index = int(Prompt.ask(f"Select the Checkpoint model (1-{len(checkpoint_files)})")) - 1
if 0 <= checkpoint_index < len(checkpoint_files):
checkpoint_file = checkpoint_files[checkpoint_index]
break
else:
console.print(f"[bold red]Please enter a number between 1 and {len(checkpoint_files)}.[/bold red]")
except ValueError:
console.print("[bold red]Please enter a valid number.[/bold red]")
# Prompt for merge strategy
console.print(
"[bold yellow]Choose the merging strategy:[/bold yellow]\n"
"[1] Mix (25%, 50%, 75% versions)\n"
"[2] Full Blend (specify weight)"
)
strategy_choice = Prompt.ask("[bold green]Choose a strategy (1-2)[/bold green]", choices=["1", "2"])
if strategy_choice == "1":
settings["merge_strategy"] = "Mix"
settings["weight_percentages"] = [25, 50, 75]
console.print("[bold cyan]Selected Mix strategy (25%, 50%, 75%).[/bold cyan]")
else:
settings["merge_strategy"] = "Full"
merge_weight = float(Prompt.ask("[bold green]Enter the percentage of LoRA to merge into the checkpoint (e.g., 40 for 40%)[/bold green]"))
settings["merge_weight"] = merge_weight
console.print(f"[bold cyan]Using Full Blend: {merge_weight}% of the LoRA model will be merged into the checkpoint.[/bold cyan]")
# Confirm the settings before merging
console.print(
f"\n[bold cyan]You have chosen to merge:[/bold cyan]\n"
f"LoRA Model: {lora_file}\n"
f"Checkpoint Model: {checkpoint_file}\n"
f"Merge Strategy: {settings['merge_strategy']}"
)
if settings["merge_strategy"] == "Mix":
console.print(f"Weight Percentages: 25%, 50%, 75%")
else:
console.print(f"Weight Percentage: {settings['merge_weight']}%")
confirm = Prompt.ask(
"[bold yellow]Is this satisfactory?[/bold yellow] (no to adjust, yes to continue)"
)
if confirm.lower() in ["yes", "y", ""]:
break
else:
console.print("[bold yellow]Adjusting settings. Please make your selections again.[/bold yellow]")
settings["lora_model"] = lora_file
settings["checkpoint_model"] = checkpoint_file
return settings
def get_file_size(file_path):
"""Returns the size of the file in MB."""
return os.path.getsize(file_path) / (1024 * 1024)
def load_lora_model(file_path):
file_size = get_file_size(file_path)
buffer = bytearray(os.path.getsize(file_path))
view = memoryview(buffer)
with open(file_path, "rb") as f:
while len(view):
bytes_read = f.readinto(view)
if bytes_read == 0:
break
view = view[bytes_read:]
if file_path.endswith('.safetensors'):
lora_model = safe_load(file_path)
else:
lora_model = torch.load(file_path)
return lora_model
def confirm_settings(settings):
"""Automatically confirm the settings without user input."""
console.print("----\n[bold green]LOADING SETTING:[/bold green]")
for key, value in settings.items():
console.print(f"[bold magenta]{key}:[/bold magenta] {value}")
# Automatically confirm the settings
return True