Skip to content

Commit

Permalink
refactor: sd estimate
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Nov 29, 2024
1 parent ca00e4f commit 50baaf4
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 45 deletions.
44 changes: 32 additions & 12 deletions file_architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ type (
//
// Only used when Architecture is "diffusion".
DiffusionArchitecture string `json:"diffusionArchitecture,omitempty"`
// DiffusionTransformer indicates whether the diffusion model is a diffusion transformer or not.
//
DiffusionTransformer bool `json:"diffusionTransformer,omitempty"`
// DiffusionConditioners is the list of diffusion conditioners.
//
// Only used when Architecture is "diffusion".
Expand Down Expand Up @@ -248,10 +251,12 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) {
sd3_5MediumKey = "model.diffusion_model.joint_blocks.23.x_block.attn.ln_k.weight" // SD 3.5 Medium
sd3_5LargeKey = "model.diffusion_model.joint_blocks.37.x_block.attn.ln_k.weight" // SD 3.5 Large

fluxKey = "model.diffusion_model.double_blocks.18.txt_attn.proj.weight" // FLUX.1-schnell
fluxKey2 = "double_blocks.18.txt_attn.proj.weight"
fluxDevKey = "model.diffusion_model.guidance_in.in_layer.weight" // FLUX.1-dev
fluxDevKey2 = "guidance_in.in_layer.weight"
fluxKey = "model.diffusion_model.double_blocks.7.txt_attn.proj.weight"
fluxKey2 = "double_blocks.7.txt_attn.proj.weight"
fluxDevAndLiteKey = "model.diffusion_model.guidance_in.in_layer.weight" // FLUX.1-dev / FLUX.1-lite
fluxDevAndLiteKey2 = "guidance_in.in_layer.weight"
fluxDevAndSchnellKey = "model.diffusion_model.double_blocks.18.txt_attn.proj.weight" // FLUX.1-dev / FLUX.1-schnell
fluxDevAndSchnellKey2 = "double_blocks.18.txt_attn.proj.weight"

// Conditioner

Expand All @@ -271,8 +276,10 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) {
sd3_5LargeKey,
fluxKey,
fluxKey2,
fluxDevKey,
fluxDevKey2,
fluxDevAndLiteKey,
fluxDevAndLiteKey2,
fluxDevAndSchnellKey,
fluxDevAndSchnellKey2,

openAiClipVitL14Key,
openClipVitH14Key,
Expand Down Expand Up @@ -302,17 +309,30 @@ func (gf *GGUFFile) diffuserArchitecture() (ga GGUFArchitecture) {
ga.DiffusionArchitecture = "Stable Diffusion 3.5 Large"
}
}
ga.DiffusionTransformer = true
}
if _, ok := tis[fluxKey]; ok {
ga.DiffusionArchitecture = "FLUX.1-schnell"
if _, ok = tis[fluxDevKey]; ok {
ga.DiffusionArchitecture = "FLUX.1-dev"
if _, ok = tis[fluxDevAndLiteKey]; ok {
if _, ok = tis[fluxDevAndSchnellKey]; ok {
ga.DiffusionArchitecture = "FLUX.1-dev"
} else {
ga.DiffusionArchitecture = "FLUX.1-lite"
}
} else {
ga.DiffusionArchitecture = "FLUX.1-schnell"
}
ga.DiffusionTransformer = true
} else if _, ok := tis[fluxKey2]; ok {
ga.DiffusionArchitecture = "FLUX.1-schnell"
if _, ok = tis[fluxDevKey2]; ok {
ga.DiffusionArchitecture = "FLUX.1-dev"
if _, ok = tis[fluxDevAndLiteKey2]; ok {
if _, ok = tis[fluxDevAndSchnellKey2]; ok {
ga.DiffusionArchitecture = "FLUX.1-dev"
} else {
ga.DiffusionArchitecture = "FLUX.1-lite"
}
} else {
ga.DiffusionArchitecture = "FLUX.1-schnell"
}
ga.DiffusionTransformer = true
}

if ti, ok := tis[openAiClipVitL14Key]; ok {
Expand Down
138 changes: 105 additions & 33 deletions file_estimate__stablediffusioncpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/gpustack/gguf-parser-go/util/ptr"
"github.com/gpustack/gguf-parser-go/util/stringx"
"strings"
)

// Types for StableDiffusionCpp estimation.
Expand Down Expand Up @@ -102,7 +103,7 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption)
o.SDCOffloadAutoencoder = ptr.To(true)
}
if o.SDCAutoencoderTiling == nil {
o.SDCAutoencoderTiling = ptr.To(true)
o.SDCAutoencoderTiling = ptr.To(false)
}

// Devices.
Expand All @@ -114,7 +115,11 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption)
e.Architecture = normalizeArchitecture(a.DiffusionArchitecture)

// Flash attention.
e.FlashAttention = false // TODO: Implement this.
if o.FlashAttention && !strings.HasPrefix(a.DiffusionArchitecture, "Stable Diffusion 3") {
// NB(thxCode): Stable Diffusion 3 doesn't support flash attention yet,
// see https://github.com/leejet/stable-diffusion.cpp/pull/386.
e.FlashAttention = true
}

// Distributable.
e.Distributable = false // TODO: Implement this.
Expand Down Expand Up @@ -160,14 +165,14 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption)
// Footprint
{
// Bootstrap.
e.Devices[0].Footprint = GGUFBytesScalar(5*1024*1024) /* model load */ + (gf.Size - gf.ModelSize) /* metadata */
e.Devices[0].Footprint = GGUFBytesScalar(10*1024*1024) /* model load */ + (gf.Size - gf.ModelSize) /* metadata */

// Output buffer,
// see
// TODO: Implement this.
}

var cdLs, aeLs, cpLs GGUFLayerTensorInfos
var cdLs, aeLs, mdLs GGUFLayerTensorInfos
{
var tis GGUFTensorInfos
tis = gf.TensorInfos.Search(regexp.MustCompile(`^cond_stage_model\..*`))
Expand All @@ -176,60 +181,127 @@ func (gf *GGUFFile) EstimateStableDiffusionCppRun(opts ...GGUFRunEstimateOption)
if len(cdLs) != len(e.Conditioners) {
panic("conditioners' layers count mismatch")
}
// NB(thxCode): resort the layers to match the order of the conditioners.
cdLsSorted := make([]IGGUFTensorInfos, len(cdLs))
cdLsSorted[0] = cdLs[len(cdLs)-1]
for i := 1; i < len(cdLs); i++ {
cdLsSorted[i] = cdLs[i-1]
}
cdLs = cdLsSorted
}
tis = gf.TensorInfos.Search(regexp.MustCompile(`^first_stage_model\..*`))
if len(tis) != 0 {
aeLs = tis.Layers()
}
tis = gf.TensorInfos.Search(regexp.MustCompile(`^model\.diffusion_model\..*`))
if len(tis) != 0 {
cpLs = tis.Layers()
mdLs = tis.Layers()
} else {
cpLs = gf.TensorInfos.Layers()
mdLs = gf.TensorInfos.Layers()
}
}

var cdDevIdx, aeDevIdx, mdDevIdx int
{
if *o.SDCOffloadConditioner {
cdDevIdx = 1
}
if *o.SDCOffloadAutoencoder {
aeDevIdx = 1
}
mdDevIdx = 1
}

// Weight & Parameter.
{
// Conditioners.
if cdLs != nil {
d := 0
if *o.SDCOffloadConditioner {
d = 1
}
for i := range cdLs {
e.Conditioners[i].Devices[d].Weight = GGUFBytesScalar(cdLs[i].Bytes())
e.Conditioners[i].Devices[d].Parameter = GGUFParametersScalar(cdLs[i].Elements())
}
for i := range cdLs {
e.Conditioners[i].Devices[cdDevIdx].Weight = GGUFBytesScalar(cdLs[i].Bytes())
e.Conditioners[i].Devices[cdDevIdx].Parameter = GGUFParametersScalar(cdLs[i].Elements())
}

// Autoencoder.
if aeLs != nil {
d := 0
if *o.SDCOffloadAutoencoder {
d = 1
}
e.Autoencoder.Devices[d].Weight = GGUFBytesScalar(aeLs.Bytes())
e.Autoencoder.Devices[d].Parameter = GGUFParametersScalar(aeLs.Elements())
e.Autoencoder.Devices[aeDevIdx].Weight = GGUFBytesScalar(aeLs.Bytes())
e.Autoencoder.Devices[aeDevIdx].Parameter = GGUFParametersScalar(aeLs.Elements())
}

// Compute.
if cpLs != nil {
e.Devices[1].Weight = GGUFBytesScalar(cpLs.Bytes())
e.Devices[1].Parameter = GGUFParametersScalar(cpLs.Elements())
// Model.
if mdLs != nil {
e.Devices[mdDevIdx].Weight = GGUFBytesScalar(mdLs.Bytes())
e.Devices[mdDevIdx].Parameter = GGUFParametersScalar(mdLs.Elements())
}
}

// Computation.
{
// TODO: Implement this.
// Bootstrap, compute metadata,
// see https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L16135-L16136.
cm := GGMLTensorOverhead()*GGMLComputationGraphNodesMaximum +
GGMLComputationGraphOverhead(GGMLComputationGraphNodesMaximum, false)
e.Devices[0].Computation = GGUFBytesScalar(cm)

// Work context,
// see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1467-L1481,
// https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1572-L1586,
// https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1675-L1679,
// https://github.com/thxCode/stable-diffusion.cpp/blob/78629d6340f763a8fe14372e0ba3ace73526a265/stable-diffusion.cpp#L2185-L2189,
// https://github.com/thxCode/stable-diffusion.cpp/blob/78629d6340f763a8fe14372e0ba3ace73526a265/stable-diffusion.cpp#L2270-L2274.
//
{
wcSize := GGUFBytesScalar(50 * 1024 * 1024)
wcSize += GGUFBytesScalar(*o.SDCWidth * *o.SDCHeight * 3 * 4 /* sizeof(float) */ * 2) // RGB
e.Devices[0].Computation += wcSize
}

// Conditioner learned conditions,
// see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L388-L391,
// https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L758-L766,
// https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/conditioner.hpp#L1083-L1085.
switch {
case strings.HasPrefix(a.DiffusionArchitecture, "FLUX"):
for i := range cdLs {
ds := []uint64{1}
switch i {
case 0:
ds = []uint64{768, 77}
case 1:
ds = []uint64{4096, 256}
}
cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner
e.Conditioners[i].Devices[cdDevIdx].Computation += cds
}
case strings.HasPrefix(a.DiffusionArchitecture, "Stable Diffusion 3"):
for i := range cdLs {
ds := []uint64{1}
switch i {
case 0:
ds = []uint64{768, 77}
case 1:
ds = []uint64{1280, 77}
case 2:
ds = []uint64{4096, 77}
}
cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner
e.Conditioners[i].Devices[cdDevIdx].Computation += cds
}
default:
for i := range cdLs {
ds := []uint64{1}
switch i {
case 0:
ds = []uint64{768, 77}
if strings.HasSuffix(a.DiffusionArchitecture, "Refiner") {
ds = []uint64{1280, 77}
}
case 1:
ds = []uint64{1280, 77}
}
cds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf(ds)) * 2 // include unconditioner
e.Conditioners[i].Devices[cdDevIdx].Computation += cds
}
}

// Diffusion nosier,
// see https://github.com/leejet/stable-diffusion.cpp/blob/4570715727f35e5a07a76796d823824c8f42206c/stable-diffusion.cpp#L1361.
{
mds := GGUFBytesScalar(GGMLTypeF32.RowSizeOf([]uint64{uint64(*o.SDCWidth / 8), uint64(*o.SDCHeight / 8), 16, 1}))
e.Devices[mdDevIdx].Computation += mds
}

}

return e
Expand Down

0 comments on commit 50baaf4

Please sign in to comment.