Skip to content

Commit

Permalink
Merge branch 'xsn/vision' of https://github.com/ngxson/llama.cpp into…
Browse files Browse the repository at this point in the history
… wirthual/fix-vision
  • Loading branch information
wirthual committed Oct 9, 2024
2 parents 3ca3898 + a88c0d5 commit c430c21
Show file tree
Hide file tree
Showing 12 changed files with 311 additions and 133 deletions.
21 changes: 21 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,27 @@ std::vector<llama_token> llama_tokenize(
return result;
}

// TODO: this function is hacky, need to be improved
std::vector<llama_token> llama_tokenize_with_img(
const struct llama_context * ctx,
const std::string & text,
bool add_special,
bool parse_special) {
static const std::string IMG_PLACEMENT = "<img_placement>";
std::vector<std::string> parts = string_split(text, IMG_PLACEMENT);
std::vector<llama_token> output;
for (const auto & part : parts) {
bool add_bos = &parts.front() == &part;
auto tokens = llama_tokenize(ctx, part, add_special && add_bos, parse_special);
output.insert(output.end(), tokens.begin(), tokens.end());
if (&parts.back() != &part) {
// add image token to middle of 2 parts
output.push_back(TOKEN_IMG_PLACEMENT);
}
}
return output;
}

std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
Expand Down
25 changes: 25 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,20 @@ static std::vector<T> string_split(const std::string & str, char delim) {
return values;
}

// split string by a `std::string delim` instead of `char delim`
static std::vector<std::string> string_split(std::string s, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while ((pos = s.find(delimiter)) != std::string::npos) {
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}

bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);

Expand Down Expand Up @@ -447,6 +461,17 @@ std::vector<llama_token> llama_tokenize(
bool add_special,
bool parse_special = false);

const llama_token TOKEN_IMG_PLACEMENT = -1000;

// tokenize with "placeholder" for image embedding tokens
// "<img_placement>" will be replaced with TOKEN_IMG_PLACEMENT
// TODO: this function is hacky, need to be improved
std::vector<llama_token> llama_tokenize_with_img(
const struct llama_context * ctx,
const std::string & text,
bool add_special,
bool parse_special = false);

// tokenizes a token into a piece, optionally renders special/control tokens
// should work similar to Python's `tokenizer.id_to_piece`
std::string llama_token_to_piece(
Expand Down
2 changes: 1 addition & 1 deletion common/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ llama_img * load_image_from_file(const char * fname) {
// printf("\n");
// }
// printf("\n");
llama_img * result = llama_img_alloc(nx, ny);
llama_img * result = llama_img_init(nx, ny);
memcpy(result->data, img, nx*ny*3);
stbi_image_free(img);
return result;
Expand Down
6 changes: 4 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def load_hparams(dir_model: Path):
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams

@staticmethod
def load_preprocessor_config(dir_model: Path):
file_path = dir_model / "preprocessor_config.json"
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def set_gguf_parameters(self):

# For vision model
if self.vparams is not None and self.preprocessor_config is not None:
self.gguf_writer.add_vision_type("clip")
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
Expand All @@ -1600,6 +1600,8 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
# TODO: should not hardcode these, but they are currently missing from config.json
Expand Down
111 changes: 60 additions & 51 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ static void print_usage(int, char ** argv) {
int main(int argc, char ** argv) {
gpt_params params;

params.prompt = "Hello my name is";
//params.prompt = "Hello my name is";
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n"
"USER:<img_placement>\nwhat did you see?\nASSISTANT:";
params.n_predict = 32;

if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
Expand Down Expand Up @@ -62,52 +64,10 @@ int main(int argc, char ** argv) {

llama_sampler_chain_add(smpl, llama_sampler_init_greedy());




// TODO: this is for testing; DELETE ME
int n_cur = 0;
params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
{
llama_img_batch ibatch;
ibatch.n_imgs = 1;
ibatch.imgs = (llama_img **) malloc(1024);
ibatch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");
llama_vision_encode(ctx, &ibatch);

auto tokens = ::llama_tokenize(ctx, params.prompt, true);
int n_imgs = ibatch.n_imgs;
int n_embd = llama_n_embd(model);
int n_patches = llama_vision_n_patches(ctx);
printf("n_embd = %d ; n_patches = %d \n", n_embd, n_patches);
float * output_img = llama_vision_get_embeddings(ctx, 0);

n_cur += tokens.size();
llama_batch batch = llama_batch_init(512, 0, 1);
llama_batch_clear(batch);
for (auto t : tokens) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}

// for (int k = 0; k < 10; k++) printf("%f\n", output_img[k]);
llama_batch_clear(batch);
batch = {int32_t(n_patches*n_imgs), nullptr, output_img, nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
n_cur += n_embd*n_imgs;
}
params.prompt = "\nwhat did you see?\nASSISTANT:";



// tokenize the prompt

std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
tokens_list = ::llama_tokenize_with_img(ctx, params.prompt, true);

const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size());
Expand All @@ -127,33 +87,82 @@ int main(int argc, char ** argv) {
LOG("\n");

for (auto id : tokens_list) {
LOG("%s", llama_token_to_piece(ctx, id).c_str());
if (id == TOKEN_IMG_PLACEMENT) {
LOG("<img_placement>");
} else {
LOG("%s", llama_token_to_piece(ctx, id).c_str());
}
}

LOG("\n\n");

// load image
llama_batch_img img_batch = llama_batch_img_init(1);
img_batch.imgs[0] = load_image_from_file("../models/eiffel-tower-3349075_1280.jpg");

// create a llama_batch with size 512
// we use this object to submit token data for decoding

llama_batch batch = llama_batch_init(512, 0, 1);

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); i++) {
//llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
if (i == 0) continue;
llama_batch_add(batch, tokens_list[i], n_cur, { 0 }, false);
n_cur++;
int n_cur = 0;
int i_img = 0;
for (auto id : tokens_list) {
if (id == TOKEN_IMG_PLACEMENT) {
img_batch.pos[i_img] = n_cur;
n_cur += llama_img_n_tokens(ctx, img_batch.imgs[i_img]);
i_img++;
} else {
llama_batch_add(batch, id, n_cur, { 0 }, false);
printf("pos %d tok %d --> %s\n", n_cur, id, llama_token_to_piece(ctx, id).c_str());
n_cur++;
}
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;

if (llama_encode_vision(ctx, img_batch) != 0) {
LOG("%s: llama_encode_vision() failed\n", __func__);
return 1;
}

n_cur = 0;
{
auto t1 = ::llama_tokenize(ctx, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", false);
auto t2 = ::llama_tokenize(ctx, "\nwhat did you see?\nASSISTANT:", false);
t1.insert(t1.begin(), 1);

n_cur = 0;
llama_batch_clear(batch);
llama_batch_add(batch, 1, 0, { 0 }, false);
llama_decode(ctx, batch);

n_cur = t1.size();
llama_batch_clear(batch);
llama_batch batch0 = {int32_t(576), nullptr, _test_get_img_embd(ctx), nullptr, nullptr, nullptr, nullptr, n_cur, 1, 0, };
llama_decode(ctx, batch0);

n_cur = 0;
llama_batch_clear(batch);
for (auto t : t1) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
llama_decode(ctx, batch);

n_cur = t1.size() + 576;
llama_batch_clear(batch);
printf("pos %d\n", n_cur);
for (auto t : t2) { llama_batch_add(batch, t, n_cur, { 0 }, false); n_cur++; }
batch.logits[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}

// main loop

//int n_cur = batch.n_tokens;
int n_decode = 0;

const auto t_main_start = ggml_time_us();
Expand Down
21 changes: 14 additions & 7 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,15 @@ class Tokenizer:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOM_ID = "tokenizer.ggml.eom_token_id"
IMAGE_START_ID = "tokenizer.ggml.image_start_token_id"
IMAGE_END_ID = "tokenizer.ggml.image_end_token_id"

class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"

class Vision:
# only support vision.type = "clip" for now
# only support vision.type = "clip-vit" for now
TYPE = "vision.type"
IMAGE_SIZE = "vision.image_size"
PATCH_SIZE = "vision.patch_size"
Expand All @@ -196,7 +198,10 @@ class Clip:
PROJECTION_DIM = "vision.clip.projection_dim"
USE_GELU = "vision.clip.use_gelu"
MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings"
MAX_SLICES = "vision.clip.max_slices"
PROJECTOR_TYPE = "vision.clip.projector_type"
SELECT_LAYER = "vision.clip.select_layer"
PATCH_MERGE_TYPE = "vision.clip.patch_merge_type"
HEAD_COUNT = "vision.clip.attention.head_count"
LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon"

Expand Down Expand Up @@ -370,8 +375,7 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
# vision
V_MMPROJ_A = auto()
V_MMPROJ_B = auto()
V_MMPROJ = auto()
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
Expand Down Expand Up @@ -547,8 +551,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
# vision
MODEL_TENSOR.V_MMPROJ_A: "v.mmproj_a",
MODEL_TENSOR.V_MMPROJ_B: "v.mmproj_b",
MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos",
Expand Down Expand Up @@ -1338,8 +1341,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.LLAVA_VISION: [
MODEL_TENSOR.V_MMPROJ_A,
MODEL_TENSOR.V_MMPROJ_B,
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
Expand Down Expand Up @@ -1430,6 +1432,11 @@ class CLIPProjectorType(Enum):
MLP = 'mlp'


class CLIPPatchMergeType(Enum):
FLAT = 'flat'
SPATIAL_UNPAD = 'spatial_unpad'


class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
Expand Down
10 changes: 10 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PoolingType,
TokenType,
CLIPProjectorType,
CLIPPatchMergeType,
)

from .quants import quant_shape_from_byte_shape
Expand Down Expand Up @@ -848,6 +849,15 @@ def add_vision_clip_max_position_embeddings(self, value: int) -> None:
def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None:
self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value)

def add_vision_clip_max_slices(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.MAX_SLICES, value)

def add_vision_clip_select_layer(self, value: int) -> None:
self.add_int32(Keys.Vision.Clip.SELECT_LAYER, value)

def add_vision_clip_patch_merge_type(self, value: CLIPPatchMergeType) -> None:
self.add_string(Keys.Vision.Clip.PATCH_MERGE_TYPE, value.value)

def add_vision_clip_layer_norm_epsilon(self, value: float) -> None:
self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value)

Expand Down
8 changes: 4 additions & 4 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,12 @@ class TensorNameMap:
"encoder.final_layer_norm", # t5
),

MODEL_TENSOR.V_MMPROJ_A: (
"multi_modal_projector.linear_1",
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
),

MODEL_TENSOR.V_MMPROJ_B: (
"multi_modal_projector.linear_2",
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
),

MODEL_TENSOR.V_ENC_EMBD_CLS: (
Expand Down
Loading

0 comments on commit c430c21

Please sign in to comment.