Skip to content

Commit 00b4c3d

Browse files
authored
common : support tag-based --hf-repo like on ollama (#11195)
* common : support tag-based hf_repo like on ollama * fix build * various fixes * small fixes * fix style * fix windows build? * move common_get_hf_file to common.cpp * fix complain with noreturn
1 parent 7426a26 commit 00b4c3d

File tree

3 files changed

+130
-17
lines changed

3 files changed

+130
-17
lines changed

common/arg.cpp

+22-11
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,26 @@ std::string common_arg::to_string() {
130130

131131
static void common_params_handle_model_default(
132132
std::string & model,
133-
std::string & model_url,
133+
const std::string & model_url,
134134
std::string & hf_repo,
135-
std::string & hf_file) {
135+
std::string & hf_file,
136+
const std::string & hf_token) {
136137
if (!hf_repo.empty()) {
137138
// short-hand to avoid specifying --hf-file -> default it to --model
138139
if (hf_file.empty()) {
139140
if (model.empty()) {
140-
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
141+
auto auto_detected = common_get_hf_file(hf_repo, hf_token);
142+
if (auto_detected.first.empty() || auto_detected.second.empty()) {
143+
exit(1); // built without CURL, error message already printed
144+
}
145+
hf_repo = auto_detected.first;
146+
hf_file = auto_detected.second;
147+
} else {
148+
hf_file = model;
141149
}
142-
hf_file = model;
143-
} else if (model.empty()) {
150+
}
151+
// make sure model path is present (for caching purposes)
152+
if (model.empty()) {
144153
// this is to avoid different repo having same file name, or same file name in different subdirs
145154
std::string filename = hf_repo + "_" + hf_file;
146155
// to make sure we don't have any slashes in the filename
@@ -290,8 +299,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
290299
}
291300

292301
// TODO: refactor model params in a common struct
293-
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
294-
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);
302+
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
303+
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
295304

296305
if (params.escape) {
297306
string_process_escapes(params.prompt);
@@ -1583,21 +1592,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15831592
}
15841593
).set_env("LLAMA_ARG_MODEL_URL"));
15851594
add_opt(common_arg(
1586-
{"-hfr", "--hf-repo"}, "REPO",
1587-
"Hugging Face model repository (default: unused)",
1595+
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
1596+
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
1597+
"example: unsloth/phi-4-GGUF:q4_k_m\n"
1598+
"(default: unused)",
15881599
[](common_params & params, const std::string & value) {
15891600
params.hf_repo = value;
15901601
}
15911602
).set_env("LLAMA_ARG_HF_REPO"));
15921603
add_opt(common_arg(
15931604
{"-hff", "--hf-file"}, "FILE",
1594-
"Hugging Face model file (default: unused)",
1605+
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
15951606
[](common_params & params, const std::string & value) {
15961607
params.hf_file = value;
15971608
}
15981609
).set_env("LLAMA_ARG_HF_FILE"));
15991610
add_opt(common_arg(
1600-
{"-hfrv", "--hf-repo-v"}, "REPO",
1611+
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
16011612
"Hugging Face model repository for the vocoder model (default: unused)",
16021613
[](common_params & params, const std::string & value) {
16031614
params.vocoder.hf_repo = value;

common/common.cpp

+100-6
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@
7373
#include <sys/syslimits.h>
7474
#endif
7575
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
76+
77+
//
78+
// CURL utils
79+
//
80+
81+
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
82+
83+
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
84+
struct curl_slist_ptr {
85+
struct curl_slist * ptr = nullptr;
86+
~curl_slist_ptr() {
87+
if (ptr) {
88+
curl_slist_free_all(ptr);
89+
}
90+
}
91+
};
7692
#endif // LLAMA_USE_CURL
7793

7894
using json = nlohmann::ordered_json;
@@ -1130,7 +1146,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
11301146

11311147
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
11321148
// Initialize libcurl
1133-
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
1149+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1150+
curl_slist_ptr http_headers;
11341151
if (!curl) {
11351152
LOG_ERR("%s: error initializing libcurl\n", __func__);
11361153
return false;
@@ -1144,11 +1161,9 @@ static bool common_download_file(const std::string & url, const std::string & pa
11441161

11451162
// Check if hf-token or bearer-token was specified
11461163
if (!hf_token.empty()) {
1147-
std::string auth_header = "Authorization: Bearer ";
1148-
auth_header += hf_token.c_str();
1149-
struct curl_slist *http_headers = NULL;
1150-
http_headers = curl_slist_append(http_headers, auth_header.c_str());
1151-
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
1164+
std::string auth_header = "Authorization: Bearer " + hf_token;
1165+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
1166+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
11521167
}
11531168

11541169
#if defined(_WIN32)
@@ -1444,6 +1459,80 @@ struct llama_model * common_load_model_from_hf(
14441459
return common_load_model_from_url(model_url, local_path, hf_token, params);
14451460
}
14461461

1462+
/**
1463+
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
1464+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
1465+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
1466+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
1467+
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
1468+
*
1469+
* Return pair of <repo, file> (with "repo" already having tag removed)
1470+
*
1471+
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
1472+
*/
1473+
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
1474+
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
1475+
std::string tag = parts.size() > 1 ? parts.back() : "latest";
1476+
std::string hf_repo = parts[0];
1477+
if (string_split<std::string>(hf_repo, '/').size() != 2) {
1478+
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
1479+
}
1480+
1481+
// fetch model info from Hugging Face Hub API
1482+
json model_info;
1483+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1484+
curl_slist_ptr http_headers;
1485+
std::string res_str;
1486+
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
1487+
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
1488+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
1489+
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
1490+
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
1491+
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
1492+
return size * nmemb;
1493+
};
1494+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
1495+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
1496+
#if defined(_WIN32)
1497+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
1498+
#endif
1499+
if (!hf_token.empty()) {
1500+
std::string auth_header = "Authorization: Bearer " + hf_token;
1501+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
1502+
}
1503+
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
1504+
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
1505+
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
1506+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1507+
1508+
CURLcode res = curl_easy_perform(curl.get());
1509+
1510+
if (res != CURLE_OK) {
1511+
throw std::runtime_error("error: cannot make GET request to HF API");
1512+
}
1513+
1514+
long res_code;
1515+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
1516+
if (res_code == 200) {
1517+
model_info = json::parse(res_str);
1518+
} else if (res_code == 401) {
1519+
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
1520+
} else {
1521+
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
1522+
}
1523+
1524+
// check response
1525+
if (!model_info.contains("ggufFile")) {
1526+
throw std::runtime_error("error: model does not have ggufFile");
1527+
}
1528+
json & gguf_file = model_info.at("ggufFile");
1529+
if (!gguf_file.contains("rfilename")) {
1530+
throw std::runtime_error("error: ggufFile does not have rfilename");
1531+
}
1532+
1533+
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
1534+
}
1535+
14471536
#else
14481537

14491538
struct llama_model * common_load_model_from_url(
@@ -1465,6 +1554,11 @@ struct llama_model * common_load_model_from_hf(
14651554
return nullptr;
14661555
}
14671556

1557+
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
1558+
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
1559+
return std::make_pair("", "");
1560+
}
1561+
14681562
#endif // LLAMA_USE_CURL
14691563

14701564
//

common/common.h

+8
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,11 @@ static bool string_starts_with(const std::string & str,
454454
return str.rfind(prefix, 0) == 0;
455455
}
456456

457+
static bool string_ends_with(const std::string & str,
458+
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
459+
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
460+
}
461+
457462
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
458463
void string_process_escapes(std::string & input);
459464

@@ -501,6 +506,9 @@ struct llama_model * common_load_model_from_hf(
501506
const std::string & local_path,
502507
const std::string & hf_token,
503508
const struct llama_model_params & params);
509+
std::pair<std::string, std::string> common_get_hf_file(
510+
const std::string & hf_repo_with_tag,
511+
const std::string & hf_token);
504512

505513
// clear LoRA adapters from context, then apply new list of adapters
506514
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);

0 commit comments

Comments
 (0)