Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ bitsandbytes==0.45.5
ExifRead==3.3.1
imagesize==1.4.1
pillow==11.2.1
pillow-jxl-plugin~=1.3.4
pyparsing==3.2.1
PySide6==6.9.0
transformers==4.48.3
Expand Down
1 change: 1 addition & 0 deletions taggui/auto_captioning/auto_captioning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import pillow_jxl
from PIL import Image as PilImage
from PIL.ImageOps import exif_transpose
from transformers import (AutoModelForVision2Seq, AutoProcessor,
Expand Down
73 changes: 49 additions & 24 deletions taggui/models/image_list_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,36 @@
import imagesize
from PySide6.QtCore import (QAbstractListModel, QModelIndex, QMimeData, QPoint,
QRect, QSize, Qt, QUrl, Signal, Slot)
from PySide6.QtGui import QIcon, QImageReader, QPixmap
from PySide6.QtGui import QIcon, QImage, QImageReader, QPixmap
from PySide6.QtWidgets import QMessageBox
import pillow_jxl
from PIL import Image as pilimage # Import Pillow's Image class


from utils.image import Image, ImageMarking, Marking
from utils.jxlutil import get_jxl_size
from utils.settings import DEFAULT_SETTINGS, settings
from utils.utils import get_confirmation_dialog_reply, pluralize
import utils.target_dimension as target_dimension

UNDO_STACK_SIZE = 32

def pil_to_qimage(pil_image):
"""Convert PIL image to QImage properly"""
pil_image = pil_image.convert("RGBA")
data = pil_image.tobytes("raw", "RGBA")
qimage = QImage(data, pil_image.width, pil_image.height, QImage.Format_RGBA8888)
return qimage

def get_file_paths(directory_path: Path) -> set[Path]:
"""
Recursively get all file paths in a directory, including those in
Recursively get all file paths in a directory, including
subdirectories.
"""
file_paths = set()
for path in directory_path.iterdir():
for path in directory_path.rglob("*"): # Use rglob for recursive search
if path.is_file():
file_paths.add(path)
elif path.is_dir():
file_paths.update(get_file_paths(path))
return file_paths


Expand Down Expand Up @@ -85,7 +93,7 @@ def mimeData(self, indexes):
def rowCount(self, parent=None) -> int:
return len(self.images)

def data(self, index, role=None) -> Image | str | QIcon | QSize:
def data(self, index: QModelIndex, role=None) -> Image | str | QIcon | QSize:
image = self.images[index.row()]
if role == Qt.ItemDataRole.UserRole:
return image
Expand All @@ -101,21 +109,37 @@ def data(self, index, role=None) -> Image | str | QIcon | QSize:
# it. Otherwise, generate a thumbnail and save it to the image.
if image.thumbnail:
return image.thumbnail
image_reader = QImageReader(str(image.path))
# Rotate the image based on the orientation tag.
image_reader.setAutoTransform(True)
if image.crop:
crop = image.crop
else:
crop = QRect(QPoint(0, 0), image_reader.size())
if crop.height() > crop.width()*3:
# keep it reasonable, higher than 3x the width doesn't make sense
crop.setTop((crop.height() - crop.width()*3)//2) # center crop
crop.setHeight(crop.width()*3)
image_reader.setClipRect(crop)
pixmap = QPixmap.fromImageReader(image_reader).scaledToWidth(
self.image_list_image_width,
Qt.TransformationMode.SmoothTransformation)
crop = image.crop
try:
if image.path.suffix.lower() == ".jxl":
pil_image = pilimage.open(image.path) # Uses pillow-jxl
qimage = pil_to_qimage(pil_image)
if not crop:
crop = QRect(QPoint(0, 0), qimage.size())
if crop.height() > crop.width()*3:
# keep it reasonable, higher than 3x the width doesn't make sense
crop.setTop((crop.height() - crop.width()*3)//2) # center crop
crop.setHeight(crop.width()*3)

pixmap = QPixmap.fromImage(qimage).scaledToWidth(
self.image_list_image_width,
Qt.TransformationMode.SmoothTransformation)
else:
image_reader = QImageReader(str(image.path))
# Rotate the image based on the orientation tag.
image_reader.setAutoTransform(True)
if not crop:
crop = QRect(QPoint(0, 0), image_reader.size())
if crop.height() > crop.width()*3:
# keep it reasonable, higher than 3x the width doesn't make sense
crop.setTop((crop.height() - crop.width()*3)//2) # center crop
crop.setHeight(crop.width()*3)
image_reader.setClipRect(crop)
pixmap = QPixmap.fromImageReader(image_reader).scaledToWidth(
self.image_list_image_width,
Qt.TransformationMode.SmoothTransformation)
except Exception as e:
print(f"Error loading image {image.path}: {e}")
thumbnail = QIcon(pixmap)
image.thumbnail = thumbnail
return thumbnail
Expand Down Expand Up @@ -167,9 +191,10 @@ def load_directory(self, directory_path: Path):
if path.suffix == '.json'}
for image_path in image_paths:
try:
dimensions = imagesize.get(image_path)
# Check the Exif orientation tag and rotate the dimensions if
# necessary.
if str(image_path).endswith('jxl'):
dimensions = get_jxl_size(image_path)
else:
dimensions = pilimage.open(image_path).size
with open(image_path, 'rb') as image_file:
try:
exif_tags = exifread.process_file(
Expand Down
184 changes: 184 additions & 0 deletions taggui/utils/jxlutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Modified from https://github.com/Fraetor/jxl_decode
# Added partial read support for up to 200x speedup
import os

class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""

def __init__(self, file, offset=0, offsets=[]) -> None:
self.shift = 0
self.bitstream = []
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)

def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream += self.file.read(self.partial_to_read_length)
else:
self.bitstream += self.file.read(length)
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits

def partial_read(self, current_length, length):
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream += self.file.read(to_read_length)
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)


def decode_codestream(file, offset=0, offsets=[]):
"""
Decodes the actual codestream.
JXL codestream specification: https://www.iso.org/standard/85066.html
"""

# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)

# Skip signature
codestream.get_bits(16)

# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height


def decode_container(file):
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""

def parse_box(file, file_start) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}

file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")

offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False

if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)

return decode_codestream(file, offset=offset, offsets=offsets)


def get_jxl_size(path):
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)
4 changes: 2 additions & 2 deletions taggui/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Defaults for settings that are accessed from multiple places.
DEFAULT_SETTINGS = {
'font_size': 16,
# Common image formats that are supported in PySide6.
'image_list_file_formats': 'bmp, gif, jpg, jpeg, png, tif, tiff, webp',
# Common image formats that are supported in PySide6, as well as JPEG XL.
'image_list_file_formats': 'bmp, gif, jpg, jpeg, jxl, png, tif, tiff, webp',
'image_list_image_width': 200,
'tag_separator': ',',
'insert_space_after_tag_separator': True,
Expand Down
16 changes: 14 additions & 2 deletions taggui/widgets/image_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from math import ceil, floor, sqrt
from PySide6.QtCore import (QModelIndex, QPersistentModelIndex, QPoint, QPointF,
QRect, QRectF, QSize, Qt, Signal, Slot)
from PySide6.QtGui import (QAction, QActionGroup, QColor, QIcon,
from PySide6.QtGui import (QAction, QActionGroup, QColor, QIcon, QImage,
QPainter, QPainterPath, QPen, QPixmap, QTransform,
QMouseEvent)
from PySide6.QtWidgets import (QGraphicsItem, QGraphicsLineItem,
QGraphicsPixmapItem, QGraphicsRectItem,
QGraphicsTextItem, QGraphicsScene, QGraphicsView,
QMenu, QVBoxLayout, QWidget)
from PIL import Image as pilimage
from utils.settings import settings
from models.proxy_image_list_model import ProxyImageListModel
from utils.image import Image, ImageMarking, Marking
Expand Down Expand Up @@ -672,7 +673,18 @@ def load_image(self, proxy_image_index: QModelIndex, is_complete = True):
if is_complete:
self.marking_items.clear()
self.view.clear_scene()
pixmap = QPixmap(str(image.path))
if image.path.suffix.lower() == ".jxl":
pil_image = pilimage.open(image.path) # Decode JXL using Pillow
pil_image = pil_image.convert("RGBA") # Ensure RGBA format

pixmap = QPixmap(QImage(
pil_image.tobytes("raw", "RGBA"),
pil_image.width,
pil_image.height,
QImage.Format_RGBA8888
))
else:
pixmap = QPixmap(str(image.path))
image_item = QGraphicsPixmapItem(pixmap)
image_item.setZValue(0)
self.scene.setSceneRect(image_item.boundingRect()
Expand Down