Skip to content

Commit

Permalink
Add support for saving 16 bit and 32 bit images (#2083)
Browse files Browse the repository at this point in the history
* Add support for saving 16 bit and 32 bit images

* Option labels
  • Loading branch information
RunDevelopment authored Aug 12, 2023
1 parent 284c01c commit e1212bd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
28 changes: 16 additions & 12 deletions backend/src/nodes/impl/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,33 +111,37 @@ def normalize(img: np.ndarray) -> np.ndarray:
def to_uint8(
img: np.ndarray,
normalized=False,
dither=False,
) -> np.ndarray:
"""
Returns a new uint8 image with the given image data.
If `normalized` is `False`, then the image will be normalized before being converted to uint8.
If `dither` is `True`, then dithering will be used to minimize the quantization error.
"""
if img.dtype == np.uint8:
return img.copy()

if not normalized or img.dtype != np.float32:
img = normalize(img)

if not dither:
return (img * 255).round().astype(np.uint8)
return (img * 255).round().astype(np.uint8)


# random dithering
truth = img * 255
quant = truth.round()
def to_uint16(
img: np.ndarray,
normalized=False,
) -> np.ndarray:
"""
Returns a new uint8 image with the given image data.
err = truth - quant
r = np.random.default_rng(0).uniform(0, 1, img.shape).astype(np.float32)
quant += np.sign(err) * (np.abs(err) > r)
If `normalized` is `False`, then the image will be normalized before being converted to uint8.
"""
if img.dtype == np.uint16:
return img.copy()

if not normalized or img.dtype != np.float32:
img = normalize(img)

return quant.astype(np.uint8)
return (img * 65535).round().astype(np.uint16)


def shift(img: np.ndarray, amount_x: int, amount_y: int, fill: FillColor) -> np.ndarray:
Expand Down
68 changes: 63 additions & 5 deletions backend/src/packages/chaiNNer_standard/image/io/save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Literal

import cv2
import numpy as np
Expand All @@ -19,7 +19,7 @@
to_dxgi,
)
from nodes.impl.dds.texconv import save_as_dds
from nodes.impl.image_utils import cv_save_image, to_uint8
from nodes.impl.image_utils import cv_save_image, to_uint8, to_uint16
from nodes.properties.inputs import (
SUPPORTED_DDS_FORMATS,
BoolInput,
Expand Down Expand Up @@ -87,6 +87,17 @@ class JpegSubsampling(Enum):
FACTOR_420 = int(cv2.IMWRITE_JPEG_SAMPLING_FACTOR_420)


class PngColorDepth(Enum):
U8 = "u8"
U16 = "u16"


class TiffColorDepth(Enum):
U8 = "u8"
U16 = "u16"
F32 = "f32"


@io_group.register(
schema_id="chainner:image:save",
name="Save Image",
Expand All @@ -111,6 +122,17 @@ class JpegSubsampling(Enum):
default_value=ImageFormat.PNG,
option_labels=IMAGE_FORMAT_LABELS,
).with_id(4),
if_enum_group(4, ImageFormat.PNG)(
EnumInput(
PngColorDepth,
"Color Depth",
default_value=PngColorDepth.U8,
option_labels={
PngColorDepth.U8: "8 Bits/Channel",
PngColorDepth.U16: "16 Bits/Channel",
},
).with_id(15),
),
if_enum_group(4, ImageFormat.WEBP)(
BoolInput("Lossless", default=False).with_id(14),
),
Expand Down Expand Up @@ -140,6 +162,18 @@ class JpegSubsampling(Enum):
).with_id(11),
BoolInput("Progressive", default=False).with_id(12),
),
if_enum_group(4, ImageFormat.TIFF)(
EnumInput(
TiffColorDepth,
"Color Depth",
default_value=TiffColorDepth.U8,
option_labels={
TiffColorDepth.U8: "8 Bits/Channel",
TiffColorDepth.U16: "16 Bits/Channel",
TiffColorDepth.F32: "32 Bits/Channel (Float)",
},
).with_id(16),
),
if_enum_group(4, ImageFormat.DDS)(
DdsFormatDropdown().with_id(6),
if_enum_group(6, SUPPORTED_BC7_FORMATS)(
Expand Down Expand Up @@ -172,10 +206,12 @@ def save_image_node(
relative_path: str | None,
filename: str,
image_format: ImageFormat,
png_color_depth: PngColorDepth,
webp_lossless: bool,
quality: int,
jpeg_chroma_subsampling: JpegSubsampling,
jpeg_progressive: bool,
tiff_color_depth: TiffColorDepth,
dds_format: DDSFormat,
dds_bc7_compression: BC7Compression,
dds_error_metric: DDSErrorMetric,
Expand All @@ -191,11 +227,11 @@ def save_image_node(
# Create directory if it doesn't exist
os.makedirs(base_directory, exist_ok=True)

# Put image back in int range
img = to_uint8(img, normalized=True)

# DDS files are handled separately
if image_format == ImageFormat.DDS:
# we only support 8bits of precision for DDS
img = to_uint8(img, normalized=True)

# remap legacy DX9 formats
legacy_dds = dds_format in LEGACY_TO_DXGI

Expand All @@ -215,6 +251,9 @@ def save_image_node(

# Some formats are handled by PIL
if image_format == ImageFormat.GIF or image_format == ImageFormat.TGA:
# we only support 8bits of precision for those formats
img = to_uint8(img, normalized=True)

channels = get_h_w_c(img)[2]
if channels == 1:
# PIL supports grayscale images just fine, so we don't need to do any conversion
Expand Down Expand Up @@ -248,6 +287,25 @@ def save_image_node(
else:
params = []

# the bit depth depends on the image format and settings
precision: Literal["u8", "u16", "f32"] = "u8"
if image_format == ImageFormat.PNG:
if png_color_depth == PngColorDepth.U16:
precision = "u16"
elif image_format == ImageFormat.TIFF:
if tiff_color_depth == TiffColorDepth.U16:
precision = "u16"
elif tiff_color_depth == TiffColorDepth.F32:
precision = "f32"

if precision == "u8":
img = to_uint8(img, normalized=True)
elif precision == "u16":
img = to_uint16(img, normalized=True)
elif precision == "f32":
# chainner images are always f32
pass

cv_save_image(full_path, img, params)


Expand Down

0 comments on commit e1212bd

Please sign in to comment.