Skip to content

Support for Flux Controls + Flex.2 #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) {
input_image_buffer};

sd_image_t* control_image = NULL;
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
if (params.control_image_path.size() > 0) {
printf("load image from '%s'\n", params.control_image_path.c_str());
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
Expand Down
38 changes: 34 additions & 4 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ namespace Flux {
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
std::vector<int> skip_layers = {},
SDVersion version = VERSION_FLUX) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
Expand All @@ -1007,14 +1008,38 @@ namespace Flux {
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];

if (c_concat != NULL) {
if (version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = process_img(ctx, masked);
mask = process_img(ctx, mask);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);

ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);

control = patchify(ctx, control, patch_size);

img = ggml_concat(ctx, img, control, 0);
}

if (ref_latents.size() > 0) {
Expand Down Expand Up @@ -1055,13 +1080,17 @@ namespace Flux {
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend), use_mask(use_mask) {
: GGMLRunner(backend), version(version), use_mask(use_mask) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;
Expand Down Expand Up @@ -1171,7 +1200,8 @@ namespace Flux {
pe,
mod_index_arange,
ref_latents,
skip_layers);
skip_layers,
version);

ggml_build_forward_expand(gf, out);

Expand Down
14 changes: 10 additions & 4 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,24 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,

__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
struct ggml_tensor* output,
float masked_value = 0.5f) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
float rescale_mx = mask->ne[0]/output->ne[0];
float rescale_my = mask->ne[1]/output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
ggml_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) {
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
Expand Down
9 changes: 7 additions & 2 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1685,10 +1685,15 @@ SDVersion ModelLoader::get_sd_version() {
}

if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if(input_block_weight.ne[0] == 196){
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}

Expand Down
12 changes: 9 additions & 3 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ enum SDVersion {
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
VERSION_COUNT,
};

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
return true;
}
return false;
Expand Down Expand Up @@ -70,7 +72,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
return true;
}
return false;
Expand All @@ -87,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
}

static inline bool sd_version_is_control(SDVersion version) {
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
}

static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
}

enum PMVersion {
Expand Down
Loading
Loading