Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save captioning in history.jsonl file with settings and selected images #187

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pillow==10.3.0
pyparsing==3.1.2
PySide6==6.7.1
transformers==4.41.2
gitpython==4.0.11

# PyTorch
torch==2.2.2; platform_system != "Windows"
Expand Down
6 changes: 6 additions & 0 deletions taggui/auto_captioning/captioning_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_xcomposer2_error_message,
get_xcomposer2_inputs)
from models.image_list_model import ImageListModel
from widgets.history_list import HistoryListModel
from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition
from utils.image import Image
from utils.settings import get_tag_separator
Expand Down Expand Up @@ -143,10 +144,12 @@ class CaptioningThread(QThread):

def __init__(self, parent, image_list_model: ImageListModel,
selected_image_indices: list[QModelIndex],
history_list_model: HistoryListModel,
caption_settings: dict, tag_separator: str,
models_directory_path: Path | None):
super().__init__(parent)
self.image_list_model = image_list_model
self.history_list_model = history_list_model
self.selected_image_indices = selected_image_indices
self.caption_settings = caption_settings
self.tag_separator = tag_separator
Expand Down Expand Up @@ -396,6 +399,9 @@ def run(self):
print(error_message)
return
processor, model = self.load_processor_and_model(device, model_type)

self.history_list_model.append(self.caption_settings, model, self.image_list_model, self.selected_image_indices)

# CogVLM and CogAgent have to be monkey patched every time because
# `caption_start` might have changed.
caption_start = self.caption_settings['caption_start']
Expand Down
2 changes: 2 additions & 0 deletions taggui/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
'insert_space_after_tag_separator': True,
'autocomplete_tags': True,
'models_directory_path': ''
# directory_path: '' # added by main_window.load_directory
# more added by auto_captioner.get_caption_settings
}


Expand Down
8 changes: 8 additions & 0 deletions taggui/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import git
import sys
from pathlib import Path

Expand Down Expand Up @@ -39,3 +40,10 @@ def get_confirmation_dialog_reply(title: str, question: str) -> int:
| QMessageBox.StandardButton.Cancel)
confirmation_dialog.setDefaultButton(QMessageBox.StandardButton.Yes)
return confirmation_dialog.exec()

def get_repo_infos(path: str) -> dict[str, str]:
repo = git.Repo(path, search_parent_directories=True)
origin = repo.remotes.origin.url
revision = repo.head.commit.hexsha
ret = { "origin": origin, "revision": revision }
return ret
40 changes: 37 additions & 3 deletions taggui/widgets/auto_captioner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from pathlib import Path
import json
from datetime import datetime

from PySide6.QtCore import QModelIndex, Qt, Signal, Slot
from PySide6.QtGui import QFontMetrics, QTextCursor
Expand All @@ -11,6 +13,7 @@
from auto_captioning.captioning_thread import CaptioningThread
from auto_captioning.models import MODELS, get_model_type
from models.image_list_model import ImageListModel
from widgets.history_list import HistoryListModel
from utils.big_widgets import TallPushButton
from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition
from utils.settings import DEFAULT_SETTINGS, get_settings, get_tag_separator
Expand All @@ -37,7 +40,6 @@ def set_text_edit_height(text_edit: QPlainTextEdit, line_count: int):
+ text_edit.frameWidth() * 2)
text_edit.setFixedHeight(height)


class HorizontalLine(QFrame):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -340,6 +342,37 @@ def get_caption_settings(self) -> dict:
}
}

def set_captions_settings(self, caption_settings: dict):
if 'model' in caption_settings: self.model_combo_box.setCurrentText(caption_settings['model'])
if 'prompt' in caption_settings: self.prompt_text_edit.setPlainText(caption_settings['prompt'])
if 'caption_start' in caption_settings: self.caption_start_line_edit.setText(caption_settings['caption_start'])
if 'caption_position' in caption_settings: self.caption_position_combo_box.setCurrentText(caption_settings['caption_position'])
if 'device' in caption_settings: self.device_combo_box.setCurrentText(caption_settings['device'])
if 'gpu_index' in caption_settings: self.gpu_index_spin_box.setValue(caption_settings['gpu_index'])
if 'load_in_4_bit' in caption_settings: self.load_in_4_bit_check_box.setChecked(caption_settings['load_in_4_bit'])
if 'remove_tag_separators' in caption_settings: self.remove_tag_separators_check_box.setChecked(caption_settings['remove_tag_separators'])
if 'bad_words' in caption_settings: self.bad_words_line_edit.setText(caption_settings['bad_words'])
if 'forced_words' in caption_settings: self.forced_words_line_edit.setText(caption_settings['forced_words'])

if 'generation_parameters' in caption_settings:
generation_parameters = caption_settings['generation_parameters']
if 'min_new_tokens' in generation_parameters: self.min_new_token_count_spin_box.setValue(generation_parameters['min_new_tokens'])
if 'max_new_tokens' in generation_parameters: self.max_new_token_count_spin_box.setValue(generation_parameters['max_new_tokens'])
if 'num_beams' in generation_parameters: self.beam_count_spin_box.setValue(generation_parameters['num_beams'])
if 'length_penalty' in generation_parameters: self.length_penalty_spin_box.setValue(generation_parameters['length_penalty']),
if 'do_sample' in generation_parameters: self.use_sampling_check_box.setChecked(generation_parameters['do_sample']),
if 'temperature' in generation_parameters: self.temperature_spin_box.setValue(generation_parameters['temperature']),
if 'top_k' in generation_parameters: self.top_k_spin_box.setValue(generation_parameters['top_k']),
if 'top_p' in generation_parameters: self.top_p_spin_box.setValue(generation_parameters['top_p']),
if 'repetition_penalty' in generation_parameters: self.repetition_penalty_spin_box.setValue(generation_parameters['repetition_penalty']),
if 'no_repeat_ngram_size' in generation_parameters: self.no_repeat_ngram_size_spin_box.setValue(generation_parameters['no_repeat_ngram_size'])

if 'wd_tagger_settings' in caption_settings:
wd_tagger_settings = caption_settings['wd_tagger_settings']
if 'show_probabilities' in wd_tagger_settings: self.show_probabilities_check_box.isChecked(wd_tagger_settings['show_probabilities']),
if 'min_probability' in wd_tagger_settings: self.min_probability_spin_box.value(wd_tagger_settings['min_probability']),
if 'max_tags' in wd_tagger_settings: self.max_tags_spin_box.value(wd_tagger_settings['max_tags']),
if 'tags_to_exclude' in wd_tagger_settings: self.tags_to_exclude_text_edit.toPlainText(wd_tagger_settings['tags_to_exclude'])

@Slot()
def restore_stdout_and_stderr():
Expand All @@ -351,10 +384,11 @@ class AutoCaptioner(QDockWidget):
caption_generated = Signal(QModelIndex, str, list)

def __init__(self, image_list_model: ImageListModel,
image_list: ImageList):
image_list: ImageList, history_list_model: HistoryListModel):
super().__init__()
self.image_list_model = image_list_model
self.image_list = image_list
self.history_list_model = history_list_model
self.settings = get_settings()
self.is_captioning = False
self.captioning_thread = None
Expand Down Expand Up @@ -467,7 +501,7 @@ def generate_captions(self):
models_directory_path = (Path(models_directory_path)
if models_directory_path else None)
self.captioning_thread = CaptioningThread(
self, self.image_list_model, selected_image_indices,
self, self.image_list_model, selected_image_indices, self.history_list_model,
caption_settings, tag_separator, models_directory_path)
self.captioning_thread.text_outputted.connect(
self.update_console_text_edit)
Expand Down
130 changes: 130 additions & 0 deletions taggui/widgets/history_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import json
from pathlib import Path
from datetime import datetime
from typing import Any, Callable, Dict

from transformers import AutoModel
from PySide6.QtCore import (QAbstractListModel, QModelIndex, Qt)
from PySide6.QtWidgets import (QDockWidget, QListView, QVBoxLayout, QWidget)

from auto_captioning.models import get_model_type
from models.image_list_model import ImageListModel
from utils.enums import CaptionModelType

class HistoryListModel(QAbstractListModel):
def __init__(self, repo_infos):
super().__init__()
self.history_list = []
self.app_infos = repo_infos
self.image_directory_path: Path | None = None

def data(self, index: QModelIndex, role: Qt.ItemDataRole):
if not index.isValid():
return None
item = self.history_list[index.row()]
if role == Qt.UserRole:
return item
if role == Qt.DisplayRole:
ret = f"{item['date']} '{item['app']['settings']['prompt'][:20]}'"
return ret

def rowCount(self, parent: QModelIndex | None=None) -> int:
return len(self.history_list)

def load_directory(self, image_directory_path: Path):
self.beginResetModel()
self.image_directory_path = image_directory_path
self.history_list = []
history_path = image_directory_path / "!0_history.jsonl"
if history_path.exists():
with open(history_path) as file:
for line in file:
entry = json.loads(line)
self.history_list.append(entry)
self.endResetModel()

def append(self, caption_settings: dict, model: AutoModel, image_list_model: ImageListModel, selected_image_indices: list[QModelIndex]) -> None:
caption_settings = caption_settings.copy()
model_id = caption_settings["model"]
model_type = get_model_type(model_id)

# clean up settings
del_keys = ["device", "gpu_index"]
for del_key in del_keys:
del caption_settings[del_key]
if model_type == CaptionModelType.WD_TAGGER:
del caption_settings["generation_parameters"]
else:
del caption_settings["wd_tagger_settings"]
if not caption_settings["generation_parameters"]["do_sample"]:
del_keys = ["temperature", "top_k", "top_p", "repetition_penalty", "no_repeat_ngram_size"]
for del_key in del_keys:
del caption_settings["generation_parameters"][del_key]

# app infos
app = { **self.app_infos, "settings": caption_settings }

# model infos
model_config = model.config
model = {
"name": model_config.name_or_path,
#"name": model.pretrained_model_name_or_path,
#"name": model.model_name_or_path,
"type": str(model_config.model_type),
#"revision": model_config.revision,
}

# images
images = []
if self.image_directory_path != None:
images = sorted([str(image_list_model.images[i.row()].path.relative_to(self.image_directory_path)) for i in selected_image_indices])

# collect infos
entry = {
"date": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
"history_version": 0,
"app": app,
"model": model,
"images": images,
}

# append to internal list
self.beginInsertRows(QModelIndex(), self.rowCount(), self.rowCount())
self.history_list.append(entry)
self.endInsertRows()

# append to history file
if self.image_directory_path != None:
with open(f"{self.image_directory_path}/!0_history.jsonl", "a") as file:
json_str = json.dumps(entry, separators=(',', ':'))
file.write(json_str + "\r\n")

class HistoryList(QDockWidget):
def __init__(self, model: HistoryListModel):
super().__init__()
self.set_captions_settings: Callable[[Dict[str, Any]], None] | None = None
self.setObjectName('history_list')
self.setWindowTitle('History')
self.setAllowedAreas(Qt.DockWidgetArea.LeftDockWidgetArea
| Qt.DockWidgetArea.RightDockWidgetArea)

container = QWidget()

self.listView = QListView()
self.listView.setModel(model)
self.listView.clicked.connect(self.item_clicked)

layout = QVBoxLayout(container)
layout.addWidget(self.listView)

self.setWidget(container)

def item_clicked(self, current: QModelIndex):
if current.isValid():
index = self.listView.currentIndex()
entry = self.listView.model().data(index, Qt.UserRole)
caption_settings = entry['app']['settings']

if self.set_captions_settings is not None:
#print(json.dumps(caption_settings))
self.set_captions_settings(caption_settings)
20 changes: 18 additions & 2 deletions taggui/widgets/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from utils.key_press_forwarder import KeyPressForwarder
from utils.settings import DEFAULT_SETTINGS, get_settings, get_tag_separator
from utils.shortcut_remover import ShortcutRemover
from utils.utils import get_resource_path, pluralize
from utils.utils import get_repo_infos, get_resource_path, pluralize
from widgets.all_tags_editor import AllTagsEditor
from widgets.auto_captioner import AutoCaptioner
from widgets.image_list import ImageList
from widgets.image_tags_editor import ImageTagsEditor
from widgets.image_viewer import ImageViewer
from widgets.history_list import HistoryList, HistoryListModel

ICON_PATH = Path('images/icon.ico')
GITHUB_REPOSITORY_URL = 'https://github.com/jhc13/taggui'
Expand All @@ -49,6 +50,7 @@ def __init__(self, app: QApplication):
self.image_list_model, tokenizer, tag_separator)
self.image_list_model.proxy_image_list_model = (
self.proxy_image_list_model)
self.history_list_model = HistoryListModel(get_repo_infos(__file__))
self.tag_counter_model = TagCounterModel()
self.image_tag_list_model = ImageTagListModel()

Expand All @@ -64,6 +66,10 @@ def __init__(self, app: QApplication):
tag_separator, image_list_image_width)
self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea,
self.image_list)
self.history_list = HistoryList(self.history_list_model)
self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea,
self.history_list)
self.tabifyDockWidget(self.image_list, self.history_list)
self.image_tags_editor = ImageTagsEditor(
self.proxy_image_list_model, self.tag_counter_model,
self.image_tag_list_model, self.image_list, tokenizer,
Expand All @@ -76,7 +82,7 @@ def __init__(self, app: QApplication):
self.addDockWidget(Qt.DockWidgetArea.RightDockWidgetArea,
self.all_tags_editor)
self.auto_captioner = AutoCaptioner(self.image_list_model,
self.image_list)
self.image_list, self.history_list_model)
self.addDockWidget(Qt.DockWidgetArea.RightDockWidgetArea,
self.auto_captioner)
self.tabifyDockWidget(self.all_tags_editor, self.auto_captioner)
Expand All @@ -99,6 +105,7 @@ def __init__(self, app: QApplication):
self.undo_action = QAction('Undo', parent=self)
self.redo_action = QAction('Redo', parent=self)
self.toggle_image_list_action = QAction('Images', parent=self)
self.toggle_history_list_action = QAction('History', parent=self)
self.toggle_image_tags_editor_action = QAction('Image Tags',
parent=self)
self.toggle_all_tags_editor_action = QAction('All Tags', parent=self)
Expand All @@ -110,6 +117,7 @@ def __init__(self, app: QApplication):
.selectionModel())
self.image_list_model.image_list_selection_model = (
self.image_list_selection_model)
self.history_list.set_captions_settings = self.auto_captioner.caption_settings_form.set_captions_settings
self.connect_image_list_signals()
self.connect_image_tags_editor_signals()
self.connect_all_tags_editor_signals()
Expand Down Expand Up @@ -207,6 +215,7 @@ def load_directory(self, path: Path, select_index: int = 0):
self.settings.setValue('directory_path', str(path))
self.setWindowTitle(path.name)
self.image_list_model.load_directory(path)
self.history_list_model.load_directory(path)
self.image_list.filter_line_edit.clear()
self.all_tags_editor.filter_line_edit.clear()
# Clear the current index first to make sure that the `currentChanged`
Expand Down Expand Up @@ -352,18 +361,22 @@ def create_menus(self):

view_menu = menu_bar.addMenu('View')
self.toggle_image_list_action.setCheckable(True)
self.toggle_history_list_action.setCheckable(True)
self.toggle_image_tags_editor_action.setCheckable(True)
self.toggle_all_tags_editor_action.setCheckable(True)
self.toggle_auto_captioner_action.setCheckable(True)
self.toggle_image_list_action.triggered.connect(
lambda is_checked: self.image_list.setVisible(is_checked))
self.toggle_history_list_action.triggered.connect(
lambda is_checked: self.history_list.setVisible(is_checked))
self.toggle_image_tags_editor_action.triggered.connect(
lambda is_checked: self.image_tags_editor.setVisible(is_checked))
self.toggle_all_tags_editor_action.triggered.connect(
lambda is_checked: self.all_tags_editor.setVisible(is_checked))
self.toggle_auto_captioner_action.triggered.connect(
lambda is_checked: self.auto_captioner.setVisible(is_checked))
view_menu.addAction(self.toggle_image_list_action)
view_menu.addAction(self.toggle_history_list_action)
view_menu.addAction(self.toggle_image_tags_editor_action)
view_menu.addAction(self.toggle_all_tags_editor_action)
view_menu.addAction(self.toggle_auto_captioner_action)
Expand Down Expand Up @@ -459,6 +472,9 @@ def connect_image_list_signals(self):
self.image_list.visibilityChanged.connect(
lambda: self.toggle_image_list_action.setChecked(
self.image_list.isVisible()))
self.history_list.visibilityChanged.connect(
lambda: self.toggle_history_list_action.setChecked(
self.history_list.isVisible()))

@Slot()
def update_image_tags(self):
Expand Down