Skip to content

Commit

Permalink
feat(ui): add AutoFP8 quantization window
Browse files Browse the repository at this point in the history
- add AutoFP8 quantization window (currently broken)
- add more dynamic KV parameters
  • Loading branch information
leafspark committed Sep 3, 2024
1 parent e43bc48 commit a7f2dec
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
Copy-Item -Path "src\convert_hf_to_gguf.py" -Destination "$distPath\src"
Copy-Item -Path "src\convert_lora_to_gguf.py" -Destination "$distPath\src"
Copy-Item -Path "src\convert_lora_to_ggml.py" -Destination "$distPath\src"
Copy-Item -Path "src\quantize_to_fp8_dynamic.py" -Destination "$distPath\src"
- name: Copy additional files (Linux/macOS)
if: matrix.os != 'windows-latest'
Expand All @@ -76,6 +77,7 @@ jobs:
cp src/convert_hf_to_gguf.py $distPath/src/
cp src/convert_lora_to_gguf.py $distPath/src/
cp src/convert_lora_to_ggml.py $distPath/src/
cp src/quantize_to_fp8_dynamic.py $distPath/src/
- name: Generate SHA256 (Windows)
if: matrix.os == 'windows-latest'
Expand Down
99 changes: 99 additions & 0 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def __init__(self, args: List[str]) -> None:
about_action.triggered.connect(self.show_about)
help_menu.addAction(about_action)

# Tools menu
tools_menu = self.menubar.addMenu("&Tools")
autofp8_action = QAction("&AutoFP8", self)
autofp8_action.triggered.connect(self.show_autofp8_window)
tools_menu.addAction(autofp8_action)

# Content widget
content_widget = QWidget()
content_layout = QHBoxLayout(content_widget)
Expand Down Expand Up @@ -1010,6 +1016,91 @@ def browse_hf_outfile(self) -> None:
if outfile:
self.hf_outfile.setText(os.path.abspath(outfile))

def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
self.logger.info(f"Quantizing {os.path.basename(model_dir)} to {output_dir}")
try:
command = [
"python",
"src/quantize_to_fp8_dynamic.py",
model_dir,
output_dir,
]

logs_path = self.logs_input.text()
ensure_directory(logs_path)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(logs_path, f"autofp8_{timestamp}.log")

thread = QuantizationThread(command, os.getcwd(), log_file)
self.quant_threads.append(thread)

task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
self.task_list.setItemWidget(list_item, task_item)

thread.status_signal.connect(task_item.update_status)
thread.finished_signal.connect(
lambda: self.task_finished(thread, task_item)
)
thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
thread.start()

except Exception as e:
show_error(self.logger, f"Error starting AutoFP8 quantization: {e}")
self.logger.info("AutoFP8 quantization task started")

def show_autofp8_window(self):
dialog = QDialog(self)
dialog.setWindowTitle("Quantize to FP8 Dynamic")
dialog.setFixedWidth(500)
layout = QVBoxLayout()

# Input path
input_layout = QHBoxLayout()
self.fp8_input = QLineEdit()
input_button = QPushButton(BROWSE)
input_button.clicked.connect(
lambda: self.fp8_input.setText(
QFileDialog.getExistingDirectory(self, "Open Model Folder")
)
)
input_layout.addWidget(QLabel("Input Model:"))
input_layout.addWidget(self.fp8_input)
input_layout.addWidget(input_button)
layout.addLayout(input_layout)

# Output path
output_layout = QHBoxLayout()
self.fp8_output = QLineEdit()
output_button = QPushButton(BROWSE)
output_button.clicked.connect(
lambda: self.fp8_output.setText(
QFileDialog.getExistingDirectory(self, "Open Model Folder")
)
)
output_layout.addWidget(QLabel("Output Path:"))
output_layout.addWidget(self.fp8_output)
output_layout.addWidget(output_button)
layout.addLayout(output_layout)

# Quantize button
quantize_button = QPushButton("Quantize")
quantize_button.clicked.connect(
lambda: self.quantize_to_fp8_dynamic(
self.fp8_input.text(), self.fp8_output.text()
)
)
layout.addWidget(quantize_button)

dialog.setLayout(layout)
dialog.exec()

def convert_hf_to_gguf(self) -> None:
self.logger.info(STARTING_HF_TO_GGUF_CONVERSION)
try:
Expand Down Expand Up @@ -1346,10 +1437,12 @@ def quantize_model(self) -> None:
output_name_parts.append("rq")

# Check for KV override
kv_used = bool
if any(
entry.get_override_string() for entry in self.kv_override_entries
):
output_name_parts.append("kv")
kv_used = True

# Join all parts with underscores and add .gguf extension
output_name = "_".join(output_name_parts) + ".gguf"
Expand Down Expand Up @@ -1391,6 +1484,12 @@ def quantize_model(self) -> None:
model_name=model_name,
quant_type=quant_type,
output_path=output_path,
quantization_parameters=[
kv_used, # If KV overrides are used
self.allow_requantize.isChecked(), # If requantize is used
self.pure.isChecked(), # If pure tensors option is used
self.leave_output_tensor.isChecked(), # If leave output tensor option is used
],
)
if override_string:
command.extend(["--override-kv", override_string])
Expand Down
37 changes: 34 additions & 3 deletions src/KVOverrideEntry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import locale
import shutil

import psutil
from PySide6.QtWidgets import QWidget, QHBoxLayout, QLineEdit, QComboBox, QPushButton
from PySide6.QtCore import Signal, QRegularExpression
from PySide6.QtGui import QDoubleValidator, QIntValidator, QRegularExpressionValidator
from datetime import datetime
import pytz
import time
import os
import socket
Expand All @@ -24,7 +29,7 @@ def __init__(self, parent=None) -> None:
layout.addWidget(self.key_input)

self.type_combo = QComboBox()
self.type_combo.addItems(["int", "str", "float"])
self.type_combo.addItems(["int", "str", "float", "u32", "i32"])
layout.addWidget(self.type_combo)

self.value_input = QLineEdit()
Expand All @@ -46,7 +51,11 @@ def delete_clicked(self) -> None:
self.deleted.emit(self)

def get_override_string(
self, model_name=None, quant_type=None, output_path=None
self,
model_name=None,
quant_type=None,
output_path=None,
quantization_parameters=None,
) -> str: # Add arguments
key = self.key_input.text()
type_ = self.type_combo.currentText()
Expand All @@ -61,7 +70,14 @@ def get_override_string(
"{system.hostname}": lambda: socket.gethostname(),
"{system.platform}": lambda: platform.system(),
"{system.python.version}": lambda: platform.python_version(),
"{system.date}": lambda: datetime.now().strftime("%Y-%m-%d"),
"{system.timezone}": lambda: time.tzname[time.daylight],
"{system.cpus}": lambda: str(os.cpu_count()),
"{system.memory.total}": lambda: str(psutil.virtual_memory().total),
"{system.memory.free}": lambda: str(psutil.virtual_memory().free),
"{system.filesystem.used}": lambda: str(shutil.disk_usage("/").used),
"{system.kernel.version}": lambda: platform.release(),
"{system.locale}": lambda: locale.getdefaultlocale()[0],
"{process.nice}": lambda: str(os.nice(0)),
"{model.name}": lambda: (
model_name if model_name is not None else "Unknown Model"
),
Expand All @@ -71,6 +87,21 @@ def get_override_string(
"{output.path}": lambda: (
output_path if output_path is not None else "Unknown Output Path"
),
"{quant.kv}": lambda: (
quantization_parameters[0]
if quantization_parameters is not None
else False
),
"{quant.requantized}": lambda: (
quantization_parameters[1]
if quantization_parameters is not None
else False
),
"{quant.leave_output_tensor}": lambda: (
quantization_parameters[2]
if quantization_parameters is not None
else False
),
}

for param, func in dynamic_params.items():
Expand Down
21 changes: 11 additions & 10 deletions src/AutoFP8.py → src/quantize_to_fp8_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import gc
import re
import sys
from typing import List
from typing import Optional, Tuple

Expand Down Expand Up @@ -280,12 +281,11 @@ def _prepare_calibration_data(calibration_tokens):
_prepare_calibration_data(calibration_tokens),
)

def save_quantized(self, save_dir, logger):
def save_quantized(self, save_dir):
save_quantized_model(
self.model,
quant_config=self.quantize_config,
save_dir=save_dir,
logger=logger,
)


Expand Down Expand Up @@ -489,10 +489,9 @@ def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
logger: Logger,
):
logger.info(model)
logger.info(f"Saving the model to {save_dir}")
print(model)
print(f"Saving the model to {save_dir}")
static_q_dict = {
"quantization_config": {
"quant_method": "fp8",
Expand Down Expand Up @@ -544,10 +543,8 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
return kv_cache_quant_layers


def quantize_to_fp8_dynamic(
input_model_dir: str, output_model_dir: str, logger: Logger
) -> None:
logger.info("Starting fp8 dynamic quantization")
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
print("Starting fp8 dynamic quantization")
# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="dynamic"
Expand All @@ -557,4 +554,8 @@ def quantize_to_fp8_dynamic(
model = AutoFP8ForCausalLM.from_pretrained(input_model_dir, quantize_config)
# No examples for dynamic quantization
model.quantize([])
model.save_quantized(output_model_dir, logger)
model.save_quantized(output_model_dir)


if __name__ == "__main__":
quantize_to_fp8_dynamic(sys.argv[0], sys.argv[1])

0 comments on commit a7f2dec

Please sign in to comment.