Skip to content

Commit

Permalink
style: format code with Black
Browse files Browse the repository at this point in the history
  • Loading branch information
leafspark committed Aug 5, 2024
1 parent 2dc5bd9 commit fa51f7c
Show file tree
Hide file tree
Showing 21 changed files with 8,226 additions and 6,933 deletions.
43 changes: 26 additions & 17 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self):

self.logger.info(INITIALIZING_AUTOGGUF)
self.setWindowTitle(WINDOW_TITLE)
self.setWindowIcon(QIcon(resource_path("assets/favicon.ico")))
self.setWindowIcon(QIcon(resource_path("assets/favicon.ico")))
self.setGeometry(100, 100, 1600, 1200)

ensure_directory(os.path.abspath("quantized_models"))
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self):
"Q5_K_S",
"Q5_K_M",
"Q6_K",
"Q8_0",
"Q8_0",
"Q4_0",
"Q4_1",
"Q5_0",
Expand All @@ -180,7 +180,7 @@ def __init__(self):
"Q4_0_4_8",
"Q4_0_8_8",
"BF16",
"F16",
"F16",
"F32",
"COPY",
]
Expand Down Expand Up @@ -452,8 +452,13 @@ def __init__(self):
# Output Type Dropdown
self.lora_output_type_combo = QComboBox()
self.lora_output_type_combo.addItems(["GGML", "GGUF"])
self.lora_output_type_combo.currentIndexChanged.connect(self.update_base_model_visibility)
lora_layout.addRow(self.create_label(OUTPUT_TYPE, SELECT_OUTPUT_TYPE), self.lora_output_type_combo)
self.lora_output_type_combo.currentIndexChanged.connect(
self.update_base_model_visibility
)
lora_layout.addRow(
self.create_label(OUTPUT_TYPE, SELECT_OUTPUT_TYPE),
self.lora_output_type_combo,
)

# Base Model Path (initially hidden)
self.base_model_label = self.create_label(BASE_MODEL, SELECT_BASE_MODEL_FILE)
Expand All @@ -471,7 +476,9 @@ def __init__(self):
wrapper_layout = QHBoxLayout(self.base_model_wrapper)
wrapper_layout.addWidget(self.base_model_label)
wrapper_layout.addWidget(self.base_model_widget, 1) # Give it a stretch factor
wrapper_layout.setContentsMargins(0, 0, 0, 0) # Remove margins for better alignment
wrapper_layout.setContentsMargins(
0, 0, 0, 0
) # Remove margins for better alignment

# Add the wrapper to the layout
lora_layout.addRow(self.base_model_wrapper)
Expand Down Expand Up @@ -545,7 +552,7 @@ def __init__(self):
# Modify the task list to support right-click menu
self.task_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu)
self.task_list.customContextMenuRequested.connect(self.show_task_context_menu)

# Set inital state
self.update_base_model_visibility(self.lora_output_type_combo.currentIndex())

Expand Down Expand Up @@ -1200,19 +1207,19 @@ def delete_task(self, item):
if reply == QMessageBox.StandardButton.Yes:
# Retrieve the task_item before removing it from the list
task_item = self.task_list.itemWidget(item)

# Remove the item from the list
row = self.task_list.row(item)
self.task_list.takeItem(row)

# If the task is still running, terminate it
if task_item and task_item.log_file:
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
thread.terminate()
self.quant_threads.remove(thread)
break
break

# Delete the task_item widget
if task_item:
task_item.deleteLater()
Expand Down Expand Up @@ -1395,7 +1402,7 @@ def quantize_model(self):
override_string = entry.get_override_string(
model_name=model_name,
quant_type=quant_type,
output_path=output_path
output_path=output_path,
)
if override_string:
command.extend(["--override-kv", override_string])
Expand All @@ -1413,7 +1420,7 @@ def quantize_model(self):
log_file = os.path.join(
logs_path, f"{model_name}_{timestamp}_{quant_type}.log"
)

# Log quant command
command_str = " ".join(command)
self.logger.info(f"{QUANTIZATION_COMMAND}: {command_str}")
Expand All @@ -1430,7 +1437,9 @@ def quantize_model(self):
self.task_list.setItemWidget(list_item, task_item)

# Connect the output signal to the new progress parsing function
thread.output_signal.connect(lambda line: self.parse_progress(line, task_item))
thread.output_signal.connect(
lambda line: self.parse_progress(line, task_item)
)
thread.status_signal.connect(task_item.update_status)
thread.finished_signal.connect(lambda: self.task_finished(thread))
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
Expand Down Expand Up @@ -1556,7 +1565,7 @@ def generate_imatrix(self):

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(self.logs_input.text(), f"imatrix_{timestamp}.log")

# Log command
command_str = " ".join(command)
self.logger.info(f"{IMATRIX_GENERATION_COMMAND}: {command_str}")
Expand All @@ -1580,7 +1589,7 @@ def generate_imatrix(self):
except Exception as e:
self.show_error(ERROR_STARTING_IMATRIX_GENERATION.format(str(e)))
self.logger.info(IMATRIX_GENERATION_TASK_STARTED)

def show_error(self, message):
self.logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(self, ERROR, message)
Expand Down Expand Up @@ -1617,4 +1626,4 @@ def closeEvent(self, event: QCloseEvent):
app = QApplication(sys.argv)
window = AutoGGUF()
window.show()
sys.exit(app.exec())
sys.exit(app.exec())
109 changes: 55 additions & 54 deletions src/DownloadThread.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,55 @@
from PyQt6.QtWidgets import *
from PyQt6.QtCore import *
from PyQt6.QtGui import *
import os
import sys
import psutil
import subprocess
import time
import signal
import json
import platform
import requests
import zipfile
from datetime import datetime

class DownloadThread(QThread):
progress_signal = pyqtSignal(int)
finished_signal = pyqtSignal(str)
error_signal = pyqtSignal(str)

def __init__(self, url, save_path):
super().__init__()
self.url = url
self.save_path = save_path

def run(self):
try:
response = requests.get(self.url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 8192
downloaded = 0

with open(self.save_path, 'wb') as file:
for data in response.iter_content(block_size):
size = file.write(data)
downloaded += size
if total_size:
progress = int((downloaded / total_size) * 100)
self.progress_signal.emit(progress)

# Extract the downloaded zip file
extract_dir = os.path.splitext(self.save_path)[0]
with zipfile.ZipFile(self.save_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)

# Remove the zip file after extraction
os.remove(self.save_path)

self.finished_signal.emit(extract_dir)
except Exception as e:
self.error_signal.emit(str(e))
if os.path.exists(self.save_path):
os.remove(self.save_path)
from PyQt6.QtWidgets import *
from PyQt6.QtCore import *
from PyQt6.QtGui import *
import os
import sys
import psutil
import subprocess
import time
import signal
import json
import platform
import requests
import zipfile
from datetime import datetime


class DownloadThread(QThread):
progress_signal = pyqtSignal(int)
finished_signal = pyqtSignal(str)
error_signal = pyqtSignal(str)

def __init__(self, url, save_path):
super().__init__()
self.url = url
self.save_path = save_path

def run(self):
try:
response = requests.get(self.url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
downloaded = 0

with open(self.save_path, "wb") as file:
for data in response.iter_content(block_size):
size = file.write(data)
downloaded += size
if total_size:
progress = int((downloaded / total_size) * 100)
self.progress_signal.emit(progress)

# Extract the downloaded zip file
extract_dir = os.path.splitext(self.save_path)[0]
with zipfile.ZipFile(self.save_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)

# Remove the zip file after extraction
os.remove(self.save_path)

self.finished_signal.emit(extract_dir)
except Exception as e:
self.error_signal.emit(str(e))
if os.path.exists(self.save_path):
os.remove(self.save_path)
Loading

0 comments on commit fa51f7c

Please sign in to comment.