Skip to content

Commit

Permalink
fix: use proper status in TaskListItem
Browse files Browse the repository at this point in the history
- use proper status in TaskListItem
- make sure to pass quant_threads and Logger to TaskListItem
- remove unnecessary logging in quantize_to_fp8_dynamic.py and optimize imports
  • Loading branch information
leafspark committed Sep 3, 2024
1 parent a7f2dec commit a91f804
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
30 changes: 25 additions & 5 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(self, args: List[str]) -> None:
self.delete_task = partial(TaskListItem.delete_task, self)
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
Expand Down Expand Up @@ -1036,7 +1035,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
self.quant_threads.append(thread)

task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
Expand Down Expand Up @@ -1152,7 +1157,13 @@ def convert_hf_to_gguf(self) -> None:
self.quant_threads.append(thread)

task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
Expand Down Expand Up @@ -1516,7 +1527,10 @@ def quantize_model(self) -> None:
self.quant_threads.append(thread)

task_item = TaskListItem(
QUANTIZING_MODEL_TO.format(model_name, quant_type), log_file
QUANTIZING_MODEL_TO.format(model_name, quant_type),
log_file,
show_properties=True,
logger=self.logger,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
Expand Down Expand Up @@ -1687,7 +1701,13 @@ def generate_imatrix(self) -> None:
task_name = GENERATING_IMATRIX_FOR.format(
os.path.basename(self.imatrix_model.text())
)
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
Expand Down
38 changes: 26 additions & 12 deletions src/TaskListItem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from PySide6.QtCore import *
from PySide6.QtGui import QAction
from PySide6.QtWidgets import *
Expand All @@ -15,19 +17,34 @@
SHOWING_PROPERTIES_FOR_TASK,
DELETE,
RESTART,
IN_PROGRESS,
ERROR,
)
from ModelInfoDialog import ModelInfoDialog
from QuantizationThread import QuantizationThread
from Logger import Logger


class TaskListItem(QWidget):
def __init__(
self, task_name, log_file, show_progress_bar=True, parent=None
self,
task_name,
log_file,
show_progress_bar=True,
parent=None,
show_properties=False,
logger=Logger,
quant_threads=List[QuantizationThread],
) -> None:
super().__init__(parent)
self.quant_threads = quant_threads
self.task_name = task_name
self.log_file = log_file
self.logger = logger
self.show_properties = show_properties
self.status = "Pending"
layout = QHBoxLayout(self)

self.task_label = QLabel(task_name)
self.progress_bar = QProgressBar()
self.progress_bar.setRange(0, 100)
Expand Down Expand Up @@ -84,7 +101,8 @@ def show_task_properties(self, item) -> None:
model_info_dialog.exec()
break

def cancel_task_by_item(self, item) -> None:
def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
Expand All @@ -93,15 +111,11 @@ def cancel_task_by_item(self, item) -> None:
self.quant_threads.remove(thread)
break

def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text()))
self.cancel_task_by_item(item)

def delete_task(self, item) -> None:
self.logger.info(DELETING_TASK.format(item.text()))

# Cancel the task first
self.cancel_task_by_item(item)
self.cancel_task(item)

reply = QMessageBox.question(
self,
Expand All @@ -121,21 +135,21 @@ def delete_task(self, item) -> None:
def update_status(self, status) -> None:
self.status = status
self.status_label.setText(status)
if status == "In Progress":
if status == IN_PROGRESS:
# Only start timer if showing percentage progress
if self.progress_bar.isVisible():
self.progress_bar.setRange(0, 100)
self.progress_timer.start(100)
elif status == "Completed":
elif status == COMPLETED:
self.progress_timer.stop()
self.progress_bar.setValue(100)
elif status == "Canceled":
elif status == CANCELED:
self.progress_timer.stop()
self.progress_bar.setValue(0)

def set_error(self) -> None:
self.status = "Error"
self.status_label.setText("Error")
self.status = ERROR
self.status_label.setText(ERROR)
self.status_label.setStyleSheet("color: red;")
self.progress_bar.setRange(0, 100)
self.progress_timer.stop()
Expand Down
2 changes: 0 additions & 2 deletions src/quantize_to_fp8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from Logger import Logger

# https://github.com/neuralmagic/AutoFP8

Expand Down Expand Up @@ -544,7 +543,6 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List


def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
print("Starting fp8 dynamic quantization")
# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="dynamic"
Expand Down

0 comments on commit a91f804

Please sign in to comment.