Skip to content

Commit

Permalink
Fix logarithmic units (#4707)
Browse files Browse the repository at this point in the history
  • Loading branch information
gustaphe authored Apr 4, 2023
1 parent 8af7be3 commit 082b3bf
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
43 changes: 27 additions & 16 deletions ext/UnitfulExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ module UnitfulExt

import Plots: Plots, @ext_imp_use, @recipe, PlotText, Subplot, AVec, AMat, Axis
import RecipesBase
@ext_imp_use :import Unitful Quantity unit ustrip Unitful dimension Units NoUnits
@ext_imp_use :import Unitful Quantity unit ustrip Unitful dimension Units NoUnits LogScaled logunit MixedUnits Level Gain uconvert

const MissingOrQuantity = Union{Missing,<:Quantity}
const MissingOrQuantity = Union{Missing,<:Quantity,<:LogScaled}

#==========
Main recipe
Expand All @@ -17,7 +17,7 @@ Main recipe
axisletter = plotattributes[:letter] # x, y, or z
clims_types = (:contour, :contourf, :heatmap, :surface)
if axisletter === :z && get(plotattributes, :seriestype, :nothing) clims_types
u = get(plotattributes, :zunit, unit(eltype(x)))
u = get(plotattributes, :zunit, _unit(eltype(x)))
ustripattribute!(plotattributes, :clims, u)
append_unit_if_needed!(plotattributes, :colorbar_title, u)
end
Expand All @@ -33,7 +33,7 @@ function fixaxis!(attr, x, axisletter)
axisunit = Symbol(axisletter, :unit) # xunit, yunit, zunit
axis = Symbol(axisletter, :axis) # xaxis, yaxis, zaxis
# Get the unit
u = pop!(attr, axisunit, unit(eltype(x)))
u = pop!(attr, axisunit, _unit(eltype(x)))
# If the subplot already exists with data, get its unit
sp = get(attr, :subplot, 1)
if sp length(attr[:plot_object]) && attr[:plot_object].n > 0
Expand All @@ -54,12 +54,12 @@ function fixaxis!(attr, x, axisletter)
fixmarkersize!(attr)
fixlinecolor!(attr)
# Strip the unit
ustrip.(u, x)
_ustrip.(u, x)
end

# Recipe for (x::AVec, y::AVec, z::Surface) types
@recipe function f(x::AVec, y::AVec, z::AMat{T}) where {T<:Quantity} # COV_EXCL_LINE
u = get(plotattributes, :zunit, unit(eltype(z)))
u = get(plotattributes, :zunit, _unit(eltype(z)))
ustripattribute!(plotattributes, :clims, u)
z = fixaxis!(plotattributes, z, :z)
append_unit_if_needed!(plotattributes, :colorbar_title, u)
Expand Down Expand Up @@ -159,8 +159,8 @@ fixlinecolor!(attr) = ustripattribute!(attr, :line_z)
ustripattribute!(attr, key) =
if haskey(attr, key)
v = attr[key]
u = unit(eltype(v))
attr[key] = ustrip.(u, v)
u = _unit(eltype(v))
attr[key] = _ustrip.(u, v)
return u
else
return NoUnits
Expand All @@ -170,7 +170,7 @@ function ustripattribute!(attr, key, u)
if haskey(attr, key)
v = attr[key]
if eltype(v) <: Quantity
attr[key] = ustrip.(u, v)
attr[key] = _ustrip.(u, v)
end
end
u
Expand Down Expand Up @@ -204,7 +204,7 @@ Plots.protectedstring(s) = ProtectedString(s)
Append unit to labels when appropriate
=====================================#

append_unit_if_needed!(attr, key, u::Units) =
append_unit_if_needed!(attr, key, u) =
append_unit_if_needed!(attr, key, get(attr, key, nothing), u)
# dispatch on the type of `label`
append_unit_if_needed!(attr, key, label::ProtectedString, u) = nothing
Expand Down Expand Up @@ -257,25 +257,36 @@ Plots.locate_annotation(
x::MissingOrQuantity,
y::MissingOrQuantity,
label::PlotText,
) = (ustrip(x), ustrip(y), label)
) = (_ustrip(x), _ustrip(y), label)
Plots.locate_annotation(
sp::Subplot,
x::MissingOrQuantity,
y::MissingOrQuantity,
z::MissingOrQuantity,
label::PlotText,
) = (ustrip(x), ustrip(y), ustrip(z), label)
) = (_ustrip(x), _ustrip(y), _ustrip(z), label)
Plots.locate_annotation(sp::Subplot, rel::NTuple{N,<:MissingOrQuantity}, label) where {N} =
Plots.locate_annotation(sp, ustrip.(rel), label)
Plots.locate_annotation(sp, _ustrip.(rel), label)

#==================#
# ticks and limits #
#==================#
Plots._transform_ticks(ticks::AbstractArray{T}, axis) where {T<:Quantity} =
ustrip.(getaxisunit(axis), ticks)
_ustrip.(getaxisunit(axis), ticks)
Plots.process_limits(lims::AbstractArray{T}, axis) where {T<:Quantity} =
ustrip.(getaxisunit(axis), lims)
_ustrip.(getaxisunit(axis), lims)
Plots.process_limits(lims::Tuple{S,T}, axis) where {S<:Quantity,T<:Quantity} =
ustrip.(getaxisunit(axis), lims)
_ustrip.(getaxisunit(axis), lims)

function _ustrip(u, x)
u isa MixedUnits && return ustrip(uconvert(u, x))
return ustrip(u, x)
end

function _unit(x)
t = eltype(x)
t <: LogScaled && return logunit(t)
return unit(x)
end

end # module
15 changes: 15 additions & 0 deletions test/test_unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,18 @@ end
@test ncodeunits(str) == 4
@test codeunit(str) == UInt8
end

@testset "Logunits plots" begin
u = (1:3)u"B"
v = (1:3)u"dB"
x = (1:3)u"dBV"
y = (1:3)u"V"
pl = plot(u, x)
@test pl isa Plot
@test xguide(pl) == "B"
@test yguide(pl) == "dBV"
@test plot!(pl, v, y) isa Plot
pl = plot(v, y)
@test pl isa Plot
@test plot!(pl, u, x) isa Plot
end

0 comments on commit 082b3bf

Please sign in to comment.