Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix decimal bug #10

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions dbms/src/AggregateFunctions/AggregateFunctionAvg.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Aggregate
}
String getName() const override { return "avg"; }

AggregateDataPtr adjust_place(AggregateDataPtr place) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions are implemented twice, in both avg and sum. They can simply be static and put somewhere common.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend to name these two functions more specifically as alignDecimalPlace, put more comments, and use the camel-case naming as CH does (i.e. alignDecimalPlace rather than align_decimal_place).

size_t pad = (32 - uint64_t(place) % 32) % 32;
return place + pad;
}

AggregateDataPtr adjust_place(ConstAggregateDataPtr place) const {
size_t pad = (32 - uint64_t(place) % 32) % 32;
return (char*)(place + pad);
}

DataTypePtr getReturnType() const override
{
if constexpr (IsDecimal<T>)
Expand All @@ -54,33 +64,48 @@ class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Aggregate

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (IsDecimal<T>) {
place = adjust_place(place);
}
this->data(place).sum += static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
++this->data(place).count;
}

void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
if constexpr (IsDecimal<T>) {
place = adjust_place(place);
rhs = adjust_place(rhs);
}
this->data(place).sum += this->data(rhs).sum;
this->data(place).count += this->data(rhs).count;
}

void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
if constexpr (IsDecimal<T>) {
place = adjust_place(place);
}
writeBinary(this->data(place).sum, buf);
writeVarUInt(this->data(place).count, buf);
}

void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
if constexpr (IsDecimal<T>) {
place = adjust_place(place);
}
readBinary(this->data(place).sum, buf);
readVarUInt(this->data(place).count, buf);
}

void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
if constexpr (IsDecimal<T>)
if constexpr (IsDecimal<T>) {
place = adjust_place(place);
static_cast<ColumnDecimal &>(to).getData().push_back(
this->data(place).sum.getAvg(this->data(place).count, result_prec, result_scale));
}
else
static_cast<ColumnFloat64 &>(to).getData().push_back(
static_cast<Float64>(this->data(place).sum) / this->data(place).count);
Expand All @@ -89,7 +114,7 @@ class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Aggregate
void create(AggregateDataPtr place) const override {
using Data = AggregateFunctionAvgData<typename NearestFieldType<T>::Type>;
if constexpr (IsDecimal<T>)
new (place) Data(result_prec, result_scale);
new (adjust_place(place)) Data(result_prec, result_scale);
else
new (place) Data;
}
Expand Down
29 changes: 26 additions & 3 deletions dbms/src/AggregateFunctions/AggregateFunctionSum.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data, Agg
PrecType result_prec;

AggregateFunctionSum(){}

AggregateFunctionSum(PrecType prec, ScaleType scale) {
SumDecimalInferer::infer(prec, scale, result_prec, result_scale);
};
Expand All @@ -120,35 +120,58 @@ class AggregateFunctionSum final : public IAggregateFunctionDataHelper<Data, Agg
}
}

AggregateDataPtr adjust_place(AggregateDataPtr place) const {
size_t pad = (32 - uint64_t(place) % 32) % 32;
return place + pad;
}

AggregateDataPtr adjust_place(ConstAggregateDataPtr place) const {
size_t pad = (32 - uint64_t(place) % 32) % 32;
return (char*)(place + pad);
}

void create(AggregateDataPtr place) const override {
if constexpr (IsDecimal<T> && IsDecimal<TResult>)
new (place) Data(result_prec, result_scale);
if constexpr (IsDecimal<T> && IsDecimal<TResult>) {
new (adjust_place(place)) Data(result_prec, result_scale);
}
else
new (place) Data;
}

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (IsDecimal<T> && IsDecimal<TResult>)
place = adjust_place(place);
this->data(place).add(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
}

void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
if constexpr (IsDecimal<T> && IsDecimal<TResult>) {
place = adjust_place(place);
rhs = adjust_place(rhs);
}
this->data(place).merge(this->data(rhs));
}

void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
if constexpr (IsDecimal<T> && IsDecimal<TResult>)
place = adjust_place(place);
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
if constexpr (IsDecimal<T> && IsDecimal<TResult>)
place = adjust_place(place);
this->data(place).read(buf);
}

void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
if constexpr (IsDecimal<T> && IsDecimal<TResult>)
place = adjust_place(place);
static_cast<ColumnVector<TResult> &>(to).getData().push_back(this->data(place).get());
}

Expand Down
2 changes: 2 additions & 0 deletions dbms/src/AggregateFunctions/IAggregateFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ class IAggregateFunctionDataHelper : public IAggregateFunctionHelper<Derived>

size_t sizeOfData() const override
{
if constexpr (IsDecimal<T>)
return sizeof(Data) + 32;
return sizeof(Data);
}

Expand Down