From 50baaf4783682aa88252c7badec4490d246b5306 Mon Sep 17 00:00:00 2001 From: thxCode Date: Fri, 29 Nov 2024 13:24:44 +0800 Subject: [PATCH] refactor: sd estimate Signed-off-by: thxCode --- file_architecture.go | 44 ++++++--- file_estimate__stablediffusioncpp.go | 138 ++++++++++++++++++++------- 2 files changed, 137 insertions(+), 45 deletions(-) diff --git a/file_architecture.go b/file_architecture.go index fb6e341..cd43118 100644 --- a/file_architecture.go +++ b/file_architecture.go @@ -130,6 +130,9 @@ type ( // // Only used when Architecture is "diffusion". DiffusionArchitecture string `json:"diffusionArchitecture,omitempty"` + // DiffusionTransformer indicates whether the diffusion model is a diffusion transformer or not. + // + DiffusionTransformer bool `json:"diffusionTransformer,omitempty"` // DiffusionConditioners is the list of diffusion conditioners. // // Only used when Architecture is "diffusion". @@ -248,10 +251,12 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) { sd3_5MediumKey = "model.diffusion_model.joint_blocks.23.x_block.attn.ln_k.weight" // SD 3.5 Medium sd3_5LargeKey = "model.diffusion_model.joint_blocks.37.x_block.attn.ln_k.weight" // SD 3.5 Large - fluxKey = "model.diffusion_model.double_blocks.18.txt_attn.proj.weight" // FLUX.1-schnell - fluxKey2 = "double_blocks.18.txt_attn.proj.weight" - fluxDevKey = "model.diffusion_model.guidance_in.in_layer.weight" // FLUX.1-dev - fluxDevKey2 = "guidance_in.in_layer.weight" + fluxKey = "model.diffusion_model.double_blocks.7.txt_attn.proj.weight" + fluxKey2 = "double_blocks.7.txt_attn.proj.weight" + fluxDevAndLiteKey = "model.diffusion_model.guidance_in.in_layer.weight" // FLUX.1-dev / FLUX.1-lite + fluxDevAndLiteKey2 = "guidance_in.in_layer.weight" + fluxDevAndSchnellKey = "model.diffusion_model.double_blocks.18.txt_attn.proj.weight" // FLUX.1-dev / FLUX.1-schnell + fluxDevAndSchnellKey2 = "double_blocks.18.txt_attn.proj.weight" // Conditioner @@ -271,8 +276,10 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) { sd3_5LargeKey, fluxKey, fluxKey2, - fluxDevKey, - fluxDevKey2, + fluxDevAndLiteKey, + fluxDevAndLiteKey2, + fluxDevAndSchnellKey, + fluxDevAndSchnellKey2, openAiClipVitL14Key, openClipVitH14Key, @@ -302,17 +309,30 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) { ga.DiffusionArchitecture = "Stable Diffusion 3.5 Large" } } + ga.DiffusionTransformer = true } if _, ok := tis[fluxKey]; ok { - ga.DiffusionArchitecture = "FLUX.1-schnell" - if _, ok = tis[fluxDevKey]; ok { - ga.DiffusionArchitecture = "FLUX.1-dev" + if _, ok = tis[fluxDevAndLiteKey]; ok { + if _, ok = tis[fluxDevAndSchnellKey]; ok { + ga.DiffusionArchitecture = "FLUX.1-dev" + } else { + ga.DiffusionArchitecture = "FLUX.1-lite" + } + } else { + ga.DiffusionArchitecture = "FLUX.1-schnell" } + ga.DiffusionTransformer = true } else if _, ok := tis[fluxKey2]; ok { - ga.DiffusionArchitecture = "FLUX.1-schnell" - if _, ok = tis[fluxDevKey2]; ok { - ga.DiffusionArchitecture = "FLUX.1-dev" + if _, ok = tis[fluxDevAndLiteKey2]; ok { + if _, ok = tis[fluxDevAndSchnellKey2]; ok { + ga.DiffusionArchitecture = "FLUX.1-dev" + } else { + ga.DiffusionArchitecture = "FLUX.1-lite" + } + } else { + ga.DiffusionArchitecture = "FLUX.1-schnell" } + ga.DiffusionTransformer = true } if ti, ok := tis[openAiClipVitL14Key]; ok { diff --git a/file_estimate__stablediffusioncpp.go b/file_estimate__stablediffusioncpp.go index 770c3c8..962e215 100644 --- a/file_estimate__stablediffusioncpp.go +++ b/file_estimate__stablediffusioncpp.go @@ -5,6 +5,7 @@ import ( "github.com/gpustack/gguf-parser-go/util/ptr" "github.com/gpustack/gguf-parser-go/util/stringx" + "strings" ) // Types for StableDiffusionCpp estimation. @@ -102,7 +103,7 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption) o.SDCOffloadAutoencoder = ptr.To(true) } if o.SDCAutoencoderTiling == nil { - o.SDCAutoencoderTiling = ptr.To(true) + o.SDCAutoencoderTiling = ptr.To(false) } // Devices. @@ -114,7 +115,11 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption) e.Architecture = normalizeArchitecture(a.DiffusionArchitecture) // Flash attention. - e.FlashAttention = false // TODO: Implement this. + if o.FlashAttention && !strings.HasPrefix(a.DiffusionArchitecture, "Stable Diffusion 3") { + // NB(thxCode): Stable Diffusion 3 doesn't support flash attention yet, + // see https://github.com/leejet/stable-diffusion.cpp/pull/386. + e.FlashAttention = true + } // Distributable. e.Distributable = false // TODO: Implement this. @@ -160,14 +165,14 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption) // Footprint { // Bootstrap. - e.Devices[0].Footprint = GGUFBytesScalar(5*1024*1024) /* model load */ + (gf.Size - gf.ModelSize) /* metadata */ + e.Devices[0].Footprint = GGUFBytesScalar(10*1024*1024) /* model load */ + (gf.Size - gf.ModelSize) /* metadata */ // Output buffer, // see // TODO: Implement this. } - var cdLs, aeLs, cpLs GGUFLayerTensorInfos + var cdLs, aeLs, mdLs GGUFLayerTensorInfos { var tis GGUFTensorInfos tis = gf.TensorInfos.Search(regexp.MustCompile(`^cond_stage_model\..*`)) @@ -176,13 +181,6 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption) if len(cdLs) != len(e.Conditioners) { panic("conditioners' layers count mismatch") } - // NB(thxCode): resort the layers to match the order of the conditioners. - cdLsSorted := make([]IGGUFTensorInfos, len(cdLs)) - cdLsSorted[0] = cdLs[len(cdLs)-1] - for i := 1; i < len(cdLs); i++ { - cdLsSorted[i] = cdLs[i-1] - } - cdLs = cdLsSorted } tis = gf.TensorInfos.Search(regexp.MustCompile(`^first_stage_model\..*`)) if len(tis) != 0 { @@ -190,46 +188,120 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption) } tis = gf.TensorInfos.Search(regexp.MustCompile(`^model\.diffusion_model\..*`)) if len(tis) != 0 { - cpLs = tis.Layers() + mdLs = tis.Layers() } else { - cpLs = gf.TensorInfos.Layers() + mdLs = gf.TensorInfos.Layers() + } + } + + var cdDevIdx, aeDevIdx, mdDevIdx int + { + if *o.SDCOffloadConditioner { + cdDevIdx = 1 + } + if *o.SDCOffloadAutoencoder { + aeDevIdx = 1 } + mdDevIdx = 1 } // Weight & Parameter. { // Conditioners. - if cdLs != nil { - d := 0 - if *o.SDCOffloadConditioner { - d = 1 - } - for i := range cdLs { - e.Conditioners[i].Devices[d].Weight = GGUFBytesScalar(cdLs[i].Bytes()) - e.Conditioners[i].Devices[d].Parameter = GGUFParametersScalar(cdLs[i].Elements()) - } + for i := range cdLs { + e.Conditioners[i].Devices[cdDevIdx].Weight = GGUFBytesScalar(cdLs[i].Bytes()) + e.Conditioners[i].Devices[cdDevIdx].Parameter = GGUFParametersScalar(cdLs[i].Elements()) } // Autoencoder. if aeLs != nil { - d := 0 - if *o.SDCOffloadAutoencoder { - d = 1 - } - e.Autoencoder.Devices[d].Weight = GGUFBytesScalar(aeLs.Bytes()) - e.Autoencoder.Devices[d].Parameter = GGUFParametersScalar(aeLs.Elements()) + e.Autoencoder.Devices[aeDevIdx].Weight = GGUFBytesScalar(aeLs.Bytes()) + e.Autoencoder.Devices[aeDevIdx].Parameter = GGUFParametersScalar(aeLs.Elements()) } - // Compute. - if cpLs != nil { - e.Devices[1].Weight = GGUFBytesScalar(cpLs.Bytes()) - e.Devices[1].Parameter = GGUFParametersScalar(cpLs.Elements()) + // Model. + if mdLs != nil { + e.Devices[mdDevIdx].Weight = GGUFBytesScalar(mdLs.Bytes()) + e.Devices[mdDevIdx].Parameter = GGUFParametersScalar(mdLs.Elements()) } } // Computation. { - // TODO: Implement this. + // Bootstrap, compute metadata, + // see https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L16135-L16136. + cm := GGMLTensorOverhead()*GGMLComputationGraphNodesMaximum + + GGMLComputationGraphOverhead(GGMLComputationGraphNodesMaximum, false) + e.Devices[0].Computation = GGUFBytesScalar(cm) + + // Work context, + // see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1467-L1481, + // https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1572-L1586, + // https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1675-L1679, + // https://github.com/thxCode/stable-diffusion.cpp/blob/78629d6340f763a8fe14372e0ba3ace73526a265/stable-diffusion.cpp#L2185-L2189, + // https://github.com/thxCode/stable-diffusion.cpp/blob/78629d6340f763a8fe14372e0ba3ace73526a265/stable-diffusion.cpp#L2270-L2274. + // + { + wcSize := GGUFBytesScalar(50 * 1024 * 1024) + wcSize += GGUFBytesScalar(*o.SDCWidth * *o.SDCHeight * 3 * 4 /* sizeof(float) */ * 2) // RGB + e.Devices[0].Computation += wcSize + } + + // Conditioner learned conditions, + // see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L388-L391, + // https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L758-L766, + // https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L1083-L1085. + switch { + case strings.HasPrefix(a.DiffusionArchitecture, "FLUX"): + for i := range cdLs { + ds := []uint64{1} + switch i { + case 0: + ds = []uint64{768, 77} + case 1: + ds = []uint64{4096, 256} + } + cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner + e.Conditioners[i].Devices[cdDevIdx].Computation += cds + } + case strings.HasPrefix(a.DiffusionArchitecture, "Stable Diffusion 3"): + for i := range cdLs { + ds := []uint64{1} + switch i { + case 0: + ds = []uint64{768, 77} + case 1: + ds = []uint64{1280, 77} + case 2: + ds = []uint64{4096, 77} + } + cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner + e.Conditioners[i].Devices[cdDevIdx].Computation += cds + } + default: + for i := range cdLs { + ds := []uint64{1} + switch i { + case 0: + ds = []uint64{768, 77} + if strings.HasSuffix(a.DiffusionArchitecture, "Refiner") { + ds = []uint64{1280, 77} + } + case 1: + ds = []uint64{1280, 77} + } + cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner + e.Conditioners[i].Devices[cdDevIdx].Computation += cds + } + } + + // Diffusion nosier, + // see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1361. + { + mds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf([]uint64{uint64(*o.SDCWidth / 8), uint64(*o.SDCHeight / 8), 16, 1})) + e.Devices[mdDevIdx].Computation += mds + } + } return e