-
-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds table transforms #45
Conversation
src/rowtransforms.jl
Outdated
end | ||
for col in tfm.catcols | ||
if ismissing(x[col]) | ||
Setfield.@set! x[col] = "missing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better sentinel value we can use than the literal string "missing"? Maybe missing
, nothing
, or a symbol :missing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we discussed during the call how we can use missing
here. I'll add a review that includes that discussion.
src/rowtransforms.jl
Outdated
|
||
function DataAugmentation.apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing) | ||
x = (; zip(item.columns, [data for data in item.data])...) | ||
for col in tfm.normcols |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of iterating the columns twice and having setfield repeatedly construct a namedtuple, perhaps look into a helper function that does the normalization given the tfm
, column name and value? The function could check if the column is in normcols
and transform it with normstats
if it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah do we even need Setfield anymore if we're standardizing on a NamedTuple
? We could just build the transformed data as a vector or something then construct a NamedTuple
at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we should be able to build it at the end.
src/rowtransforms.jl
Outdated
|
||
function DataAugmentation.apply(tfm::Categorify, item::TabularItem; randstate=nothing) | ||
x = (; zip(item.columns, [data for data in item.data])...) | ||
for col in tfm.categorycols |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about double iteration of the columns. I suppose it applies to FillMissing
as well :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean something like this?
function tfmrowvals(tfm::NormalizeRow, col, val)
if col in tfm.cols
colmean, colstd = tfm.dict[col]
val = (val - colmean)/colstd
end
(col, val)
end
function apply(tfm::NormalizeRow, item; randstate=nothing)
TabularItem((;
tfmrowvals.(
[tfm for _ in 1:length(item.columns)],
item.columns,
[val for val in item.data])...
),
item.columns
)
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is better than the current implementation, we can even have a single apply
which works on Union
of all the transforms, and different methods for tfmrowvals
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably still need to dispatch on each type separately in order to know which tfmrowvals
function to call, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah there will probably be 3 tfmrowvals
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function apply(tfm::NormalizeRow, item; randstate=nothing)
x = NamedTuple(Iterators.map(item.cols, item.data) do col, val
if col in tfm.cols
colmean, colstd = tfm.dict[col]
val = (val - colmean)/colstd
end
(col, val)
end)
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think you are using TAB for indents. Could you convert to using 4 spaces?
src/rowtransforms.jl
Outdated
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing) | ||
x = (; zip(item.columns, [data for data in item.data])...) | ||
for col in tfm.contcols | ||
if ismissing(x[col]) | ||
Setfield.@set! x[col] = tfm.fmvals[col] | ||
end | ||
end | ||
for col in tfm.catcols | ||
if ismissing(x[col]) | ||
Setfield.@set! x[col] = "missing" | ||
end | ||
end | ||
TabularItem(x, item.columns) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing) | |
x = (; zip(item.columns, [data for data in item.data])...) | |
for col in tfm.contcols | |
if ismissing(x[col]) | |
Setfield.@set! x[col] = tfm.fmvals[col] | |
end | |
end | |
for col in tfm.catcols | |
if ismissing(x[col]) | |
Setfield.@set! x[col] = "missing" | |
end | |
end | |
TabularItem(x, item.columns) | |
end | |
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing) | |
x = (; zip(item.columns, [data for data in item.data])...) | |
for col in tfm.contcols | |
if ismissing(x[col]) | |
Setfield.@set! x[col] = tfm.fmvals[col] | |
end | |
end | |
TabularItem(x, item.columns) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless we want to allow the catergorical missing
s to be filled too
src/rowtransforms.jl
Outdated
if ismissing(x[col]) | ||
Setfield.@set! x[col] = "missing" | ||
end | ||
Setfield.@set! x[col] = tfm.pooldict[col].invindex[x[col]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ismissing(x[col]) | |
Setfield.@set! x[col] = "missing" | |
end | |
Setfield.@set! x[col] = tfm.pooldict[col].invindex[x[col]] | |
if ismissing(x[col]) | |
Setfield.@set! x[col] = 0 | |
else | |
Setfield.@set! x[col] = findfirst(tfm.categories .== col) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this give the same value for a column?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also seeing that the Embedding
layers won't work with 0 indexing, we should probably try to avoid it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a typo, it should be x[col]
.
Then let's make missing == 1
and do + 1
for the other columns
src/rowtransforms.jl
Outdated
function getcategorypools(catdict, catcols) | ||
pooldict = Dict() | ||
for col in catcols | ||
catarray = CategoricalArrays.categorical(catdict[col]) | ||
CategoricalArrays.levels!(catarray, ["missing", CategoricalArrays.levels(catarray)...]) | ||
pooldict[col] = catarray.pool | ||
end | ||
pooldict | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function getcategorypools(catdict, catcols) | |
pooldict = Dict() | |
for col in catcols | |
catarray = CategoricalArrays.categorical(catdict[col]) | |
CategoricalArrays.levels!(catarray, ["missing", CategoricalArrays.levels(catarray)...]) | |
pooldict[col] = catarray.pool | |
end | |
pooldict | |
end |
src/rowtransforms.jl
Outdated
struct Categorify <: DataAugmentation.Transform | ||
pooldict | ||
categorycols | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct Categorify <: DataAugmentation.Transform | |
pooldict | |
categorycols | |
end | |
struct Categorify{T, S} <: DataAugmentation.Transform | |
categories::T | |
categorycols::S | |
end |
Two changes: swap to categories
to just be a vector of the categories. I don't think we need the complexity of categorical arrays when the mapping is just the index in a list of categories passed by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce the complexity, we could just use the catdict
used in getcategorypool
directly for Categorify
. A vector of vectors (or a NamedTuple
) with the classes for each categorical column could work as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this will work if categories
is just a vector of categorical column names as we'll have to replace the class for a categorical column with an integer, and for doing this we'll need information about all the classes which are present in a particular column.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sorry it should be a NamedTuple
/Dict
.
src/rowtransforms.jl
Outdated
struct NormalizeRow <: DataAugmentation.Transform | ||
normstats | ||
normcols | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct NormalizeRow <: DataAugmentation.Transform | |
normstats | |
normcols | |
end | |
struct NormalizeRow{T, S} <: DataAugmentation.Transform | |
normstats::T | |
normcols::S | |
end |
src/rowtransforms.jl
Outdated
struct FillMissing <: DataAugmentation.Transform | ||
fmvals | ||
contcols | ||
catcols | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct FillMissing <: DataAugmentation.Transform | |
fmvals | |
contcols | |
catcols | |
end | |
struct FillMissing{T, S} <: DataAugmentation.Transform | |
fmvals::T | |
contcols::S | |
end |
src/rowtransforms.jl
Outdated
|
||
function DataAugmentation.apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing) | ||
x = (; zip(item.columns, [data for data in item.data])...) | ||
for col in tfm.normcols |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah do we even need Setfield anymore if we're standardizing on a NamedTuple
? We could just build the transformed data as a vector or something then construct a NamedTuple
at the end.
Co-authored-by: lorenzoh <[email protected]>
…ard/DataAugmentation.jl into manikyabard/tabulartfms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if we should have a transform that maps to Flux.OneHotArray
instead of just categorical indices.
src/rowtransforms.jl
Outdated
x = [val for val in item.data] | ||
for col in tfm.categorycols | ||
idx = findfirst(col .== item.columns) | ||
x[idx] = ismissing(x[idx]) ? 1 : findfirst(skipmissing(x[idx] .== tfm.catdict[col])) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x[idx] = ismissing(x[idx]) ? 1 : findfirst(skipmissing(x[idx] .== tfm.catdict[col])) + 1 | |
x[idx] = ismissing(x[idx]) ? 1 : findfirst(x[idx] .== tfm.catdict[col]) + 1 |
No need for skipmissing
here, right? x[idx]
is a value and tfm.catdict[col]
is a vector of categorical values (which doesn't contain missing
). findfirst
is just assigning an index based on which symbol in tfm.catdict[col]
matches x[idx]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially I was thinking if someone creates catdict
using unique
or something, and if somehow missing
is a part of this vector then an error could be thrown, but yeah it might just be better to remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably better to map(v -> filter!(!ismissing, v), values(catdict))
when constructing the transform. We could throw a warning when that happens too.
src/rowtransforms.jl
Outdated
FillMissing(fmvals::T, fmcols::S) where {T, S} = FillMissing{T, S}(fmvals, fmcols) | ||
|
||
function DataAugmentation.apply(tfm::FillMissing, item::TabularItem; randstate=nothing) | ||
x = [val for val in item.data] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does collect(item.data)
not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to use that.
Co-authored-by: Kyle Daruwalla <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking really clean now; nice job!
src/rowtransforms.jl
Outdated
cols::S | ||
function Categorify{T, S}(dict::T, cols::S) where {T, S} | ||
for (col, vals) in dict | ||
dict[col] = append!([], [missing], collect(skipmissing(Set(vals)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here you want to do SortedSet
from DataStructures.jl. And you don't need skipmissing
first, cause pushing missing
onto a set that already contains it is a no-op (AbstractSet
s can't contain duplicates). Since it is sorted, missing
will always map to the same index too (addressing @ToucheSir's concern from the call).
val = (val - colmean)/colstd | ||
end | ||
(col, val) | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
end) | |
end) | |
return TabularItem(x, item.columns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And all the other transforms too
function apply(tfm::Categorify, item; randstate=nothing) | ||
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val | ||
if col in tfm.cols | ||
val = ismissing(val) ? 1 : findfirst(val .== tfm.dict[col]) + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be findfirst
if we use SortedSet
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would findfirst
work when the input function involves comparing with missing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, because tfm.dict[col]
always contains missing
, and missing
is treated as any other element in the set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see what you mean. An equality comparison with missing
is missing
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think as a result of this, the whole storing missing in the tfm.dict[col]
is not going to work. We'll have to revert to the old filtering way + the conditional shown here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, there's no need to store missing
in the dict values either then right? The conditional is required either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's what I meant. Filter the missing
out of the dict, and don't add it if it isn't there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I have updated the constructor to use skipmissing
and collect
for the values containing missing
.
src/rowtransforms.jl
Outdated
TabularItem(x, item.columns) | ||
Categorify(dict::T, cols::S) where {T, S} = Categorify{T, S}(dict, cols) | ||
|
||
function apply(tfm::NormalizeRow, item; randstate=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is randstate
an artifact from Python? Or is it part of the DataAugmentation interface? What is its role here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can't see which of these transforms requires an RNG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, even though randstate
isn't required for the tabular transformations, I put it there because it was a part of the transformation interface. I think internally for compositions, apply
is called along with randstate
args so everything might not work without it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's because of how the dispatch is set up.
Co-authored-by: Kyle Daruwalla <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just fix the tab issue
Co-authored-by: Kyle Daruwalla <[email protected]>
Still needs some tests |
The tests should be fixed now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits on the tests. Is there a way to trigger a CI run on the latest commit?
test/rowtransforms.jl
Outdated
normdict = Dict(:col1 => (col1_mean, col1_std), :col3 => (col3_mean, col3_std)) | ||
|
||
tfm = NormalizeRow(normdict, cols_to_normalize) | ||
# @test_nowarn apply(tfm, item) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to delete dangling comment before submission
@lorenzoh will have to trigger it for "first-time contributors." |
Co-authored-by: Brian Chen <[email protected]>
Adds
TabularItem
for holding table row values and some transformations for it.