Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 9, 2024
1 parent 2dc6923 commit 28798f9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions csrc/core/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ScalarType {
bias(bias),
signed_(signed_),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
nan_repr(nan_repr){};

static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(true, 0, size_bits - 1, bias);
Expand Down Expand Up @@ -287,9 +287,9 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
bool _signed)
: ScalarType(exponent, mantissa, bias, _signed) {};
: ScalarType(exponent, mantissa, bias, _signed){};

ScalarTypeTorch(ScalarType type) : ScalarType(type) {};
ScalarTypeTorch(ScalarType type) : ScalarType(type){};

using Base = ScalarType;
using Self = ScalarTypeTorch;
Expand Down Expand Up @@ -361,7 +361,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
getter_func_helper = std::move(getter_func_helper)](
SelfPtr const& self) {
auto val = getter_func_helper(self);
// upconvert uint8_t, int32_t ect. to int64_t for python
// upconvert uint8_t, int32_t etc. to int64_t for python
if constexpr (std::is_integral_v<T>) {
return static_cast<int64_t>(val);
} else {
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int size_k, int block_rows) {}

template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
const int64_t w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
Expand Down Expand Up @@ -507,7 +507,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}

template <typename scalar_t, // compute dtype, half or nv_float16
const int64_t w_type_id, // number of bits used for weights
const int64_t w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
Expand Down

0 comments on commit 28798f9

Please sign in to comment.