diff --git a/ai_diffusion/ui/region.py b/ai_diffusion/ui/region.py index 6d7b08758..1ed8d56a6 100644 --- a/ai_diffusion/ui/region.py +++ b/ai_diffusion/ui/region.py @@ -1,7 +1,16 @@ from __future__ import annotations from enum import Enum from PyQt5.QtWidgets import QWidget, QLabel, QToolButton, QHBoxLayout, QVBoxLayout, QFrame, QMenu -from PyQt5.QtGui import QGuiApplication, QMouseEvent, QResizeEvent, QPixmap, QImage, QPainter, QIcon +from PyQt5.QtGui import ( + QGuiApplication, + QMouseEvent, + QResizeEvent, + QPixmap, + QImage, + QPainter, + QIcon, + QFontMetrics, +) from PyQt5.QtCore import QObject, QEvent, Qt, QMetaObject, QSize, pyqtSignal from ..root import root @@ -126,9 +135,11 @@ def __init__(self, root: RootRegion, parent: QWidget, header=PromptHeader.full): self.positive = TextPromptWidget(parent=self) self.positive.line_count = min(settings.prompt_line_count, self._max_lines) + self.positive.handle_dragged.connect(self._handle_dragging) self.positive.install_event_filter(self) self.negative = TextPromptWidget(line_count=1, is_negative=True, parent=self) + self.negative.handle_dragged.connect(self._handle_dragging) self.negative.install_event_filter(self) self._no_region = QWidget(self) @@ -181,6 +192,7 @@ def __init__(self, root: RootRegion, parent: QWidget, header=PromptHeader.full): ) self._language_button.clicked.connect(self._toggle_translation_enabled) self._layout_language_button() + self._setup_resize_handle() self._setup_bindings(self._region) settings.changed.connect(self.update_settings) @@ -343,6 +355,7 @@ def update_settings(self, key: str, value): self.negative.text = "" self.negative.setVisible(value and isinstance(self._region, RootRegion)) self._layout_language_button() + self._setup_resize_handle() elif key == "prompt_translation": self._update_language() @@ -399,6 +412,28 @@ def _layout_language_button(self): self._language_button.move(pos.x() - s.width() - 2, pos.y() - s.height() - 2) self._language_button.resize(s) + def _setup_resize_handle(self): + if settings.show_negative_prompt: + self.positive.set_resize_handle(False) + self.negative.set_resize_handle(True) + else: + self.positive.set_resize_handle(True) + self.negative.set_resize_handle(False) + + def _handle_dragging(self, y_pos: int): + # math determined experimentally, sorry :( + if self.negative.isVisible(): + pos_height = self.positive.contentsRect().height() + neg_height = self.negative.contentsRect().height() + new_height = y_pos - neg_height + pos_height - 10 + else: + new_height = y_pos - 5 + fm = QFontMetrics(ensure(self.positive._multi.document()).defaultFont()) + new_line_count = round(new_height / fm.lineSpacing()) + if 1 <= new_line_count <= 10: + settings.prompt_line_count = new_line_count + self.positive.line_count = new_line_count + def resizeEvent(self, a0): super().resizeEvent(a0) self._layout_language_button() diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index 16d49ec47..275011ecf 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -440,6 +440,42 @@ def keyPressEvent(self, a0: QKeyEvent | None): super().keyPressEvent(a0) +class ResizeHandle(QWidget): + """A small resize handle that appears at the bottom of the prompt widget.""" + + handle_dragged = pyqtSignal(int) + + def __init__(self, parent: QWidget): + super().__init__(parent) + self.setCursor(Qt.CursorShape.SizeVerCursor) + self.setFixedSize(22, 8) + self._dragging = False + + def mousePressEvent(self, a0: QMouseEvent | None) -> None: + if ensure(a0).button() == Qt.MouseButton.LeftButton: + self._dragging = True + + def mouseReleaseEvent(self, a0: QMouseEvent | None) -> None: + self._dragging = False + + def mouseMoveEvent(self, a0: QMouseEvent | None) -> None: + if not self._dragging: + return + y_pos = self.mapToParent(ensure(a0).pos()).y() + self.handle_dragged.emit(y_pos) + + def paintEvent(self, a0: QPaintEvent | None) -> None: + if not self.isVisible(): + return + painter = QPainter(self) + painter.setPen(self.palette().color(QPalette.ColorRole.PlaceholderText).lighter(100)) + painter.setBrush(painter.pen().color()) + w, h = self.width(), self.height() + for i, x in enumerate(range(2, w - 1, 3)): + y = 2 * h // 3 if i % 2 == 0 else h // 3 + painter.drawEllipse(x - 1, y - 1, 2, 2) + + class TextPromptWidget(QFrame): """Wraps a single or multi-line text widget, with ability to switch between them. Using QPlainTextEdit set to a single line doesn't work properly because it still @@ -447,11 +483,12 @@ class TextPromptWidget(QFrame): activated = pyqtSignal() text_changed = pyqtSignal(str) + handle_dragged = pyqtSignal(int) _line_count = 2 _is_negative = False - def __init__(self, line_count=2, is_negative=False, parent=None): + def __init__(self, line_count=2, is_negative=False, parent=None, resize_handle=False): super().__init__(parent) self._line_count = line_count self._is_negative = is_negative @@ -474,6 +511,9 @@ def __init__(self, line_count=2, is_negative=False, parent=None): self._layout.addWidget(self._multi) self._layout.addWidget(self._single) + self._resize_handle = ResizeHandle(self) + self._resize_handle.setVisible(False) + palette: QPalette = self._multi.palette() self._base_color = palette.color(QPalette.ColorRole.Base) self.is_negative = self._is_negative @@ -496,6 +536,22 @@ def text(self, value: str): with SignalBlocker(widget): # avoid auto-completion on non-user input widget.setText(value) + def set_resize_handle(self, value: bool): + if value and not self._resize_handle.isVisible(): + self._resize_handle.setVisible(True) + self._resize_handle.handle_dragged.connect(self.handle_dragged) + self._place_resize_handle() + if not value and self._resize_handle.isVisible(): + self._resize_handle.setVisible(False) + self._resize_handle.handle_dragged.disconnect(self.handle_dragged) + + def _place_resize_handle(self): + rect = self.geometry() + self._resize_handle.move( + (rect.width() - self._resize_handle.width()) // 2, + rect.height() - self._resize_handle.height(), + ) + @property def line_count(self): return self._line_count @@ -510,6 +566,11 @@ def line_count(self, value: int): if self._line_count > 1: self._multi.line_count = self._line_count + def resizeEvent(self, a0): + super().resizeEvent(a0) + if self._resize_handle: + self._place_resize_handle() + @property def is_negative(self): return self._is_negative