-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml: replace conv 1D - 2D stage_0 and stage_1 with im2col and mul_mat #564
Conversation
Added some 'unit' tests: master: load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
main: compute buffer size: 0.00 MB
ggml_conv2d (64): PASS PR: load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
GGML_ASSERT: D:\proyectos\ggml\src\ggml.c:6425: ggml_can_mul_mat(a, b) Incompatible matrix multiplication (the current implementation doesn't support 4D tensors?) 😂 Edit: Fixed: struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
bool special_case = (a->ne[0] * a->ne[1] * a->ne[2]) == b->ne[0]; // admit custom multiplication
GGML_ASSERT(ggml_can_mul_mat(a, b) || special_case);
GGML_ASSERT(!ggml_is_transposed(a));
bool is_node = false;
if (a->grad || b->grad) {
is_node = true;
}
const int64_t ne[4] = {
special_case ? b->ne[1] : a->ne[1],
b->ne[special_case ? 2 : 1],
special_case ? a->ne[3] : b->ne[2],
b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_MUL_MAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
return result;
} Result: D:\proyectos\ggml\build\bin\Release>test-conv2d.exe
load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
Tensor im2col: [36, 4, 4, 1]
Tensor a: [3, 3, 4, 4]
Tensor mul_mat: [4, 4, 4, 1]
main: compute buffer size: 0.00 MB
Tensor im2col: [36, 4, 4, 1]
Tensor a: [3, 3, 4, 4]
Tensor mul_mat: [4, 4, 4, 1]
ggml_conv2d (64): PASS |
Fixed (CPU Backend): D:\proyectos\ggml\build\bin\Release>test-conv1d.exe
load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
main: compute buffer size: 0.00 MB
ggml_conv1d (15): PASS
D:\proyectos\ggml\build\bin\Release>test-conv2d.exe
load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
main: compute buffer size: 0.00 MB
ggml_conv2d (64): PASS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think this is great - I'm surprised you pulled it so quickly!
However, there are some things in ggml.c
that should be done in a simpler way. See my. comments.
@ggerganov @slaren I need some feedback about ggml_mul_mat In the following mult of matrices with the shape (3,2) x (2, 3): To resolve it with D:\proyectos\cpp-projects\ggml-test\build\Release>mult-mat-test.exe
Transposed B:
[10.0 5.0
9.0 9.0
5.0 4.0]
load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 1084 bytes
Performing ggml_gemm test:
60.0 90.0 42.0
55.0 54.0 29.0
50.0 54.0 28.0
gemm_mult: PASSED
main: compute buffer size: 0.0625 KB
Performing ggml_mul_mat test:
105.0 48.0 -123612456.0
78.0 37.0 -54790864.0
-1873442825329845586624512.0 -0.0 0.0
ggml_mul_mat (6): FAILED Am I making some mistake? look the test |
In |
I closed the PR accidentally, Thats explains many things 😂😂😂😂😂😂😂😂 |
Conv2D GEMMMatrix A: [3, 3, 10, 10] m: 10, n: 48, k: 90 Where: Result Conv2D (n, m) = [48, 10] -> reshape internally -> (1 * 10, 6 * 8) Conv2D ggml_mul_matRemember (columns, rows) in Result Conv2D = Transpose([10, 48]) = [48, 10] (need reshape) |
Now Conv2d is this to reuse We need implement F16 x F16 = F32 in struct ggml_tensor* im2col = ggml_im2col(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, true); // f32 because ggml_mul_mat doesn't support fp16 source = fp32 output
struct ggml_tensor* conv2d_res = ggml_reshape_4d(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0,
ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_reshape_2d(ctx0, model.a, (model.a->ne[0] * model.a->ne[1] * model.a->ne[2]), model.a->ne[3])),
ggml_cont(ctx0, ggml_reshape_2d(ctx0, im2col_res, im2col_res->ne[0], (im2col_res->ne[3] * im2col_res->ne[2] * im2col_res->ne[1])))))),
im2col_res->ne[1], im2col_res->ne[2], model.a->ne[3], im2col_res->ne[3]); |
Error due CPU Backend: load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
main: compute buffer size: 0.02 MB
GGML_ASSERT: D:\proyectos\ggml\src\ggml.c:11730: nb10 == sizeof(float) // fixed GPU Backend: load_model: ggml tensor size = 320 bytes
load_model: backend buffer size = 0.00 MB
load_model: using CUDA backend
ggml_init_cublas: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6
main: compute buffer size: 0.02 MB
ggml_cuda_cpy: unsupported type combination (f16 to f16)
GGML_ASSERT: D:\proyectos\ggml\src\ggml-cuda.cu:7168: false |
@slaren Does the CUDA implementation of I tried to fix this cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16, CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); |
It should work, but maybe there is a bug with fp16 |
To avoid the assertion: ggml_cuda_cpy: unsupported type combination (f16 to f16)
GGML_ASSERT: D:\proyectos\ggml\src\ggml-cuda.cu:7168: false else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
} I! cuBLAS (v12.0) function cublasStatus_t __cdecl cublasSetStream_v2(struct cublasContext *,struct CUstream_st *) called:
i! handle: type=cublasHandle_t; val=POINTER (IN HEX:0x0000015CD6DC1730)
i! streamId: type=SOME TYPE; val=POINTER (IN HEX:0x0000015CBBBD36B0)
i! Time: 2023-10-14T19:38:20 elapsed from start 0.016667 minutes or 1.000000 seconds
i!Process=12976; Thread=18488; GPU=0; Handle=POINTER (IN HEX:0x0000015CD6DC1730); StreamId=POINTER (IN HEX:0x0000000000000000) (defaultStream); MathMode=CUBLAS_TF32_TENSOR_OP_MATH
i! COMPILED WITH: Microsoft Visual Studio / 192628806.0
I! cuBLAS (v12.0) function cublasStatus_t __cdecl cublasGemmEx(struct cublasContext *,cublasOperation_t,cublasOperation_t,int,int,int,const void *,const void *,enum cudaDataType_t,int,const void *,enum cudaDataType_t,int,const void *,void *,enum cudaDataType_t,int,cublasComputeType_t,cublasGemmAlgo_t) called:
i! handle: type=cublasHandle_t; val=POINTER (IN HEX:0x0000015CD6DC1730)
i! transa: type=cublasOperation_t; val=CUBLAS_OP_T(1)
i! transb: type=cublasOperation_t; val=CUBLAS_OP_N(0)
i! m: type=int; val=10
i! n: type=int; val=8
i! k: type=int; val=30
i! alpha: type=void; val=POINTER (IN HEX:0x0000000F12CFEE98)
i! A: type=void; val=POINTER (IN HEX:0x000000071A820E00)
i! Atype: type=cudaDataType_t; val=CUDA_R_16F(2)
i! lda: type=int; val=30
i! B: type=void; val=NULL_PTR(POINTER (IN HEX:0x0000000000000000))
i! Btype: type=cudaDataType_t; val=CUDA_R_16F(2)
i! ldb: type=int; val=30
i! beta: type=void; val=POINTER (IN HEX:0x0000000F12CFEF28)
i! C: type=void; val=POINTER (IN HEX:0x000000071A821600)
i! Ctype: type=cudaDataType_t; val=CUDA_R_16F(2)
i! ldc: type=int; val=10
i! computeType: type=cublasComputeType_t; val=CUBLAS_COMPUTE_16F(64)
i! algo: type=SOME TYPE; val=CUBLAS_GEMM_DEFAULT_TENSOR_OP(99)
i! Time: 2023-10-14T19:38:20 elapsed from start 0.016667 minutes or 1.000000 seconds
i!Process=12976; Thread=18488; GPU=0; Handle=POINTER (IN HEX:0x0000015CD6DC1730); StreamId=POINTER (IN HEX:0x0000015CBBBD36B0); MathMode=CUBLAS_TF32_TENSOR_OP_MATH
i! COMPILED WITH: Microsoft Visual Studio / 192628806.0
cuBLAS error 7 at D:\proyectos\ggml\src\ggml-cuda.cu:6337: an unsupported value or parameter was passed to the function
current device: 0 The ggml_mul_mat in cpu backend works with the same matrices. |
Try this change: diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 654d363..96e63e0 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -6295,7 +6295,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
}
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); |
As a final opinion on the topic, I would expect to see what people who truly need these features have to say about whether the current method of returning is suitable or the one you propose, as long as they don't have to use 'ggml_cont' and 'ggml_transpose.' If, in any way, they had to do it when operating with the regular version of 'ggml_conv,' then the ideal scenario would be to return the matrix without transposing it. |
The It uses In z0 = ggml_mul_mat(ctx, x, y);
z1 = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, y, x)));
z0 == z1 The |
I will see if |
@FSSRepo I think you can leave the branch as it is and I will try to fix it |
I'm going to review the whisper issue with this PR for now, so I won't be adding any more commits to this branch. |
@ggerganov I don't get some error in the latest commit.
And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. |
The computation was failing when using BLAS - on the CPU it supports only F32. I fixed that in 53f805e I now applied the identity transformation that I told you about earlier: diff --git a/src/ggml.c b/src/ggml.c
index 927f03a..b33a741 100644
--- a/src/ggml.c
+++ b/src/ggml.c
@@ -7482,10 +7482,10 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
int p0,
int d0) {
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
- result = ggml_reshape_3d(ctx, ggml_cont(ctx, ggml_transpose(ctx,
+ result = ggml_reshape_3d(ctx,
ggml_mul_mat(ctx,
- ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]), // [OC,IC, K] => [OC, IC * K]
- ggml_reshape_2d(ctx, result, result->ne[0], (result->ne[2] * result->ne[1]))))), // [N, OL, IC * K] => [N*OL, IC * K]
+ ggml_reshape_2d(ctx, result, result->ne[0], (result->ne[2] * result->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
+ ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])), // [OC,IC, K] => [OC, IC * K]
result->ne[1], a->ne[2], result->ne[2]); // [N, OC, OL]
return result;
} This again produces correct results for |
In Conv2D you can replace ggml_reshape_2d(ctx, a, result->ne[0], a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW] The same ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW] And Conv1D too |
The identity that you refer is just swap the matrix operands? 😂 Good that, I didn't get it. |
Will get back to this PR soon, sorry for delay |
Don't worry. The next week I will add cuda kernels for Today I made the kernels for these ops and will do some tests in stable diffusion. This PR implementation of |
Co-authored-by: slaren <[email protected]>
Fix this #559
This pull request aims to add
ggml_im2col
operator, which I have no idea what it is (I'm serious). I only know that it converts images into columns, eliminating the stage 0 of conv1d and conv2d, and refactorsggml_mul_mat
to support float16 source data types, presumably to reduce memory usage, and remove the stage 1 of conv1d and conv2d. Additionally, I want to implement CUDA kernels for these operations.Conv1D = im2col (1D) + Matrix Multiplication
Conv2D = im2col (2D) + Matrix Multiplication
Where:
Here's the plan to finish this PR. I'll be working on this as soon as possible since I won't have more time due to my studies:
ggml_conv_1d_stage_0
andggml_conv_1d_stage_1
.ggml_conv_2d_stage_0
andggml_conv_2d_stage_1
.ggml_im2col
for 1D and 2D, I am going to investigate the difference in order to refactor the code better and avoid duplication.ggml_compute_forward_mul_mat_f16_f32
for support float 16 matrix multiplication, this reduces memory.ggml_mul_mat
, and removegemm_fp16_out_fp32
. I'm testing the differences.Suggestions and feedback are welcome.