Skip to content

Commit

Permalink
Relax Colorant type
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Mar 18, 2024
1 parent 8a39938 commit 9a7fd07
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/preprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,20 @@ function apply!(buf, ::ImageToTensor, image::Image; randstate = nothing)
return buf
end

function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {C<:Color, N}
function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {C<:Colorant, N}
T.(PermutedDimsArray(_channelview(image), ((i for i in 2:N+1)..., 1)))
end

#=
function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {TC, C<:Color{TC, 1}, N}
function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {TC, C<:Colorant{TC, 1}, N}
return T.(_channelview(image))
end
=#


# TODO: relax color type constraint, implement for other colors
# single-channel colors need a `channelview` that also expands the array
function imagetotensor!(buf, image::AbstractArray{<:Color, N}) where N
function imagetotensor!(buf, image::AbstractArray{<:Colorant, N}) where N
permutedims!(
buf,
_channelview(image),
Expand All @@ -197,22 +197,19 @@ function tensortoimage(a::AbstractArray)
end
end

function tensortoimage(C::Type{<:Color}, a::AbstractArray{T, N}) where {T, N}
function tensortoimage(C::Type{<:Colorant}, a::AbstractArray{T, N}) where {T, N}
perm = (N, 1:N-1...)
return _colorview(C, PermutedDimsArray(a, perm))
end


function _channelview(img)
chview = channelview(img)
# for single-channel colors, expand the color dimension anyway
if size(img) == size(chview)
chview = reshape(chview, 1, size(chview)...)
end
return chview
# for single-channel colors, expand the color dimension anyway
_channelview(img::AbstractArray{<:Colorant{T, N}}) where {T, N} = channelview(img)
function _channelview(img::AbstractArray{<:Colorant{T, 1}}) where T
cv = channelview(img)
return reshape(cv, 1, size(cv)...)
end

function _colorview(C::Type{<:Color}, img)
function _colorview(C::Type{<:Colorant}, img)
if size(img, 1) == 1
img = reshape(img, size(img)[2:end])
end
Expand Down

0 comments on commit 9a7fd07

Please sign in to comment.