Skip to content

Commit

Permalink
Ran format on save from VScode
Browse files Browse the repository at this point in the history
  • Loading branch information
rickardp committed Feb 4, 2024
1 parent fdddb11 commit b7503c9
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ updates:
major:
update-types: [major]
minor-patch:
update-types: [minor, patch]
update-types: [minor, patch]
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
types: [ published ]

jobs:

##
# This job matrix builds the non-CUDA versions of the libraries for all supported platforms.
##
Expand Down Expand Up @@ -120,7 +120,7 @@ jobs:
build_os=${{ matrix.os }}
build_arch=${{ matrix.arch }}
for NO_CUBLASLT in ON OFF; do
if [ ${build_os:0:6} == ubuntu ]; then
if [ ${build_os:0:6} == ubuntu ]; then
image=nvidia/cuda:${{ matrix.cuda_version }}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform linux/$build_arch -i -w /src -v $PWD:/src $image sh -c \
Expand Down
60 changes: 30 additions & 30 deletions csrc/mps_kernels.metal
Original file line number Diff line number Diff line change
Expand Up @@ -83,35 +83,35 @@ static unsigned char quantize_scalar(
}
}

kernel void quantize(device float* code [[buffer(0)]],
device float* A [[buffer(1)]],
device uchar* out [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]]) {
const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK;
const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK);
float vals[NUM];
uchar qvals[NUM];
for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) {
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint j = 0; j < valid_items; j++) {
vals[j] = A[i + j];
}
for (uint j = 0; j < valid_items; j++) {
kernel void quantize(device float* code [[buffer(0)]],
device float* A [[buffer(1)]],
device uchar* out [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]]) {
const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK;
const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK);

float vals[NUM];
uchar qvals[NUM];

for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) {
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint j = 0; j < valid_items; j++) {
vals[j] = A[i + j];
}

for (uint j = 0; j < valid_items; j++) {
qvals[j] = quantize_scalar<false>(0.0f, code, vals[j]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint j = 0; j < valid_items; j++) {
out[i + j] = qvals[j];
}
}
}

threadgroup_barrier(mem_flags::mem_threadgroup);

for (uint j = 0; j < valid_items; j++) {
out[i + j] = qvals[j];
}
}
}
12 changes: 6 additions & 6 deletions csrc/mps_ops.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

static inline id<MTLDevice> get_device()
{
NSError *error = nil;
NSError *error = nil;
static id<MTLDevice> device = nil;
if(!device) {
device = MTLCreateSystemDefaultDevice();
device = MTLCreateSystemDefaultDevice();
}
if(!device) {
NSLog(@"Failed to get MPS device");
Expand All @@ -30,7 +30,7 @@

static inline id<MTLLibrary> get_library()
{
NSError *error = nil;
NSError *error = nil;
static id<MTLLibrary> library = nil;
if(!library) {
library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error];
Expand All @@ -40,7 +40,7 @@
abort();
}
return library;
}
}

/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
Expand All @@ -49,7 +49,7 @@
}*/


// MPSGraph function for quantize
// MPSGraph function for quantize
extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n)
{
id<MTLDevice> device = get_device();
Expand All @@ -64,4 +64,4 @@
}
NSLog(@"Not implemented");
return nil;
}
}

0 comments on commit b7503c9

Please sign in to comment.