Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ling0322 committed Jun 23, 2024
1 parent 1b0842a commit 858da23
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ option(WITH_OPENMP "Build with OpenMP." ON)
option(WITH_CUTLASS "build MatMul operators with CUTLASS." OFF)
option(MKL_PREFIX "Prefix for MKL headers and libraries." "/opt/intel/mkl")

set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>")

if(WITH_CUDA)
add_definitions("-DLIBLLM_CUDA_ENABLED")
find_package(CUDAToolkit REQUIRED)
Expand Down
3 changes: 2 additions & 1 deletion go/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

package llm

// #cgo LDFLAGS: -ldl
// #cgo linux LDFLAGS: -ldl
// #cgo darwin LDFLAGS: -ldl
// #include <stdlib.h>
// #include "llm_api.h"
import "C"
Expand Down
5 changes: 4 additions & 1 deletion src/libllm/benchmark_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "libllm/lut/flags.h"
#include "libllm/lut/random.h"
#include "libllm/lut/time.h"
#include "libllm/operators.h"
#include "libllm/model_for_generation.h"

constexpr int MagicNumber = 0x55aa;
Expand Down Expand Up @@ -165,7 +166,7 @@ void benchmarkLlama(std::shared_ptr<llama::LlamaModel> model, int ctxLength, DTy
}

int benchmarkMain(Device device) {
CHECK(llmInit(LLM_API_VERSION) == LLM_OK);
libllm::initOperators();

LlamaType llamaType = LlamaType::Llama2_7B;
DType weightType = libllm::DType::kQInt4x32;
Expand All @@ -182,6 +183,8 @@ int benchmarkMain(Device device) {
libllm::benchmarkLlama(model, 512, libllm::DType::kQInt4x32);

printf("----------------------------------------------------------\n");

libllm::destroyOperators();
return 0;
}

Expand Down
11 changes: 8 additions & 3 deletions src/libllm/cuda/gemm_cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ namespace libllm {
namespace op {
namespace cuda {

std::shared_ptr<Gemm> CublasGemm::create() {
std::shared_ptr<CublasGemm> mm = std::make_shared<CublasGemm>();
Gemm *CublasGemm::create() {
CublasGemm *mm = new CublasGemm();
mm->_handle = {nullptr, safeDestroyCublas};
if (CUBLAS_STATUS_SUCCESS != cublasCreate(mm->_handle.get_pp())) {
delete mm;
return nullptr;
} else {
return mm;
Expand Down Expand Up @@ -138,6 +139,10 @@ lut::ErrorCode CublasGemm::hgemmArray(
} // op
} // ly

std::shared_ptr<libllm::op::cuda::Gemm> llmCreateCudaOpExtGemm() {
libllm::op::cuda::Gemm *llmGemmExt_New() {
return libllm::op::cuda::CublasGemm::create();
}

void llmGemmExt_Delete(libllm::op::cuda::Gemm *gemm) {
delete gemm;
}
5 changes: 3 additions & 2 deletions src/libllm/cuda/gemm_cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace cuda {
/// @brief Operators implemented by cuBLAS.
class CublasGemm : public Gemm {
public:
static std::shared_ptr<Gemm> create();
static Gemm *create();

lut::ErrorCode hgemm(
bool transA,
Expand Down Expand Up @@ -78,5 +78,6 @@ class CublasGemm : public Gemm {
} // ly

extern "C" {
EXTAPI std::shared_ptr<libllm::op::cuda::Gemm> llmCreateCudaOpExtGemm();
EXTAPI libllm::op::cuda::Gemm *llmGemmExt_New();
EXTAPI void llmGemmExt_Delete(libllm::op::cuda::Gemm *gemm);
} // extern "C"
8 changes: 5 additions & 3 deletions src/libllm/cuda/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ std::shared_ptr<MatMul> MatMul::createCublas() {

mm->_gemmExtLib = lut::SharedLibrary::open("llmextcublas");

std::function<std::shared_ptr<op::cuda::Gemm>()> factory;
factory = mm->_gemmExtLib->getFunc<std::shared_ptr<op::cuda::Gemm>()>("llmCreateCudaOpExtGemm");
std::function<op::cuda::Gemm *()> factory;
std::function<void(op::cuda::Gemm *)> deleter;
factory = mm->_gemmExtLib->getFunc<op::cuda::Gemm *()>("llmGemmExt_New");
deleter = mm->_gemmExtLib->getFunc<void(op::cuda::Gemm *)>("llmGemmExt_Delete");

mm->_gemm = factory();
mm->_gemm = std::shared_ptr<op::cuda::Gemm>(factory(), deleter);
if (!mm->_gemm) throw lut::AbortedError("unable to create MatMul operator.");

return mm;
Expand Down
9 changes: 4 additions & 5 deletions src/libllm/test_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@

#include "../../third_party/catch2/catch_amalgamated.hpp"
#include "libllm/cpu/kernel/interface.h"
#include "libllm/llm.h"
#include "libllm/operators.h"
#include "libllm/lut/error.h"
#include "libllm/lut/log.h"

int main(int argc, char **argv) {
if (llmInit(LLM_API_VERSION) != LLM_OK) {
LOG(FATAL) << llmGetLastErrorMessage();
}
libllm::initOperators();

// enable some slow kernels for reference.
libllm::op::cpu::kernel::setAllowSlowKernel(true);

int result = Catch::Session().run(argc, argv);
CHECK(llmDestroy() == LLM_OK);

libllm::destroyOperators();

return result;
}

0 comments on commit 858da23

Please sign in to comment.