-
-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds table transforms #45
Changes from 16 commits
2fa8f76
c1767c5
2e0b87c
7a57684
3f8db5b
bfaec10
90c3bae
eef71ec
cd42f1d
38f35a0
a033d1d
5a40475
0de742b
128cf97
be5dbff
e90dc4b
1a22791
a03bc72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
struct TabularItem{T} <: Item | ||
data::T | ||
columns | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
NormalizeRow(dict, cols) | ||
|
||
Normalizes the values of a row present in `TabularItem` for the columns | ||
specified in `cols` using `dict`, which contains the column names as | ||
dictionary keys and the mean and standard deviation tuple present as | ||
dictionary values. | ||
|
||
## Example | ||
|
||
```julia | ||
using DataAugmentation | ||
|
||
cols = [:col1, :col2, :col3] | ||
row = (; zip(cols, [1, 2, 3])...) | ||
item = TabularItem(row, cols) | ||
normdict = Dict(:col1 => (1, 1), :col2 => (2, 2)) | ||
|
||
tfm = NormalizeRow(normdict, [:col1, :col2]) | ||
apply(tfm, item) | ||
``` | ||
""" | ||
struct NormalizeRow{T, S} <: Transform | ||
dict::T | ||
cols::S | ||
end | ||
|
||
function apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing) | ||
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val | ||
if col in tfm.cols | ||
colmean, colstd = tfm.dict[col] | ||
val = (val - colmean)/colstd | ||
end | ||
(col, val) | ||
end) | ||
TabularItem(x, item.columns) | ||
end | ||
|
||
""" | ||
FillMissing(dict, cols) | ||
|
||
Fills the missing values of a row present in `TabularItem` for the columns | ||
specified in `cols` using `dict`, which contains the column names as | ||
dictionary keys and the value to fill the column with present as | ||
dictionary values. | ||
|
||
## Example | ||
|
||
```julia | ||
using DataAugmentation | ||
|
||
cols = [:col1, :col2, :col3] | ||
row = (; zip(cols, [1, 2, 3])...) | ||
item = TabularItem(row, cols) | ||
fmdict = Dict(:col1 => 100, :col2 => 100) | ||
|
||
tfm = FillMissing(fmdict, [:col1, :col2]) | ||
apply(tfm, item) | ||
``` | ||
""" | ||
struct FillMissing{T, S} <: Transform | ||
dict::T | ||
cols::S | ||
end | ||
|
||
function apply(tfm::FillMissing, item::TabularItem; randstate=nothing) | ||
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val | ||
if col in tfm.cols && ismissing(val) | ||
val = tfm.dict[col] | ||
end | ||
(col, val) | ||
end) | ||
TabularItem(x, item.columns) | ||
end | ||
|
||
""" | ||
Categorify(dict, cols) | ||
|
||
Label encodes the values of a row present in `TabularItem` for the | ||
columns specified in `cols` using `dict`, which contains the column | ||
names as dictionary keys and the unique values of column present | ||
as dictionary values. | ||
|
||
if there are any `missing` values in the values to be transformed, | ||
they are replaced by 1. | ||
|
||
## Example | ||
|
||
```julia | ||
using DataAugmentation | ||
|
||
cols = [:col1, :col2, :col3] | ||
row = (; zip(cols, ["cat", 2, 3])...) | ||
item = TabularItem(row, cols) | ||
catdict = Dict(:col1 => ["dog", "cat"]) | ||
|
||
tfm = Categorify(catdict, [:col1]) | ||
apply(tfm, item) | ||
``` | ||
""" | ||
struct Categorify{T, S} <: Transform | ||
dict::T | ||
cols::S | ||
function Categorify{T, S}(dict::T, cols::S) where {T, S} | ||
for (col, vals) in dict | ||
if any(ismissing, vals) | ||
dict[col] = filter(!ismissing, vals) | ||
@warn "There is a missing value present for category '$col' which will be removed from Categorify dict" | ||
end | ||
end | ||
new{T, S}(dict, cols) | ||
end | ||
end | ||
|
||
Categorify(dict::T, cols::S) where {T, S} = Categorify{T, S}(dict, cols) | ||
|
||
function apply(tfm::Categorify, item::TabularItem; randstate=nothing) | ||
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val | ||
if col in tfm.cols | ||
val = ismissing(val) ? 1 : findfirst(val .== tfm.dict[col]) + 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I see what you mean. An equality comparison with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think as a result of this, the whole storing missing in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case, there's no need to store There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's what I meant. Filter the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright, I have updated the constructor to use |
||
end | ||
(col, val) | ||
end) | ||
TabularItem(x, item.columns) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
include("imports.jl") | ||
|
||
@testset ExtendedTestSet "`NormalizeRow`" begin | ||
cols = [:col1, :col2, :col3] | ||
item = TabularItem((; zip(cols, [1, "a", 10])...), cols) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cols_to_normalize = [:col1, :col3] | ||
col1_mean, col1_std = 10, 100 | ||
col3_mean, col3_std = 100, 10 | ||
normdict = Dict(:col1 => (col1_mean, col1_std), :col3 => (col3_mean, col3_std)) | ||
|
||
tfm = NormalizeRow(normdict, cols_to_normalize) | ||
# @test_nowarn apply(tfm, item) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to delete dangling comment before submission |
||
testapply(tfm, item) | ||
titem = apply(tfm, item) | ||
@test titem.data[:col1] == (item.data[:col1] - col1_mean)/col1_std | ||
@test titem.data[:col3] == (item.data[:col3] - col3_mean)/col3_std | ||
end | ||
|
||
@testset ExtendedTestSet "`FillMissing`" begin | ||
cols = [:col1, :col2, :col3] | ||
item = TabularItem((; zip(cols, [1, missing, missing])...), cols) | ||
cols_to_fill = [:col1, :col3] | ||
col1_fmval = 1000. | ||
col3_fmval = 1000. | ||
fmdict = Dict() | ||
fmdict[:col1], fmdict[:col3] = col1_fmval, col3_fmval | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
tfm1 = FillMissing(fmdict, cols_to_fill) | ||
@test_nowarn apply(tfm1, item) | ||
titem = apply(tfm1, item) | ||
@test titem.data[:col1] == (ismissing(item.data[:col1]) ? col1_fmval : item.data[:col1]) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test titem.data[:col3] == (ismissing(item.data[:col3]) ? col3_fmval : item.data[:col3]) | ||
@test ismissing(titem.data[:col2]) | ||
|
||
fmdict[:col2] = "d" | ||
tfm2 = FillMissing(fmdict, [:col1, :col2, :col3]) | ||
testapply(tfm2, item) | ||
titem2 = apply(tfm2, item) | ||
@test titem2.data[:col2] == (ismissing(item.data[:col2]) ? "d" : item.data[:col2]) | ||
end | ||
|
||
@testset ExtendedTestSet "`Categorify`" begin | ||
cols = [:col1, :col2, :col3, :col4] | ||
item = TabularItem((; zip(cols, [1, "a", "A", missing])...), cols) | ||
cols_to_categorify = [:col2, :col3, :col4] | ||
|
||
categorydict = Dict(:col2 => ["a", "b", "c"], :col3 => ["C", "B", "A"], :col4 => [missing, 10, 20]) | ||
tfm = Categorify(categorydict, cols_to_categorify) | ||
@test !any(ismissing.(tfm.dict[:col4])) | ||
manikyabard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test_nowarn apply(tfm, item) | ||
testapply(tfm, item) | ||
titem = apply(tfm, item) | ||
@test titem.data[:col2] == 2 | ||
@test titem.data[:col3] == 4 | ||
@test titem.data[:col4] == 1 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And all the other transforms too