Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fcase() segfault (#6452) #6451

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ rowwiseDT(

5. Queries like `DT[, min(x):max(x)]` now work as expected, i.e. the same as `DT[, seq(min(x), max(x))]` or `with(DT, min(x):max(x))`, [#2069](https://github.com/Rdatatable/data.table/issues/2069). Shorthand like `DT[, a:b]` meaning "select from columns `a` through `b`" still works. Thanks to @franknarf1 for reporting and @jangorecki for the fix.

6. Fixed a segfault in `fcase()`, [#6448](https://github.com/Rdatatable/data.table/issues/6448). Thanks @ethanbsmith for reporting with reprex, @aitap for finding the root cause, and @MichaelChirico for the PR.

## NOTES

1. Tests run again when some Suggests packages are missing, [#6411](https://github.com/Rdatatable/data.table/issues/6411). Thanks @aadler for the note and @MichaelChirico for the fix.
Expand Down
11 changes: 11 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -19182,3 +19182,14 @@ test(2285.09, merge(merge(y, x), z), data.table(a=3L, key="a"))
test(2285.10, merge(merge(y, z), x), data.table(a=3L, key="a"))
test(2285.11, merge(merge(z, x), y), data.table(a=3L, key="a"))
test(2285.12, merge(merge(z, y), x), data.table(a=3L, key="a"))

# ensure proper PROTECT() within fcase, #6448
x <- 1:3
test(2286,
fcase(
x<2, structure(list(1), class = "foo"),
x<3, structure(list(2), class = "foo"),
# Force gc() and some allocations which have a good chance at landing in the region that was earlier left unprotected
{ gc(full = TRUE); replicate(10, FALSE); x<4 },
`attr<-`(list(3), "class", "foo")),
structure(list(1, 2, 3), class = "foo"))
68 changes: 37 additions & 31 deletions src/fifelse.c
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ SEXP fcaseR(SEXP rho, SEXP args) {
"Note that the default argument must be named explicitly, e.g., default=0"), narg - 2);
}
int nprotect=0, l;
int64_t len0=0, len1=0, len2=0;
SEXP ans=R_NilValue, value0=R_NilValue, tracker=R_NilValue, whens=R_NilValue, thens=R_NilValue;
int64_t n_ans=0, n_this_arg=0, n_undecided=0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_undecided sounds good to indicate the count of the unresolved elements/cases. An alternative I can think of is n_pending (as in tracking how many elements are left to be evaluated)

Is n_this_arg meant to represent the length of the thens argument that is being processed in the current iteration? (I only see it being assigned to it at 285 below, so I'm wondering if it could be renamed specific to it)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yes you're right about n_this_arg. I'm roughly indifferent between n_undecided and n_pending; both are way better than len2 :)

SEXP ans=R_NilValue, tracker=R_NilValue, whens=R_NilValue, thens=R_NilValue;
SEXP ans_class, ans_levels;
PROTECT_INDEX Iwhens, Ithens;
PROTECT_WITH_INDEX(whens, &Iwhens); nprotect++;
PROTECT_WITH_INDEX(thens, &Ithens); nprotect++;
SEXPTYPE type0=NILSXP;
SEXPTYPE ans_type=NILSXP;
// naout means if the output is scalar logic na
bool imask = true, naout = false, idefault = false;
bool ans_is_factor;
int *restrict p = NULL;
const int n = narg/2;
for (int i=0; i<n; ++i) {
Expand All @@ -238,35 +240,39 @@ SEXP fcaseR(SEXP rho, SEXP args) {
const int *restrict pwhens = LOGICAL(whens);
l = 0;
if (i == 0) {
len0 = xlength(whens);
len2 = len0;
type0 = TYPEOF(thens);
value0 = thens;
ans = PROTECT(allocVector(type0, len0)); nprotect++;
n_ans = xlength(whens);
n_undecided = n_ans;
ans_type = TYPEOF(thens);
ans_class = PROTECT(getAttrib(thens, R_ClassSymbol)); nprotect++;
ans_is_factor = isFactor(thens);
if (ans_is_factor) {
ans_levels = PROTECT(getAttrib(thens, R_LevelsSymbol)); nprotect++;
}
ans = PROTECT(allocVector(ans_type, n_ans)); nprotect++;
copyMostAttrib(thens, ans);
tracker = PROTECT(allocVector(INTSXP, len0)); nprotect++;
tracker = PROTECT(allocVector(INTSXP, n_ans)); nprotect++;
p = INTEGER(tracker);
} else {
imask = false;
naout = xlength(thens) == 1 && TYPEOF(thens) == LGLSXP && LOGICAL(thens)[0]==NA_LOGICAL;
if (xlength(whens) != len0) {
if (xlength(whens) != n_ans) {
// no need to check `idefault` here because the con for default is always `TRUE`
error(_("Argument #%d has length %lld which differs from that of argument #1 (%lld). "
"Please make sure all logical conditions have the same length."),
i*2+1, (long long)xlength(whens), (long long)len0);
i*2+1, (long long)xlength(whens), (long long)n_ans);
}
if (!naout && TYPEOF(thens) != type0) {
if (!naout && TYPEOF(thens) != ans_type) {
if (idefault) {
error(_("Resulting value is of type %s but 'default' is of type %s. "
"Please make sure that both arguments have the same type."), type2char(type0), type2char(TYPEOF(thens)));
"Please make sure that both arguments have the same type."), type2char(ans_type), type2char(TYPEOF(thens)));
} else {
error(_("Argument #%d is of type %s, however argument #2 is of type %s. "
"Please make sure all output values have the same type."),
i*2+2, type2char(TYPEOF(thens)), type2char(type0));
i*2+2, type2char(TYPEOF(thens)), type2char(ans_type));
}
}
if (!naout) {
if (!R_compute_identical(PROTECT(getAttrib(value0, R_ClassSymbol)), PROTECT(getAttrib(thens, R_ClassSymbol)), 0)) {
if (!R_compute_identical(ans_class, PROTECT(getAttrib(thens, R_ClassSymbol)), 0)) {
if (idefault) {
error(_("Resulting value has different class than 'default'. "
"Please make sure that both arguments have the same class."));
Expand All @@ -275,35 +281,35 @@ SEXP fcaseR(SEXP rho, SEXP args) {
"Please make sure all output values have the same class."), i*2+2);
}
}
UNPROTECT(2); // class(value0), class(thens)
UNPROTECT(1); // class(thens)
}
if (!naout && isFactor(value0)) {
if (!R_compute_identical(PROTECT(getAttrib(value0, R_LevelsSymbol)), PROTECT(getAttrib(thens, R_LevelsSymbol)), 0)) {
if (!naout && ans_is_factor) {
if (!R_compute_identical(ans_levels, PROTECT(getAttrib(thens, R_LevelsSymbol)), 0)) {
if (idefault) {
error(_("Resulting value and 'default' are both type factor but their levels are different."));
} else {
error(_("Argument #2 and argument #%d are both factor but their levels are different."), i*2+2);
}
}
UNPROTECT(2); // levels(value0), levels(thens)
UNPROTECT(1); // levels(thens)
}
}
len1 = xlength(thens);
if (len1 != len0 && len1 != 1) {
n_this_arg = xlength(thens);
if (n_this_arg != n_ans && n_this_arg != 1) {
if (idefault) {
error(_("Length of 'default' must be 1 or %lld."), (long long)len0);
error(_("Length of 'default' must be 1 or %lld."), (long long)n_ans);
} else {
error(_("Length of output value #%d (%lld) must either be 1 or match the length of the logical condition (%lld)."), i*2+2, (long long)len1, (long long)len0);
error(_("Length of output value #%d (%lld) must either be 1 or match the length of the logical condition (%lld)."), i*2+2, (long long)n_this_arg, (long long)n_ans);
}
}
int64_t thenMask = len1>1 ? INT64_MAX : 0;
int64_t thenMask = n_this_arg>1 ? INT64_MAX : 0;
switch(TYPEOF(ans)) {
case LGLSXP: {
const int *restrict pthens;
if (!naout) pthens = LOGICAL(thens); // the content is not useful if out is NA_LOGICAL scalar
int *restrict pans = LOGICAL(ans);
const int pna = NA_LOGICAL;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -320,7 +326,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (!naout) pthens = INTEGER(thens); // the content is not useful if out is NA_LOGICAL scalar
int *restrict pans = INTEGER(ans);
const int pna = NA_INTEGER;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -338,7 +344,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
double *restrict pans = REAL(ans);
const double na_double = INHERITS(ans, char_integer64) ? NA_INT64_D : NA_REAL;
const double pna = na_double;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -355,7 +361,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (!naout) pthens = COMPLEX(thens); // the content is not useful if out is NA_LOGICAL scalar
Rcomplex *restrict pans = COMPLEX(ans);
const Rcomplex pna = NA_CPLX;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
pans[idx] = naout ? pna : pthens[idx & thenMask];
Expand All @@ -371,7 +377,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
const SEXP *restrict pthens=NULL;
if (!naout) pthens = STRING_PTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar
const SEXP pna = NA_STRING;
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
SET_STRING_ELT(ans, idx, naout ? pna : pthens[idx & thenMask]);
Expand All @@ -388,7 +394,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
// assign the NA values as it does for other atomic types
const SEXP *restrict pthens=NULL;
if (!naout) pthens = SEXPPTR_RO(thens); // the content is not useful if out is NA_LOGICAL scalar
for (int64_t j=0; j<len2; ++j) {
for (int64_t j=0; j<n_undecided; ++j) {
const int64_t idx = imask ? j : p[j];
if (pwhens[idx]==1) {
if (!naout) SET_VECTOR_ELT(ans, idx, pthens[idx & thenMask]);
Expand All @@ -403,7 +409,7 @@ SEXP fcaseR(SEXP rho, SEXP args) {
if (l==0) {
break; // stop early as nothing left to do
}
len2 = l;
n_undecided = l;
}
UNPROTECT(nprotect); // whens, thens, ans, tracker
return ans;
Expand Down
Loading