Skip to content

Commit bc92a03

Browse files
authored
Merge pull request #2 from mgonzs13/main
hf_hub_download_with_shards
2 parents 225b48b + 7436efb commit bc92a03

File tree

3 files changed

+63
-3
lines changed

3 files changed

+63
-3
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ With this library, C++ developers can integrate Hugging Face model downloads dir
66

77
[![GitHub License](https://img.shields.io/github/license/agonzc34/huggingface-hub-cpp)](https://opensource.org/license/mit) [![GitHub release](https://img.shields.io/github/release/agonzc34/huggingface-hub-cpp.svg)](https://github.com/agonzc34/huggingface-hub-cpp/releases) [![Code Size](https://img.shields.io/github/languages/code-size/agonzc34/huggingface-hub-cpp.svg?branch=main)](https://github.com/agonzc34/huggingface-hub-cpp?branch=main) [![Last Commit](https://img.shields.io/github/last-commit/agonzc34/huggingface-hub-cpp.svg)](https://github.com/agonzc34/huggingface-hub-cpp/commits/main) [![GitHub issues](https://img.shields.io/github/issues/agonzc34/huggingface-hub-cpp)](https://github.com/agonzc34/huggingface-hub-cpp/issues) [![GitHub pull requests](https://img.shields.io/github/issues-pr/agonzc34/huggingface-hub-cpp)](https://github.com/agonzc34/huggingface-hub-cpp/pulls) [![Contributors](https://img.shields.io/github/contributors/agonzc34/huggingface-hub-cpp.svg)](https://github.com/agonzc34/huggingface-hub-cpp/graphs/contributors) [![Build](https://github.com/agonzc34/huggingface-hub-cpp/actions/workflows/cmake-build-status.yml/badge.svg)](https://github.com/agonzc34/huggingface-hub-cpp/actions/workflows/cmake-build-status.yml?branch=main) [![Doxygen Deployment](https://github.com/agonzc34/huggingface-hub-cpp/actions/workflows/doxygen-deployment.yml/badge.svg)](https://agonzc34.github.io/huggingface-hub-cpp/)
88

9-
109
## Table of Contents
1110

1211
- [huggingface-hub-cpp](#huggingface-hub-cpp)
@@ -64,6 +63,7 @@ int main() {
6463
} else {
6564
std::cout << "Error" << std::endl;
6665
}
66+
}
6767
```
6868

6969
### Running the demo app

include/huggingface_hub.h

+20
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,25 @@ hf_hub_download(const std::string &repo_id, const std::string &filename,
9595
const std::string &cache_dir = "~/.cache/huggingface/hub",
9696
bool force_download = false);
9797

98+
/**
99+
* @brief Download a file from Hugging Face Hub.
100+
*
101+
* This function downloads a specified file from a given repository on the
102+
* Hugging Face Hub and saves it to the specified cache directory.
103+
*
104+
* @param repo_id The repository ID.
105+
* @param filename The name of the file to download.
106+
* @param cache_dir The directory to cache the downloaded file. Default is
107+
* "~/.cache/huggingface/hub".
108+
* @param force_download If true, forces the download even if the file already
109+
* exists in the cache.
110+
* @return A DownloadResult structure containing the success status and the path
111+
* of the downloaded file.
112+
*/
113+
struct DownloadResult hf_hub_download_with_shards(
114+
const std::string &repo_id, const std::string &filename,
115+
const std::string &cache_dir = "~/.cache/huggingface/hub",
116+
bool force_download = false);
117+
98118
#endif // HUGGINGFACE_HUB_H
99119
} // namespace huggingface_hub

src/huggingface_hub.cpp

+42-2
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@
2121
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2222
// SOFTWARE.
2323

24-
#include "huggingface_hub.h"
2524
#include <algorithm>
2625
#include <chrono>
2726
#include <csignal>
28-
#include <curl/curl.h>
2927
#include <filesystem>
3028
#include <fstream>
3129
#include <iomanip>
30+
#include <regex>
3231
#include <sstream>
32+
33+
#include <curl/curl.h>
3334
#include <sys/stat.h>
3435
#include <sys/types.h>
3536

37+
#include "huggingface_hub.h"
38+
3639
namespace huggingface_hub {
40+
3741
volatile sig_atomic_t stop_download = 0;
3842
void handle_sigint(int) { stop_download = 1; }
3943

@@ -369,4 +373,40 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
369373
result.success = res == CURLE_OK;
370374
return result;
371375
}
376+
377+
struct DownloadResult hf_hub_download_with_shards(const std::string &repo_id,
378+
const std::string &filename,
379+
const std::string &cache_dir,
380+
bool force_download) {
381+
382+
std::regex pattern(R"(-([0-9]+)-of-([0-9]+)\.gguf)");
383+
std::smatch match;
384+
385+
if (std::regex_search(filename, match, pattern)) {
386+
int total_shards = std::stoi(match[2]);
387+
std::string base_name = filename.substr(0, match.position(0));
388+
389+
// Download shards
390+
for (int i = 1; i <= total_shards; ++i) {
391+
char shard_file[512];
392+
snprintf(shard_file, sizeof(shard_file), "%s-%05d-of-%05d.gguf",
393+
base_name.c_str(), i, total_shards);
394+
auto aux_res =
395+
hf_hub_download(repo_id, shard_file, cache_dir, force_download);
396+
397+
if (!aux_res.success) {
398+
return aux_res;
399+
}
400+
}
401+
402+
// Return first shard
403+
char first_shard[512];
404+
snprintf(first_shard, sizeof(first_shard), "%s-00001-of-%05d.gguf",
405+
base_name.c_str(), total_shards);
406+
return hf_hub_download(repo_id, first_shard, cache_dir, false);
407+
}
408+
409+
return hf_hub_download(repo_id, filename, cache_dir, force_download);
410+
}
411+
372412
} // namespace huggingface_hub

0 commit comments

Comments
 (0)