Skip to content

Commit

Permalink
get sd-models (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
Robothy authored Oct 14, 2023
1 parent 7058bb9 commit 2d6c39e
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 29 deletions.
4 changes: 4 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ ext {
apacheHttpClient5Version = '5.2.1'
junitVersion = '5.9.1'
lombokVersion = '1.18.22'
mockServerVersion = "5.14.0"
}

dependencies {
Expand All @@ -40,8 +41,11 @@ dependencies {
testAnnotationProcessor "org.projectlombok:lombok:${lombokVersion}"
testImplementation "org.projectlombok:lombok:${lombokVersion}"

implementation "org.mock-server:mockserver-junit-jupiter-no-dependencies:${mockServerVersion}"

testImplementation platform("org.junit:junit-bom:${junitVersion}")
testImplementation 'org.junit.jupiter:junit-jupiter'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
}

java {
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/io/github/robothy/sdwebui/sdk/GetSdModels.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.github.robothy.sdwebui.sdk;

import io.github.robothy.sdwebui.sdk.models.results.SdModel;

import java.util.List;

public interface GetSdModels {

List<SdModel> getSdModels();

}
56 changes: 28 additions & 28 deletions src/main/java/io/github/robothy/sdwebui/sdk/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,34 @@
public class Main {

public static void main(String[] args) throws IOException {
SdWebui sd = SdWebui.create("http://localhost:7860");

Txt2ImgResult txt2ImgResult = sd.txt2Img(Txt2ImageOptions.builder()
.prompt("1dog")
.samplerName("DPM++ 2M Karras")
.steps(20)
.cfgScale(7)
.seed(32749528)
.build());

Path step1Path = Paths.get("docs/images/txt2img-dog.png");
Files.write(step1Path, Base64.getDecoder().decode(txt2ImgResult.getImages().get(0)));

Image2ImageResult image2ImageResult = sd.img2img(Image2ImageOptions.builder()
.prompt("1dog, glass")
.negativePrompt("bad fingers")
.samplerName("DPM++ 2M Karras")
.seed(32749528)
.cfgScale(7)
.denoisingStrength(0.3)
.initImages(List.of(txt2ImgResult.getImages().get(0)))
.build());


String base64img = image2ImageResult.getImages().get(0);

Path filepath = Paths.get("docs/images/img2img-dog.png");
Files.write(filepath, Base64.getDecoder().decode(base64img));
// SdWebui sd = SdWebui.create("http://localhost:7860");
//
// Txt2ImgResult txt2ImgResult = sd.txt2Img(Txt2ImageOptions.builder()
// .prompt("1dog")
// .samplerName("DPM++ 2M Karras")
// .steps(20)
// .cfgScale(7)
// .seed(32749528)
// .build());
//
// Path step1Path = Paths.get("docs/images/txt2img-dog.png");
// Files.write(step1Path, Base64.getDecoder().decode(txt2ImgResult.getImages().get(0)));
//
// Image2ImageResult image2ImageResult = sd.img2img(Image2ImageOptions.builder()
// .prompt("1dog, glass")
// .negativePrompt("bad fingers")
// .samplerName("DPM++ 2M Karras")
// .seed(32749528)
// .cfgScale(7)
// .denoisingStrength(0.3)
// .initImages(List.of(txt2ImgResult.getImages().get(0)))
// .build());
//
//
// String base64img = image2ImageResult.getImages().get(0);
//
// Path filepath = Paths.get("docs/images/img2img-dog.png");
// Files.write(filepath, Base64.getDecoder().decode(base64img));
}

}
2 changes: 1 addition & 1 deletion src/main/java/io/github/robothy/sdwebui/sdk/SdWebui.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import java.lang.reflect.Proxy;

public interface SdWebui extends SystemInfoFetcher, Txt2Image, Image2Image {
public interface SdWebui extends SystemInfoFetcher, Txt2Image, Image2Image, GetSdModels {

static SdWebui create(String endpoint) {
SdWebuiOptions options = new SdWebuiOptions(endpoint);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.github.robothy.sdwebui.sdk.models.results;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Getter;

@Getter
public class SdModel {

@JsonProperty("title")
private String title;

@JsonProperty("model_name")
private String modelName;

@JsonProperty("hash")
private String hash;

@JsonProperty("sha256")
private String sha256;

@JsonProperty("filename")
private String filename;

@JsonProperty("config")
private String config;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.github.robothy.sdwebui.sdk.services;

import io.github.robothy.sdwebui.sdk.GetSdModels;
import io.github.robothy.sdwebui.sdk.SdWebuiBeanContainer;
import io.github.robothy.sdwebui.sdk.models.results.SdModel;

import java.util.Arrays;
import java.util.List;

public class DefaultGetSdModelService implements GetSdModels {

private final SdWebuiBeanContainer container;

public DefaultGetSdModelService(SdWebuiBeanContainer container) {
this.container = container;
}

@Override
public List<SdModel> getSdModels() {
return Arrays.asList(this.container.getBean(CommonGetService.class).getData("/sdapi/v1/sd-models", SdModel[].class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ private void init() {
this.services.put(SystemInfo.class, new CacheableSystemInfoFetcher(sdWebuiOptions.getEndpoint(), this));
this.services.put(Txt2Image.class, new DefaultTxt2ImageService(this));
this.services.put(Image2Image.class, new DefaultImage2ImageService(this));
this.services.put(CommonGetService.class, new CommonGetService(this));
this.services.put(GetSdModels.class, new DefaultGetSdModelService(this));
}

}
36 changes: 36 additions & 0 deletions src/test/java/io/github/robothy/sdwebui/sdk/MockSdServer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.github.robothy.sdwebui.sdk;

import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.mockserver.client.MockServerClient;
import org.mockserver.netty.MockServer;

import java.net.http.HttpRequest;
import java.util.List;
import java.util.function.Predicate;

public class MockSdServer implements BeforeAllCallback, BeforeEachCallback, AfterEachCallback {

private int port;

public int getPort() {
return port;
}

@Override
public void afterEach(ExtensionContext context) throws Exception {

}

@Override
public void beforeAll(ExtensionContext context) throws Exception {

}

@Override
public void beforeEach(ExtensionContext context) throws Exception {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.github.robothy.sdwebui.sdk.models.results;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.robothy.sdwebui.sdk.SdWebui;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockserver.client.MockServerClient;
import org.mockserver.junit.jupiter.MockServerExtension;
import org.mockserver.model.HttpRequest;
import org.mockserver.model.HttpResponse;

import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

@ExtendWith(MockServerExtension.class)
class SdModelTest {

private static final String JSON = "{\n" +
" \"title\": \"v1-5-pruned-emaonly.ckpt [cc6cb27103]\",\n" +
" \"model_name\": \"v1-5-pruned-emaonly\",\n" +
" \"hash\": \"cc6cb27103\",\n" +
" \"sha256\": \"cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516\",\n" +
" \"filename\": \"C:\\\\Users\\\\admin\\\\PythonProjects\\\\stable-diffusion-webui\\\\models\\\\Stable-diffusion\\\\v1-5-pruned-emaonly.ckpt\",\n" +
" \"config\": null\n" +
"}";

@Test
void testSerialization() throws JsonProcessingException {
SdModel sdModel = new ObjectMapper().readValue(JSON, SdModel.class);
assertEquals("v1-5-pruned-emaonly.ckpt [cc6cb27103]", sdModel.getTitle());
assertEquals("v1-5-pruned-emaonly", sdModel.getModelName());
assertEquals("cc6cb27103", sdModel.getHash());
assertEquals("cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516", sdModel.getSha256());
assertEquals("C:\\Users\\admin\\PythonProjects\\stable-diffusion-webui\\models\\Stable-diffusion\\v1-5-pruned-emaonly.ckpt", sdModel.getFilename());
assertNull(sdModel.getConfig());
}

@Test
void getGetSdModels(MockServerClient client) {
client.when(new HttpRequest().withMethod("GET").withPath("/sdapi/v1/sd-models"))
.respond(new HttpResponse().withStatusCode(200).withBody(" [" +
"{\n" +
" \"title\": \"MoyouArtificial_v10502g.safetensors [b6c1edcbe9]\",\n" +
" \"model_name\": \"MoyouArtificial_v10502g\",\n" +
" \"hash\": \"b6c1edcbe9\",\n" +
" \"sha256\": \"b6c1edcbe9ef9fa3d38c3787d351211a775e6254b832234d97042800f33345d1\",\n" +
" \"filename\": \"C:\\\\Users\\\\admin\\\\PythonProjects\\\\stable-diffusion-webui\\\\models\\\\Stable-diffusion\\\\MoyouArtificial_v10502g.safetensors\",\n" +
" \"config\": null\n" +
" },\n" +
" {\n" +
" \"title\": \"v1-5-pruned-emaonly.ckpt [cc6cb27103]\",\n" +
" \"model_name\": \"v1-5-pruned-emaonly\",\n" +
" \"hash\": \"cc6cb27103\",\n" +
" \"sha256\": \"cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516\",\n" +
" \"filename\": \"C:\\\\Users\\\\admin\\\\PythonProjects\\\\stable-diffusion-webui\\\\models\\\\Stable-diffusion\\\\v1-5-pruned-emaonly.ckpt\",\n" +
" \"config\": null\n" +
" }\n" +
"]"));
List<SdModel> sdModels = SdWebui.create("http://localhost:" + client.remoteAddress().getPort()).getSdModels();
assertEquals(2, sdModels.size());
assertEquals("MoyouArtificial_v10502g.safetensors [b6c1edcbe9]", sdModels.get(0).getTitle());
assertEquals("v1-5-pruned-emaonly.ckpt [cc6cb27103]", sdModels.get(1).getTitle());
assertEquals("MoyouArtificial_v10502g", sdModels.get(0).getModelName());
assertEquals("v1-5-pruned-emaonly", sdModels.get(1).getModelName());
assertEquals("b6c1edcbe9", sdModels.get(0).getHash());
assertEquals("cc6cb27103", sdModels.get(1).getHash());
assertEquals("b6c1edcbe9ef9fa3d38c3787d351211a775e6254b832234d97042800f33345d1", sdModels.get(0).getSha256());
assertEquals("cc6cb27103417325ff94f52b7a5d2dde45a7515b25c255d8e396c90014281516", sdModels.get(1).getSha256());
assertEquals("C:\\Users\\admin\\PythonProjects\\stable-diffusion-webui\\models\\Stable-diffusion\\MoyouArtificial_v10502g.safetensors", sdModels.get(0).getFilename());
assertEquals("C:\\Users\\admin\\PythonProjects\\stable-diffusion-webui\\models\\Stable-diffusion\\v1-5-pruned-emaonly.ckpt", sdModels.get(1).getFilename());
assertNull(sdModels.get(0).getConfig());
assertNull(sdModels.get(1).getConfig());
}

}

0 comments on commit 2d6c39e

Please sign in to comment.