diff --git a/src/transforms/logratio.jl b/src/transforms/logratio.jl index 8453a1b0..47e59306 100644 --- a/src/transforms/logratio.jl +++ b/src/transforms/logratio.jl @@ -23,22 +23,27 @@ assertions(::LogRatio) = [scitypeassert(Continuous)] function applyfeat(transform::LogRatio, feat, prep) cols = Tables.columns(feat) - onames = Tables.columnnames(cols) - varnames = collect(onames) + names = Tables.columnnames(cols) + vars = collect(names) + + # perform closure for full revertibility + cfeat, ccache = apply(Closure(), feat) # reference variable - rvar = refvar(transform, varnames) - _assert(rvar ∈ varnames, "invalid reference variable") - rind = findfirst(==(rvar), varnames) + rvar = refvar(transform, vars) + _assert(rvar ∈ vars, "invalid reference variable") + + # reference index + rind = findfirst(==(rvar), vars) # permute columns if necessary - perm = rind ≠ lastindex(varnames) + perm = rind ≠ lastindex(vars) pfeat = if perm - popat!(varnames, rind) - push!(varnames, rvar) - feat |> Select(varnames) + popat!(vars, rind) + push!(vars, rvar) + cfeat |> Select(vars) else - feat + cfeat end # apply transform @@ -46,33 +51,38 @@ function applyfeat(transform::LogRatio, feat, prep) Y = applymatrix(transform, X) # new variable names - newnames = newvars(transform, varnames) + newnames = newvars(transform, vars) # return same table type 𝒯 = (; zip(newnames, eachcol(Y))...) newfeat = 𝒯 |> Tables.materializer(feat) - newfeat, (rind, perm, onames) + newfeat, (ccache, perm, rind, vars) end function revertfeat(transform::LogRatio, newfeat, fcache) + # retrieve cache + ccache, perm, rind, vars = fcache + # revert transform Y = Tables.matrix(newfeat) X = revertmatrix(transform, Y) - - # retrieve cache - rind, perm, onames = fcache + pfeat = (; zip(vars, eachcol(X))...) # revert the permutation if necessary - if perm - n = length(onames) + cfeat = if perm + n = length(vars) inds = collect(1:(n - 1)) insert!(inds, rind, n) - X = X[:, inds] + pfeat |> Select(inds) + else + pfeat end + # revert closure for full revertibility + 𝒯 = revert(Closure(), cfeat, ccache) + # return same table type - 𝒯 = (; zip(onames, eachcol(X))...) 𝒯 |> Tables.materializer(newfeat) end diff --git a/test/transforms/logratio.jl b/test/transforms/logratio.jl index 9414d470..c0443202 100644 --- a/test/transforms/logratio.jl +++ b/test/transforms/logratio.jl @@ -12,21 +12,22 @@ n, c = apply(T, t) @test Tables.schema(n).names == (:ARL1, :ARL2) @test n == t |> ALR(:c) - talr = revert(T, n, c) + r = revert(T, n, c) + @test Tables.matrix(r) ≈ Tables.matrix(t) + T = CLR() n, c = apply(T, t) @test Tables.schema(n).names == (:CLR1, :CLR2, :CLR3) - tclr = revert(T, n, c) + r = revert(T, n, c) + @test Tables.matrix(r) ≈ Tables.matrix(t) + T = ILR() n, c = apply(T, t) @test Tables.schema(n).names == (:ILR1, :ILR2) @test n == t |> ILR(:c) - tilr = revert(T, n, c) - @test Tables.matrix(talr) ≈ Tables.matrix(tclr) - @test Tables.matrix(tclr) ≈ Tables.matrix(tilr) - @test Tables.matrix(talr) ≈ Tables.matrix(tilr) + r = revert(T, n, c) + @test Tables.matrix(r) ≈ Tables.matrix(t) - # permute columns a = [1.0, 0.0, 1.0] b = [2.0, 2.0, 2.0] c = [3.0, 3.0, 0.0] @@ -35,10 +36,23 @@ T = ALR(:c) n1, c1 = apply(T, t1) + r1 = revert(T, n1, c1) + n2, c2 = apply(T, t2) + r2 = revert(T, n2, c2) + @test n1 == n2 + @test Tables.matrix(r1) ≈ Tables.matrix(t1) + @test Tables.schema(r1).names == (:a, :c, :b) + @test Tables.matrix(r2) ≈ Tables.matrix(t2) + @test Tables.schema(r2).names == (:c, :a, :b) + + T = ILR(:c) + n1, c1 = apply(T, t1) + r1 = revert(T, n1, c1) n2, c2 = apply(T, t2) + r2 = revert(T, n2, c2) @test n1 == n2 - tₒ = revert(T, n1, c1) - @test Tables.schema(tₒ).names == (:a, :c, :b) - tₒ = revert(T, n2, c2) - @test Tables.schema(tₒ).names == (:c, :a, :b) + @test Tables.matrix(r1) ≈ Tables.matrix(t1) + @test Tables.schema(r1).names == (:a, :c, :b) + @test Tables.matrix(r2) ≈ Tables.matrix(t2) + @test Tables.schema(r2).names == (:c, :a, :b) end