Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds table transforms #45

Merged
merged 18 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/DataAugmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include("./sequence.jl")
include("./items/arrayitem.jl")
include("./projective/base.jl")
include("./items/image.jl")
include("./items/table.jl")
include("./items/keypoints.jl")
include("./items/mask.jl")
include("./projective/compose.jl")
Expand All @@ -36,6 +37,7 @@ include("./projective/affine.jl")
include("./projective/warp.jl")
include("./oneof.jl")
include("./preprocessing.jl")
include("./rowtransforms.jl")
include("./colortransforms.jl")
include("testing.jl")
include("./visualization.jl")
Expand All @@ -49,6 +51,7 @@ export Item,
Sequence,
Project,
Image,
TabularItem,
Keypoints,
Polygon,
ToEltype,
Expand Down Expand Up @@ -88,7 +91,8 @@ export Item,
onehot,
showitems,
showgrid,
Bounds
Bounds,
getcategorypools


end # module
4 changes: 4 additions & 0 deletions src/items/table.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
struct TabularItem{T} <: Item
data::T
columns
end
125 changes: 125 additions & 0 deletions src/rowtransforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
NormalizeRow(dict, cols)

Normalizes the values of a row present in `TabularItem` for the columns
specified in `cols` using `dict`, which contains the column names as
dictionary keys and the mean and standard deviation tuple present as
dictionary values.

## Example

```julia
using DataAugmentation

cols = [:col1, :col2, :col3]
row = (; zip(cols, [1, 2, 3])...)
item = TabularItem(row, cols)
normdict = Dict(:col1 => (1, 1), :col2 => (2, 2))

tfm = NormalizeRow(normdict, [:col1, :col2])
apply(tfm, item)
```
"""
struct NormalizeRow{T, S} <: Transform
dict::T
cols::S
end

function apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing)
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
if col in tfm.cols
colmean, colstd = tfm.dict[col]
val = (val - colmean)/colstd
end
(col, val)
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
end)
end)
return TabularItem(x, item.columns)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And all the other transforms too

TabularItem(x, item.columns)
end

"""
FillMissing(dict, cols)

Fills the missing values of a row present in `TabularItem` for the columns
specified in `cols` using `dict`, which contains the column names as
dictionary keys and the value to fill the column with present as
dictionary values.

## Example

```julia
using DataAugmentation

cols = [:col1, :col2, :col3]
row = (; zip(cols, [1, 2, 3])...)
item = TabularItem(row, cols)
fmdict = Dict(:col1 => 100, :col2 => 100)

tfm = FillMissing(fmdict, [:col1, :col2])
apply(tfm, item)
```
"""
struct FillMissing{T, S} <: Transform
dict::T
cols::S
end

function apply(tfm::FillMissing, item::TabularItem; randstate=nothing)
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
if col in tfm.cols && ismissing(val)
val = tfm.dict[col]
end
(col, val)
end)
TabularItem(x, item.columns)
end

"""
Categorify(dict, cols)

Label encodes the values of a row present in `TabularItem` for the
columns specified in `cols` using `dict`, which contains the column
names as dictionary keys and the unique values of column present
as dictionary values.

if there are any `missing` values in the values to be transformed,
they are replaced by 1.

## Example

```julia
using DataAugmentation

cols = [:col1, :col2, :col3]
row = (; zip(cols, ["cat", 2, 3])...)
item = TabularItem(row, cols)
catdict = Dict(:col1 => ["dog", "cat"])

tfm = Categorify(catdict, [:col1])
apply(tfm, item)
```
"""
struct Categorify{T, S} <: Transform
dict::T
cols::S
function Categorify{T, S}(dict::T, cols::S) where {T, S}
for (col, vals) in dict
if any(ismissing, vals)
dict[col] = filter(!ismissing, vals)
@warn "There is a missing value present for category '$col' which will be removed from Categorify dict"
end
end
new{T, S}(dict, cols)
end
end

Categorify(dict::T, cols::S) where {T, S} = Categorify{T, S}(dict, cols)

function apply(tfm::Categorify, item::TabularItem; randstate=nothing)
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
if col in tfm.cols
val = ismissing(val) ? 1 : findfirst(val .== tfm.dict[col]) + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be findfirst if we use SortedSet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would findfirst work when the input function involves comparing with missing?

Copy link
Member

@darsnack darsnack Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because tfm.dict[col] always contains missing, and missing is treated as any other element in the set.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what you mean. An equality comparison with missing is missing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as a result of this, the whole storing missing in the tfm.dict[col] is not going to work. We'll have to revert to the old filtering way + the conditional shown here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, there's no need to store missing in the dict values either then right? The conditional is required either way.

Copy link
Member

@darsnack darsnack Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's what I meant. Filter the missing out of the dict, and don't add it if it isn't there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I have updated the constructor to use skipmissing and collect for the values containing missing.

end
(col, val)
end)
TabularItem(x, item.columns)
end
3 changes: 2 additions & 1 deletion test/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using CoordinateTransformations
using DataAugmentation: Item, Transform, getrandstate, itemdata, setdata, ComposedProjectiveTransform,
projectionbounds, getprojection, offsetcropbounds,
CroppedProjectiveTransform, getbounds, project, project!, makebuffer, imagetotensor, imagetotensor!,
normalize, normalize!, tensortoimage, denormalize, denormalize!
normalize, normalize!, tensortoimage, denormalize, denormalize!,
NormalizeRow, FillMissing, Categorify, TabularItem
using DataAugmentation: testitem, testapply, testapply!, testprojective
import DataAugmentation: apply, compose
56 changes: 56 additions & 0 deletions test/rowtransforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
include("imports.jl")

@testset ExtendedTestSet "`NormalizeRow`" begin
cols = [:col1, :col2, :col3]
item = TabularItem((; zip(cols, [1, "a", 10])...), cols)
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
cols_to_normalize = [:col1, :col3]
col1_mean, col1_std = 10, 100
col3_mean, col3_std = 100, 10
normdict = Dict(:col1 => (col1_mean, col1_std), :col3 => (col3_mean, col3_std))

tfm = NormalizeRow(normdict, cols_to_normalize)
# @test_nowarn apply(tfm, item)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to delete dangling comment before submission

testapply(tfm, item)
titem = apply(tfm, item)
@test titem.data[:col1] == (item.data[:col1] - col1_mean)/col1_std
@test titem.data[:col3] == (item.data[:col3] - col3_mean)/col3_std
end

@testset ExtendedTestSet "`FillMissing`" begin
cols = [:col1, :col2, :col3]
item = TabularItem((; zip(cols, [1, missing, missing])...), cols)
cols_to_fill = [:col1, :col3]
col1_fmval = 1000.
col3_fmval = 1000.
fmdict = Dict()
fmdict[:col1], fmdict[:col3] = col1_fmval, col3_fmval
manikyabard marked this conversation as resolved.
Show resolved Hide resolved

tfm1 = FillMissing(fmdict, cols_to_fill)
@test_nowarn apply(tfm1, item)
titem = apply(tfm1, item)
@test titem.data[:col1] == (ismissing(item.data[:col1]) ? col1_fmval : item.data[:col1])
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
@test titem.data[:col3] == (ismissing(item.data[:col3]) ? col3_fmval : item.data[:col3])
@test ismissing(titem.data[:col2])

fmdict[:col2] = "d"
tfm2 = FillMissing(fmdict, [:col1, :col2, :col3])
testapply(tfm2, item)
titem2 = apply(tfm2, item)
@test titem2.data[:col2] == (ismissing(item.data[:col2]) ? "d" : item.data[:col2])
end

@testset ExtendedTestSet "`Categorify`" begin
cols = [:col1, :col2, :col3, :col4]
item = TabularItem((; zip(cols, [1, "a", "A", missing])...), cols)
cols_to_categorify = [:col2, :col3, :col4]

categorydict = Dict(:col2 => ["a", "b", "c"], :col3 => ["C", "B", "A"], :col4 => [missing, 10, 20])
tfm = Categorify(categorydict, cols_to_categorify)
@test !any(ismissing.(tfm.dict[:col4]))
manikyabard marked this conversation as resolved.
Show resolved Hide resolved
@test_nowarn apply(tfm, item)
testapply(tfm, item)
titem = apply(tfm, item)
@test titem.data[:col2] == 2
@test titem.data[:col3] == 4
@test titem.data[:col4] == 1
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ include("./imports.jl")
@testset ExtendedTestSet "visualization.jl" begin
include("visualization.jl")
end
@testset ExtendedTestSet "rowtransforms.jl" begin
include("rowtransforms.jl")
end
end