Skip to content

Commit

Permalink
refactor: move some strings to localizations
Browse files Browse the repository at this point in the history
  • Loading branch information
leafspark committed Aug 6, 2024
1 parent ab7ffb0 commit eca2ecc
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 62 deletions.
112 changes: 51 additions & 61 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
from imports_and_globals import ensure_directory, open_file_safe, resource_path
from localizations import *

from functools import wraps


def handle_load_preset_error(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
QMessageBox.critical(self, ERROR, FAILED_TO_LOAD_PRESET.format(str(e)))

return wrapper


class AutoGGUF(QMainWindow):
def __init__(self):
Expand Down Expand Up @@ -679,53 +692,49 @@ def save_preset(self):
)
self.logger.info(PRESET_SAVED_TO.format(file_name))

@handle_load_preset_error
def load_preset(self):
self.logger.info(LOADING_PRESET)
file_name, _ = QFileDialog.getOpenFileName(self, LOAD_PRESET, "", JSON_FILES)
if file_name:
try:
with open(file_name, "r") as f:
preset = json.load(f)

self.quant_type.clearSelection()
for quant_type in preset.get("quant_types", []):
items = self.quant_type.findItems(quant_type, Qt.MatchExactly)
if items:
items[0].setSelected(True)
self.allow_requantize.setChecked(preset.get("allow_requantize", False))
self.leave_output_tensor.setChecked(
preset.get("leave_output_tensor", False)
)
self.pure.setChecked(preset.get("pure", False))
self.imatrix.setText(preset.get("imatrix", ""))
self.include_weights.setText(preset.get("include_weights", ""))
self.exclude_weights.setText(preset.get("exclude_weights", ""))
self.use_output_tensor_type.setChecked(
preset.get("use_output_tensor_type", False)
)
self.output_tensor_type.setCurrentText(
preset.get("output_tensor_type", "")
)
self.use_token_embedding_type.setChecked(
preset.get("use_token_embedding_type", False)
)
self.token_embedding_type.setCurrentText(
preset.get("token_embedding_type", "")
)
self.keep_split.setChecked(preset.get("keep_split", False))
self.extra_arguments.setText(preset.get("extra_arguments", ""))
with open(file_name, "r") as f:
preset = json.load(f)

self.quant_type.clearSelection()
for quant_type in preset.get("quant_types", []):
items = self.quant_type.findItems(quant_type, Qt.MatchExactly)
if items:
items[0].setSelected(True)
self.allow_requantize.setChecked(preset.get("allow_requantize", False))
self.leave_output_tensor.setChecked(
preset.get("leave_output_tensor", False)
)
self.pure.setChecked(preset.get("pure", False))
self.imatrix.setText(preset.get("imatrix", ""))
self.include_weights.setText(preset.get("include_weights", ""))
self.exclude_weights.setText(preset.get("exclude_weights", ""))
self.use_output_tensor_type.setChecked(
preset.get("use_output_tensor_type", False)
)
self.output_tensor_type.setCurrentText(preset.get("output_tensor_type", ""))
self.use_token_embedding_type.setChecked(
preset.get("use_token_embedding_type", False)
)
self.token_embedding_type.setCurrentText(
preset.get("token_embedding_type", "")
)
self.keep_split.setChecked(preset.get("keep_split", False))
self.extra_arguments.setText(preset.get("extra_arguments", ""))

# Clear existing KV overrides and add new ones
for entry in self.kv_override_entries:
self.remove_kv_override(entry)
for override in preset.get("kv_overrides", []):
self.add_kv_override(override)
# Clear existing KV overrides and add new ones
for entry in self.kv_override_entries:
self.remove_kv_override(entry)
for override in preset.get("kv_overrides", []):
self.add_kv_override(override)

QMessageBox.information(
self, PRESET_LOADED, PRESET_LOADED_FROM.format(file_name)
)
except Exception as e:
QMessageBox.critical(self, ERROR, FAILED_TO_LOAD_PRESET.format(str(e)))
QMessageBox.information(
self, PRESET_LOADED, PRESET_LOADED_FROM.format(file_name)
)
self.logger.info(PRESET_LOADED_FROM.format(file_name))

def save_task_preset(self, task_item):
Expand Down Expand Up @@ -812,9 +821,7 @@ def delete_lora_adapter_item(self, adapter_widget):

def browse_hf_model_input(self):
self.logger.info("Browsing for HuggingFace model directory")
model_dir = QFileDialog.getExistingDirectory(
self, "Select HuggingFace Model Directory"
)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir))

Expand Down Expand Up @@ -1340,11 +1347,6 @@ def cancel_task(self, item):
task_item.update_status(CANCELED)
break

def retry_task(self, item):
task_item = self.task_list.itemWidget(item)
# TODO: Implement the logic to restart the task
pass

def delete_task(self, item):
self.logger.info(DELETING_TASK.format(item.text()))
reply = QMessageBox.question(
Expand Down Expand Up @@ -1666,7 +1668,6 @@ def quantize_model(self):

def update_model_info(self, model_info):
self.logger.debug(UPDATING_MODEL_INFO.format(model_info))
# TODO: Do something with this
pass

def parse_progress(self, line, task_item):
Expand Down Expand Up @@ -1734,17 +1735,6 @@ def browse_imatrix_output(self):
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))

def update_gpu_offload_spinbox(self, value):
self.gpu_offload_spinbox.setValue(value)

def update_gpu_offload_slider(self, value):
self.gpu_offload_slider.setValue(value)

def toggle_gpu_offload_auto(self, state):
is_auto = state == Qt.CheckState.Checked
self.gpu_offload_slider.setEnabled(not is_auto)
self.gpu_offload_spinbox.setEnabled(not is_auto)

def generate_imatrix(self):
self.logger.info(STARTING_IMATRIX_GENERATION)
try:
Expand Down
3 changes: 2 additions & 1 deletion src/error_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PyQt6.QtWidgets import QMessageBox
from localizations import *


def show_error(logger, message):
logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(None, ERROR, message)
Expand All @@ -9,4 +10,4 @@ def show_error(logger, message):
def handle_error(logger, error_message, task_item):
logger.error(TASK_ERROR.format(error_message))
show_error(logger, error_message)
task_item.update_status(ERROR)
task_item.update_status(ERROR)
2 changes: 2 additions & 0 deletions src/localizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self):
self.SPLIT_MAX_SIZE = ""
self.DRY_RUN = ""
self.CONVERT_HF_TO_GGUF = ""
self.SELECT_HF_MODEL_DIRECTORY = ""

# General Messages
self.ERROR = ""
Expand Down Expand Up @@ -372,6 +373,7 @@ def __init__(self):
self.SPLIT_MAX_SIZE = "Split Max Size:"
self.DRY_RUN = "Dry Run"
self.CONVERT_HF_TO_GGUF = "Convert HF to GGUF"
self.SELECT_HF_MODEL_DIRECTORY = "Select HuggingFace Model Directory"

# General Messages
self.ERROR = "Error"
Expand Down

0 comments on commit eca2ecc

Please sign in to comment.