-
Notifications
You must be signed in to change notification settings - Fork 6
/
binding.go
288 lines (240 loc) · 8.33 KB
/
binding.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
// Copyright (c) seasonjs. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.
package sd
import (
"github.com/ebitengine/purego"
"runtime"
"unsafe"
)
type LogLevel int
type RNGType int
type SampleMethod int
type Schedule int
type WType int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
)
const (
STD_DEFAULT_RNG RNGType = iota
CUDA_RNG
)
const (
EULER_A SampleMethod = iota
EULER
HEUN
DPM2
DPMPP2S_A
DPMPP2M
DPMPP2Mv2
LCM
N_SAMPLE_METHODS
)
const (
DEFAULT Schedule = iota
DISCRETE
KARRAS
N_SCHEDULES
)
const (
F32 WType = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
I8 = 16
I16 = 17
I32 = 18
COUNT = 19 // don't use this when specifying a type
)
type CStableDiffusionCtx struct {
ctx uintptr
}
type CUpScalerCtx struct {
ctx uintptr
}
type CLogCallback func(level LogLevel, text string)
type CStableDiffusion interface {
NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx
PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image
ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image
SetLogCallBack(cb CLogCallback)
GetSystemInfo() string
FreeCtx(ctx *CStableDiffusionCtx)
NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx
FreeUpscalerCtx(ctx *CUpScalerCtx)
UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image
Close() error
}
type cImage struct {
width uint32
height uint32
channel uint32
data uintptr
}
type cDarwinImage struct {
width uint32
height uint32
channel uint32
data *byte
}
type Image struct {
Width uint32
Height uint32
Channel uint32
Data []byte
}
type CStableDiffusionImpl struct {
libSd uintptr
sdGetSystemInfo func() string
newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wtype int, rngType int, schedule int) uintptr
sdSetLogCallback func(callback func(level int, text uintptr, data uintptr) uintptr, data uintptr)
txt2img func(ctx uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) uintptr
img2img func(ctx uintptr, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr
freeSdCtx func(ctx uintptr)
newUpscalerCtx func(esrganPath string, nThreads int, wtype int) uintptr
freeUpscalerCtx func(ctx uintptr)
upscale func(ctx uintptr, img uintptr, upscaleFactor uint32) uintptr
}
func NewCStableDiffusion(libraryPath string) (*CStableDiffusionImpl, error) {
libSd, err := openLibrary(libraryPath)
if err != nil {
return nil, err
}
impl := CStableDiffusionImpl{}
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_get_system_info")
purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx")
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_set_log_callback")
purego.RegisterLibFunc(&impl.txt2img, libSd, "txt2img")
purego.RegisterLibFunc(&impl.img2img, libSd, "img2img")
purego.RegisterLibFunc(&impl.freeSdCtx, libSd, "free_sd_ctx")
purego.RegisterLibFunc(&impl.newUpscalerCtx, libSd, "new_upscaler_ctx")
purego.RegisterLibFunc(&impl.freeUpscalerCtx, libSd, "free_upscaler_ctx")
purego.RegisterLibFunc(&impl.upscale, libSd, "upscale")
return &impl, nil
}
func (c *CStableDiffusionImpl) NewCtx(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType WType, rngType RNGType, schedule Schedule) *CStableDiffusionCtx {
ctx := c.newSdCtx(modelPath, vaePath, taesdPath, loraModelDir, vaeDecodeOnly, vaeTiling, freeParamsImmediately, nThreads, int(wType), int(rngType), int(schedule))
return &CStableDiffusionCtx{
ctx: ctx,
}
}
func (c *CStableDiffusionImpl) PredictImage(ctx *CStableDiffusionCtx, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, seed int64, batchCount int) []Image {
images := c.txt2img(ctx.ctx, prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, seed, batchCount)
return goImageSlice(images, batchCount)
}
func (c *CStableDiffusionImpl) ImagePredictImage(ctx *CStableDiffusionCtx, img Image, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod SampleMethod, sampleSteps int, strength float32, seed int64, batchCount int) []Image {
ci := cImage{
width: img.Width,
height: img.Height,
channel: img.Channel,
data: uintptr(unsafe.Pointer(&img.Data[0])),
}
images := c.img2img(ctx.ctx, uintptr(unsafe.Pointer(&ci)), prompt, negativePrompt, clipSkip, cfgScale, width, height, int(sampleMethod), sampleSteps, strength, seed, batchCount)
return goImageSlice(images, batchCount)
}
func (c *CStableDiffusionImpl) SetLogCallBack(cb CLogCallback) {
c.sdSetLogCallback(func(level int, text uintptr, data uintptr) uintptr {
cb(LogLevel(level), goString(text))
return 0
}, 0)
}
func (c *CStableDiffusionImpl) GetSystemInfo() string {
return c.sdGetSystemInfo()
}
func (c *CStableDiffusionImpl) FreeCtx(ctx *CStableDiffusionCtx) {
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
if ptr != nil {
c.freeSdCtx(ctx.ctx)
}
ctx = nil
runtime.GC()
}
func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType WType) *CUpScalerCtx {
ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
return &CUpScalerCtx{ctx: ctx}
}
func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
if ptr != nil {
c.freeUpscalerCtx(ctx.ctx)
}
ctx = nil
runtime.GC()
}
func (c *CStableDiffusionImpl) Close() error {
if c.libSd != 0 {
err := closeLibrary(c.libSd)
return err
}
return nil
}
func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, img Image, upscaleFactor uint32) Image {
ci := cImage{
width: img.Width,
height: img.Height,
channel: img.Channel,
data: uintptr(unsafe.Pointer(&img.Data[0])),
}
uptr := c.upscale(ctx.ctx, uintptr(unsafe.Pointer(&ci)), upscaleFactor)
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&uptr))
if ptr == nil {
return Image{}
}
cimg := (*cImage)(ptr)
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&cimg.data))
return Image{
Width: cimg.width,
Height: cimg.height,
Channel: cimg.channel,
Data: unsafe.Slice((*byte)(dataPtr), cimg.channel*cimg.width*cimg.height),
}
}
func goString(c uintptr) string {
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
if ptr == nil {
return ""
}
var length int
for {
if *(*byte)(unsafe.Add(ptr, uintptr(length))) == '\x00' {
break
}
length++
}
return unsafe.String((*byte)(ptr), length)
}
func goImageSlice(c uintptr, size int) []Image {
// We take the address and then dereference it to trick go vet from creating a possible misuse of unsafe.Pointer
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&c))
if ptr == nil {
return nil
}
img := (*cImage)(ptr)
goImages := make([]Image, 0, size)
imgSlice := unsafe.Slice(img, size)
for _, image := range imgSlice {
var gImg Image
gImg.Channel = image.channel
gImg.Width = image.width
gImg.Height = image.height
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&image.data))
if ptr != nil {
gImg.Data = unsafe.Slice((*byte)(dataPtr), image.channel*image.width*image.height)
}
goImages = append(goImages, gImg)
}
return goImages
}