Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds table transforms #45

Merged
merged 18 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["lorenzoh <[email protected]>"]
version = "0.2.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ColorBlendModes = "60508b50-96e1-4007-9d6c-f475c410f16b"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
7 changes: 6 additions & 1 deletion src/DataAugmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DataAugmentation

using ColorBlendModes
using CoordinateTransformations
using CategoricalArrays
using Distributions: Sampleable, Uniform, Categorical
using ImageDraw
using Images
Expand All @@ -28,6 +29,7 @@ include("./sequence.jl")
include("./items/arrayitem.jl")
include("./projective/base.jl")
include("./items/image.jl")
include("./items/table.jl")
include("./items/keypoints.jl")
include("./items/mask.jl")
include("./projective/compose.jl")
Expand All @@ -36,6 +38,7 @@ include("./projective/affine.jl")
include("./projective/warp.jl")
include("./oneof.jl")
include("./preprocessing.jl")
include("./rowtransforms.jl")
include("./colortransforms.jl")
include("testing.jl")
include("./visualization.jl")
Expand All @@ -49,6 +52,7 @@ export Item,
Sequence,
Project,
Image,
TabularItem,
Keypoints,
Polygon,
ToEltype,
Expand Down Expand Up @@ -88,7 +92,8 @@ export Item,
onehot,
showitems,
showgrid,
Bounds
Bounds,
getcategorypools


end # module
4 changes: 4 additions & 0 deletions src/items/table.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
struct TabularItem{T} <: DataAugmentation.Item
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
data::T
columns
end
60 changes: 60 additions & 0 deletions src/rowtransforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
struct NormalizeRow <: DataAugmentation.Transform
normstats
normcols
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct NormalizeRow <: DataAugmentation.Transform
normstats
normcols
end
struct NormalizeRow{T, S} <: DataAugmentation.Transform
normstats::T
normcols::S
end


struct Categorify <: DataAugmentation.Transform
pooldict
categorycols
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct Categorify <: DataAugmentation.Transform
pooldict
categorycols
end
struct Categorify{T, S} <: DataAugmentation.Transform
categories::T
categorycols::S
end

Two changes: swap to categories to just be a vector of the categories. I don't think we need the complexity of categorical arrays when the mapping is just the index in a list of categories passed by the user.

Copy link
Contributor Author

@manikyabard manikyabard Jun 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reduce the complexity, we could just use the catdict used in getcategorypool directly for Categorify. A vector of vectors (or a NamedTuple) with the classes for each categorical column could work as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this will work if categories is just a vector of categorical column names as we'll have to replace the class for a categorical column with an integer, and for doing this we'll need information about all the classes which are present in a particular column.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sorry it should be a NamedTuple/Dict.


struct FillMissing <: DataAugmentation.Transform
fmvals
contcols
catcols
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct FillMissing <: DataAugmentation.Transform
fmvals
contcols
catcols
end
struct FillMissing{T, S} <: DataAugmentation.Transform
fmvals::T
contcols::S
end


function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing)
x = (; zip(item.columns, [data for data in item.data])...)
for col in tfm.contcols
if ismissing(x[col])
Setfield.@set! x[col] = tfm.fmvals[col]
end
end
for col in tfm.catcols
if ismissing(x[col])
Setfield.@set! x[col] = "missing"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better sentinel value we can use than the literal string "missing"? Maybe missing, nothing, or a symbol :missing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we discussed during the call how we can use missing here. I'll add a review that includes that discussion.

end
end
TabularItem(x, item.columns)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing)
x = (; zip(item.columns, [data for data in item.data])...)
for col in tfm.contcols
if ismissing(x[col])
Setfield.@set! x[col] = tfm.fmvals[col]
end
end
for col in tfm.catcols
if ismissing(x[col])
Setfield.@set! x[col] = "missing"
end
end
TabularItem(x, item.columns)
end
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing)
x = (; zip(item.columns, [data for data in item.data])...)
for col in tfm.contcols
if ismissing(x[col])
Setfield.@set! x[col] = tfm.fmvals[col]
end
end
TabularItem(x, item.columns)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we want to allow the catergorical missings to be filled too


function DataAugmentation.apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing)
x = (; zip(item.columns, [data for data in item.data])...)
for col in tfm.normcols
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of iterating the columns twice and having setfield repeatedly construct a namedtuple, perhaps look into a helper function that does the normalization given the tfm, column name and value? The function could check if the column is in normcols and transform it with normstats if it is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah do we even need Setfield anymore if we're standardizing on a NamedTuple? We could just build the transformed data as a vector or something then construct a NamedTuple at the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should be able to build it at the end.

colmean, colstd = tfm.normstats[col]
Setfield.@set! x[col] = (x[col] - colmean)/colstd
end
TabularItem(x, item.columns)
end

function DataAugmentation.apply(tfm::Categorify, item::TabularItem; randstate=nothing)
x = (; zip(item.columns, [data for data in item.data])...)
for col in tfm.categorycols
Copy link
Member

@ToucheSir ToucheSir Jun 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about double iteration of the columns. I suppose it applies to FillMissing as well :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean something like this?

function tfmrowvals(tfm::NormalizeRow, col, val)
    if col in tfm.cols
        colmean, colstd = tfm.dict[col]
        val = (val - colmean)/colstd
    end
    (col, val)
end

function apply(tfm::NormalizeRow, item; randstate=nothing)
    TabularItem((; 
            tfmrowvals.(
                [tfm for _ in 1:length(item.columns)],
                item.columns,
                [val for val in item.data])...
            ), 
        item.columns
    )
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is better than the current implementation, we can even have a single apply which works on Union of all the transforms, and different methods for tfmrowvals.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably still need to dispatch on each type separately in order to know which tfmrowvals function to call, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah there will probably be 3 tfmrowvals methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function apply(tfm::NormalizeRow, item; randstate=nothing)
    x = NamedTuple(Iterators.map(item.cols, item.data) do col, val
        if col in tfm.cols
            colmean, colstd = tfm.dict[col]
            val = (val - colmean)/colstd
        end
        (col, val)
    end)
end

if ismissing(x[col])
Setfield.@set! x[col] = "missing"
end
Setfield.@set! x[col] = tfm.pooldict[col].invindex[x[col]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ismissing(x[col])
Setfield.@set! x[col] = "missing"
end
Setfield.@set! x[col] = tfm.pooldict[col].invindex[x[col]]
if ismissing(x[col])
Setfield.@set! x[col] = 0
else
Setfield.@set! x[col] = findfirst(tfm.categories .== col)
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this give the same value for a column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also seeing that the Embedding layers won't work with 0 indexing, we should probably try to avoid it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a typo, it should be x[col].

Then let's make missing == 1 and do + 1 for the other columns

end
TabularItem(x, item.columns)
end

function getcategorypools(catdict, catcols)
pooldict = Dict()
for col in catcols
catarray = CategoricalArrays.categorical(catdict[col])
CategoricalArrays.levels!(catarray, ["missing", CategoricalArrays.levels(catarray)...])
pooldict[col] = catarray.pool
end
pooldict
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function getcategorypools(catdict, catcols)
pooldict = Dict()
for col in catcols
catarray = CategoricalArrays.categorical(catdict[col])
CategoricalArrays.levels!(catarray, ["missing", CategoricalArrays.levels(catarray)...])
pooldict[col] = catarray.pool
end
pooldict
end