Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server flash winograd1 #2

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f723f56
Add server example
stduhpf Aug 25, 2024
76fa293
server: remove pingpong endpoint
stduhpf Aug 26, 2024
ae238de
Server: Fix missing return on non-void function
stduhpf Aug 27, 2024
c4b8c47
Server: change default host
stduhpf Aug 27, 2024
9a81f4a
Server: Fix printf
stduhpf Aug 27, 2024
f88143f
repair flash attention in _ext
Green-Sky Sep 1, 2024
408cb05
make flash attention in the diffusion model a runtime flag
Green-Sky Sep 7, 2024
e904b86
remove old flash attention option and switch vae over to attn_ext
Green-Sky Sep 7, 2024
cc7efa2
rdy for merge
Green-Sky Sep 7, 2024
90d420a
update docs
Green-Sky Sep 8, 2024
8ab5666
winograd conv2d works but performance is not better
Sep 28, 2024
830fd19
use nn_conv_2d to set winograd
Sep 29, 2024
b37803f
added Conv2d1x3x3 block for winograd
Sep 29, 2024
4e9f036
fixed a bug where recursive transform not propagating down
Sep 30, 2024
8529431
server: move httplib to thirdparty folder
stduhpf Oct 4, 2024
533da39
Server: accept json inputs + return bas64 image
stduhpf Oct 4, 2024
a1e3f04
server: use t.join() instead of infinite loop
stduhpf Oct 4, 2024
83095ca
server: fix CI Build
stduhpf Oct 4, 2024
453f145
server: attach image metadata in response
stduhpf Oct 5, 2024
125707a
server: add simple client script
stduhpf Oct 5, 2024
2f4dfa4
Server: add client docstrings
stduhpf Oct 5, 2024
9912ad3
server: client: fix image preview on non-windows os
stduhpf Oct 5, 2024
c9f2781
server: support sampling method arg
stduhpf Oct 6, 2024
1c59983
server: small test client fixes
stduhpf Oct 5, 2024
0c8554d
Merge branch 'rescue_flash_attn' of github.com:Green-Sky/stable-diffu…
Green-Sky Oct 6, 2024
986b630
force enable flash_attention on server
Green-Sky Oct 6, 2024
a160cc9
Merge branch 'add-winograd-conv2d-v1' of https://github.com/bssrdf/st…
Green-Sky Oct 12, 2024
7c32ace
update ggml for winograd
Green-Sky Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_VULKAN "sd: vulkan backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON)
Expand Down Expand Up @@ -61,11 +60,6 @@ if (SD_HIPBLAS)
endif()
endif ()

if(SD_FLASH_ATTN)
message("-- Use Flash Attention for memory optimization")
add_definitions(-DSD_USE_FLASH_ATTENTION)
endif()

set(SD_LIB stable-diffusion)

file(GLOB SD_LIB_SOURCES
Expand Down
21 changes: 17 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++
- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
- No need to convert to `.ggml` or `.gguf` anymore!
- Flash Attention for memory usage optimization (only cpu for now)
- Flash Attention for memory usage optimization
- Original `txt2img` and `img2img` mode
- Negative prompt
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
Expand Down Expand Up @@ -182,11 +182,21 @@ Example of text2img by using SYCL backend:

##### Using Flash Attention

Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
eg.:
- flux 768x768 ~600mb
- SD2 768x768 ~1400mb

For most backends, it slows things down, but for cuda it generally speeds it up too.
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).

Run by adding `--diffusion-fa` to the arguments and watch for:
```
cmake .. -DSD_FLASH_ATTN=ON
cmake --build . --config Release
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
```
and the compute buffer shrink in the debug log:
```
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
```

### Run
Expand Down Expand Up @@ -239,6 +249,9 @@ arguments:
--vae-tiling process vae in tiles to reduce memory usage
--vae-on-cpu keep vae in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram).
--diffusion-fa use flash attention in the diffusion model (for low vram).
Might lower quality, since it implies converting k and v to f16.
This might crash if it is not supported by the backend.
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color Colors the logging tags according to level
Expand Down
40 changes: 27 additions & 13 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ class UpSampleBlock : public GGMLBlock {
int out_channels)
: channels(channels),
out_channels(out_channels) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
if(channels % 8 == 0 && out_channels % 64 == 0)
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d1x3x3(channels, out_channels));
else
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks["conv"]);

x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
Expand Down Expand Up @@ -82,7 +85,12 @@ class ResBlock : public GGMLBlock {
if (dims == 3) {
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
if (kernel_size.first == 3 && kernel_size.second == 3 &&
in_channels % 8 == 0 && out_channels % 64 == 0 &&
padding.first == 1 && padding.second == 1)
return std::shared_ptr<GGMLBlock>(new Conv2d1x3x3(in_channels, out_channels));
else
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
}
}

Expand Down Expand Up @@ -138,8 +146,9 @@ class ResBlock : public GGMLBlock {
// in_layers
auto h = in_layers_0->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
// print_ggml_tensor(h, true, "bef in_layer");
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]

// print_ggml_tensor(h, true, "aft in_layer");
// emb_layers
if (!skip_t_emb) {
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
Expand Down Expand Up @@ -245,16 +254,19 @@ class CrossAttention : public GGMLBlock {
int64_t context_dim;
int64_t n_head;
int64_t d_head;
bool flash_attn;

public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head)
int64_t d_head,
bool flash_attn = false)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim) {
context_dim(context_dim),
flash_attn(flash_attn) {
int64_t inner_dim = d_head * n_head;

blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
Expand Down Expand Up @@ -283,7 +295,7 @@ class CrossAttention : public GGMLBlock {
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]

x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim]

x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
Expand All @@ -301,15 +313,16 @@ class BasicTransformerBlock : public GGMLBlock {
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false)
bool ff_in = false,
bool flash_attn = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head, flash_attn));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
Expand Down Expand Up @@ -374,7 +387,8 @@ class SpatialTransformer : public GGMLBlock {
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim)
int64_t context_dim,
bool flash_attn = false)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
Expand All @@ -388,7 +402,7 @@ class SpatialTransformer : public GGMLBlock {

for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn));
}

blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
Expand Down Expand Up @@ -511,4 +525,4 @@ class VideoResBlock : public ResBlock {
}
};

#endif // __COMMON_HPP__
#endif // __COMMON_HPP__
26 changes: 21 additions & 5 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,28 @@ struct DiffusionModel {
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual int64_t get_adm_in_channels() = 0;
virtual void transform(int n) = 0;

};

struct UNetModel : public DiffusionModel {
UNetModelRunner unet;

UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_SD1)
: unet(backend, wtype, version) {
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, wtype, version, flash_attn) {
}

void alloc_params_buffer() {
unet.alloc_params_buffer();
}

void transform(int n){
unet.transform(n);
}

void free_params_buffer() {
unet.free_params_buffer();
}
Expand Down Expand Up @@ -108,6 +115,10 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}

void transform(int n){

}

void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
Expand All @@ -129,8 +140,9 @@ struct FluxModel : public DiffusionModel {

FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: flux(backend, wtype, version) {
SDVersion version = VERSION_FLUX_DEV,
bool flash_attn = false)
: flux(backend, wtype, version, flash_attn) {
}

void alloc_params_buffer() {
Expand All @@ -157,6 +169,10 @@ struct FluxModel : public DiffusionModel {
return 768;
}

void transform(int n){

}

void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
Expand All @@ -173,4 +189,4 @@ struct FluxModel : public DiffusionModel {
}
};

#endif
#endif
3 changes: 2 additions & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})

add_subdirectory(cli)
add_subdirectory(cli)
add_subdirectory(server)
10 changes: 9 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ struct SDParams {
bool normalize_input = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
Expand Down Expand Up @@ -145,6 +146,7 @@ void print_params(SDParams params) {
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
Expand Down Expand Up @@ -213,6 +215,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
printf(" --vae-on-cpu keep vae in cpu (for low vram)\n");
printf(" --clip-on-cpu keep clip in cpu (for low vram).\n");
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram).\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
Expand Down Expand Up @@ -457,6 +462,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
} else if (arg == "--vae-on-cpu") {
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
} else if (arg == "--diffusion-fa") {
params.diffusion_flash_attn = true; // can reduce MEM significantly
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
Expand Down Expand Up @@ -782,7 +789,8 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.diffusion_flash_attn);

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
Expand Down
6 changes: 6 additions & 0 deletions examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(TARGET sd-server)

add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PUBLIC cxx_std_11)
42 changes: 42 additions & 0 deletions examples/server/b64.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

//FROM
//https://stackoverflow.com/a/34571089/5155484

static const std::string b = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";//=
static std::string base64_encode(const std::string &in) {
std::string out;

int val=0, valb=-6;
for (uint8_t c : in) {
val = (val<<8) + c;
valb += 8;
while (valb>=0) {
out.push_back(b[(val>>valb)&0x3F]);
valb-=6;
}
}
if (valb>-6) out.push_back(b[((val<<8)>>(valb+8))&0x3F]);
while (out.size()%4) out.push_back('=');
return out;
}


static std::string base64_decode(const std::string &in) {

std::string out;

std::vector<int> T(256,-1);
for (int i=0; i<64; i++) T[b[i]] = i;

int val=0, valb=-8;
for (uint8_t c : in) {
if (T[c] == -1) break;
val = (val<<6) + T[c];
valb += 6;
if (valb>=0) {
out.push_back(char((val>>valb)&0xFF));
valb-=8;
}
}
return out;
}
Loading
Loading