Skip to content

Commit

Permalink
Merge pull request #554 from iceljc/features/add-image-edit
Browse files Browse the repository at this point in the history
Features/add image edit
  • Loading branch information
iceljc authored Jul 19, 2024
2 parents 5f459a0 + 4311eaf commit 37fbd6d
Show file tree
Hide file tree
Showing 18 changed files with 457 additions and 675 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace BotSharp.Abstraction.MLTasks;

public interface IImageVariation
public interface IImageCompletion
{
/// <summary>
/// The LLM provider like Microsoft Azure, OpenAI, ClaudAI
Expand All @@ -15,5 +15,7 @@ public interface IImageVariation
/// <param name="model">deployment name</param>
void SetModelName(string model);

Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message);

Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
}
5 changes: 0 additions & 5 deletions src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageEdit.cs

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public partial class BotSharpFileService
{
public async Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text)
{
var completion = CompletionProvider.GetImageGeneration(_services, provider: provider ?? "openai", model: model ?? "dall-e-3");
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-3");
var message = await completion.GetImageGeneration(new Agent()
{
Id = Guid.Empty.ToString(),
Expand All @@ -21,7 +21,7 @@ public async Task<RoleDialogModel> VarifyImage(string? provider, string? model,
throw new ArgumentException($"Please fill in at least file url or file data!");
}

var completion = CompletionProvider.GetImageVariation(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
var bytes = await DownloadFile(file);
using var stream = new MemoryStream();
stream.Write(bytes, 0, bytes.Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace BotSharp.Core.Infrastructures;

public class CompletionProvider
{
public static object? GetCompletion(IServiceProvider services,
public static object GetCompletion(IServiceProvider services,
string? provider = null,
string? model = null,
AgentLlmConfig? agentConfig = null)
Expand All @@ -18,26 +18,20 @@ public class CompletionProvider

if (settings.Type == LlmModelType.Text)
{
return GetTextCompletion(services,
provider: provider,
model: model,
agentConfig: agentConfig);
return GetTextCompletion(services, provider: provider, model: model, agentConfig: agentConfig);
}
else if (settings.Type == LlmModelType.Embedding)
{
return GetTextEmbedding(services,
provider: provider,
model: model);
return GetTextEmbedding(services, provider: provider, model: model);
}
else if (settings.Type == LlmModelType.Chat)
else if (settings.Type == LlmModelType.Image)
{
return GetChatCompletion(services,
provider: provider,
model: model,
agentConfig: agentConfig);
return GetImageCompletion(services, provider: provider, model: model);
}
else
{
return GetChatCompletion(services, provider: provider, model: model, agentConfig: agentConfig);
}

return null;
}

public static IChatCompletion GetChatCompletion(IServiceProvider services,
Expand Down Expand Up @@ -82,36 +76,15 @@ public static ITextCompletion GetTextCompletion(IServiceProvider services,
return completer;
}

public static IImageGeneration GetImageGeneration(IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool imageGenerate = false,
AgentLlmConfig? agentConfig = null)
{
var completions = services.GetServices<IImageGeneration>();
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId,
imageGenerate: imageGenerate, agentConfig: agentConfig);

var completer = completions.FirstOrDefault(x => x.Provider == provider);
if (completer == null)
{
var logger = services.GetRequiredService<ILogger<CompletionProvider>>();
logger.LogError($"Can't resolve completion provider by {provider}");
}

completer?.SetModelName(model);
return completer;
}

public static IImageVariation GetImageVariation(IServiceProvider services,
public static IImageCompletion GetImageCompletion(IServiceProvider services,
string? provider = null,
string? model = null,
string? modelId = null,
bool imageGenerate = false)
{
var completions = services.GetServices<IImageVariation>();
(provider, model) = GetProviderAndModel(services, provider: provider, model: model, modelId: modelId, imageGenerate: imageGenerate);
var completions = services.GetServices<IImageCompletion>();
(provider, model) = GetProviderAndModel(services, provider: provider,
model: model, modelId: modelId, imageGenerate: imageGenerate);

var completer = completions.FirstOrDefault(x => x.Provider == provider);
if (completer == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe
try
{
var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
if (file == null)
{
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
}
var message = await fileService.VarifyImage(input.Provider, input.Model, file);
imageViewModel.Content = message.Content;
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
Expand Down
3 changes: 1 addition & 2 deletions src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
services.AddScoped<ITextCompletion, TextCompletionProvider>();
services.AddScoped<IChatCompletion, ChatCompletionProvider>();
services.AddScoped<ITextEmbedding, TextEmbeddingProvider>();
services.AddScoped<IImageGeneration, ImageGenerationProvider>();
services.AddScoped<IImageVariation, ImageVariationProvider>();
services.AddScoped<IImageCompletion, ImageCompletionProvider>();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using OpenAI.Images;

namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;

public partial class ImageCompletionProvider
{
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message)
{
var client = ProviderHelper.GetClient(Provider, _model, _services);
var (prompt, imageCount, options) = PrepareOptions(message);
var imageClient = client.GetImageClient(_model);

var response = imageClient.GenerateImages(prompt, imageCount, options);
var values = response.Value;

var generatedImages = new List<ImageGeneration>();
foreach (var value in values)
{
if (value == null) continue;

var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
if (options.ResponseFormat == GeneratedImageFormat.Uri)
{
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
}
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
{
var base64Str = string.Empty;
var bytes = value?.ImageBytes?.ToArray();
if (!bytes.IsNullOrEmpty())
{
base64Str = Convert.ToBase64String(bytes);
}
generatedImage.ImageData = base64Str;
}

generatedImages.Add(generatedImage);
}

var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
{
CurrentAgentId = agent.Id,
MessageId = message?.MessageId ?? string.Empty,
GeneratedImages = generatedImages
};

return await Task.FromResult(responseMessage);
}

private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message)
{
var prompt = message?.Payload ?? message?.Content ?? string.Empty;

var state = _services.GetRequiredService<IConversationStateService>();
var size = GetImageSize(state.GetState("image_size"));
var quality = GetImageQuality(state.GetState("image_quality"));
var style = GetImageStyle(state.GetState("image_style"));
var format = GetImageFormat(state.GetState("image_format"));
var count = GetImageCount(state.GetState("image_count", "1"));

var options = new ImageGenerationOptions
{
Size = size,
Quality = quality,
Style = style,
ResponseFormat = format
};
return (prompt, count, options);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using OpenAI.Images;

namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;

public partial class ImageCompletionProvider
{
public async Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName)
{
var client = ProviderHelper.GetClient(Provider, _model, _services);
var (imageCount, options) = PrepareOptions();
var imageClient = client.GetImageClient(_model);

var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options);
var values = response.Value;

var generatedImages = new List<ImageGeneration>();
foreach (var value in values)
{
if (value == null) continue;

var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
if (options.ResponseFormat == GeneratedImageFormat.Uri)
{
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
}
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
{
var base64Str = string.Empty;
var bytes = value?.ImageBytes?.ToArray();
if (!bytes.IsNullOrEmpty())
{
base64Str = Convert.ToBase64String(bytes);
}
generatedImage.ImageData = base64Str;
}

generatedImages.Add(generatedImage);
}

var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
{
CurrentAgentId = agent.Id,
MessageId = message?.MessageId ?? string.Empty,
GeneratedImages = generatedImages
};

return await Task.FromResult(responseMessage);
}

private (int, ImageVariationOptions) PrepareOptions()
{
var state = _services.GetRequiredService<IConversationStateService>();
var size = GetImageSize(state.GetState("image_size"));
var format = GetImageFormat(state.GetState("image_format"));
var count = GetImageCount(state.GetState("image_count", "1"));

var options = new ImageVariationOptions
{
Size = size,
ResponseFormat = format
};
return (count, options);
}
}
Loading

0 comments on commit 37fbd6d

Please sign in to comment.