Skip to content

Commit

Permalink
Merge pull request #555 from iceljc/features/add-image-mask-edit
Browse files Browse the repository at this point in the history
add image mask edit
  • Loading branch information
iceljc authored Jul 19, 2024
2 parents 37fbd6d + b1fb9d6 commit 5637a41
Show file tree
Hide file tree
Showing 27 changed files with 578 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@ public class IncomingMessageModel : MessageConfig
/// Postback message
/// </summary>
public PostbackMessageModel? Postback { get; set; }

public List<BotSharpFile> Files { get; set; } = new List<BotSharpFile>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ Task<IEnumerable<MessageFileModel>> GetChatFiles(string conversationId, string s

#region Image
Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text);
Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file);
Task<RoleDialogModel> VaryImage(string? provider, string? model, BotSharpFile image);
Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image);
Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image, BotSharpFile mask);
#endregion

#region Pdf
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace BotSharp.Abstraction.Files.Models;

public class InputMessageFiles
{
public List<BotSharpFile> Files { get; set; } = new List<BotSharpFile>();
public BotSharpFile? Mask { get; set; }
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ public interface IImageCompletion
Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message);

Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);

Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName);

Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName, Stream mask, string maskFileName);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
namespace BotSharp.Abstraction.Models;

public class MessageConfig
public class MessageConfig : InputMessageFiles
{
/// <summary>
/// Completion Provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,77 @@ public async Task<RoleDialogModel> GenerateImage(string? provider, string? model
return message;
}

public async Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file)
public async Task<RoleDialogModel> VaryImage(string? provider, string? model, BotSharpFile image)
{
if (string.IsNullOrWhiteSpace(file?.FileUrl) && string.IsNullOrWhiteSpace(file?.FileData))
if (string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData))
{
throw new ArgumentException($"Please fill in at least file url or file data!");
throw new ArgumentException($"Cannot find image url or data!");
}

var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
var bytes = await DownloadFile(file);
var bytes = await DownloadFile(image);
using var stream = new MemoryStream();
stream.Write(bytes, 0, bytes.Length);
stream.Position = 0;

var message = await completion.GetImageVariation(new Agent()
{
Id = Guid.Empty.ToString()
}, new RoleDialogModel(AgentRole.User, string.Empty), stream, file.FileName ?? string.Empty);
}, new RoleDialogModel(AgentRole.User, string.Empty), stream, image.FileName ?? string.Empty);

stream.Close();
return message;
}

public async Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image)
{
if (string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData))
{
throw new ArgumentException($"Cannot find image url or data!");
}

var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
var bytes = await DownloadFile(image);
using var stream = new MemoryStream();
stream.Write(bytes, 0, bytes.Length);
stream.Position = 0;

var message = await completion.GetImageEdits(new Agent()
{
Id = Guid.Empty.ToString()
}, new RoleDialogModel(AgentRole.User, text), stream, image.FileName ?? string.Empty);

stream.Close();
return message;
}

public async Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image, BotSharpFile mask)
{
if ((string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData)) ||
(string.IsNullOrWhiteSpace(mask?.FileUrl) && string.IsNullOrWhiteSpace(mask?.FileData)))
{
throw new ArgumentException($"Cannot find image/mask url or data");
}

var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
var imageBytes = await DownloadFile(image);
var maskBytes = await DownloadFile(mask);

using var imageStream = new MemoryStream();
imageStream.Write(imageBytes, 0, imageBytes.Length);
imageStream.Position = 0;

using var maskStream = new MemoryStream();
maskStream.Write(maskBytes, 0, maskBytes.Length);
maskStream.Position = 0;

var message = await completion.GetImageEdits(new Agent()
{
Id = Guid.Empty.ToString()
}, new RoleDialogModel(AgentRole.User, text), imageStream, image.FileName ?? string.Empty, maskStream, mask.FileName ?? string.Empty);

imageStream.Close();
maskStream.Close();
return message;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
return await textCompletion.GetCompletion(input.Text, Guid.Empty.ToString(), Guid.NewGuid().ToString());
}

#region Chat
[HttpPost("/instruct/chat-completion")]
public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
{
Expand All @@ -75,7 +76,9 @@ public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
});
return message.Content;
}
#endregion

#region Read image
[HttpPost("/instruct/multi-modal")]
public async Task<string> MultiModalCompletion([FromBody] IncomingMessageModel input)
{
Expand Down Expand Up @@ -105,7 +108,9 @@ public async Task<string> MultiModalCompletion([FromBody] IncomingMessageModel i
return error;
}
}
#endregion

#region Generate image
[HttpPost("/instruct/image-generation")]
public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingMessageModel input)
{
Expand All @@ -129,7 +134,9 @@ public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingM
return imageViewModel;
}
}
#endregion

#region Edit image
[HttpPost("/instruct/image-variation")]
public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMessageModel input)
{
Expand All @@ -140,12 +147,12 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe

try
{
var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
if (file == null)
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
if (image == null)
{
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
}
var message = await fileService.VarifyImage(input.Provider, input.Model, file);
var message = await fileService.VaryImage(input.Provider, input.Model, image);
imageViewModel.Content = message.Content;
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
return imageViewModel;
Expand All @@ -159,6 +166,67 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe
}
}

[HttpPost("/instruct/image-edit")]
public async Task<ImageGenerationViewModel> ImageEdit([FromBody] IncomingMessageModel input)
{
var fileService = _services.GetRequiredService<IBotSharpFileService>();
var state = _services.GetRequiredService<IConversationStateService>();
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
var imageViewModel = new ImageGenerationViewModel();

try
{
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
if (image == null)
{
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
}
var message = await fileService.EditImage(input.Provider, input.Model, input.Text, image);
imageViewModel.Content = message.Content;
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
return imageViewModel;
}
catch (Exception ex)
{
var error = $"Error in image edit. {ex.Message}";
_logger.LogError(error);
imageViewModel.Message = error;
return imageViewModel;
}
}

[HttpPost("/instruct/image-mask-edit")]
public async Task<ImageGenerationViewModel> ImageMaskEdit([FromBody] IncomingMessageModel input)
{
var fileService = _services.GetRequiredService<IBotSharpFileService>();
var state = _services.GetRequiredService<IConversationStateService>();
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
var imageViewModel = new ImageGenerationViewModel();

try
{
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
var mask = input.Mask;
if (image == null || mask == null)
{
return new ImageGenerationViewModel { Message = "Error! Cannot find an image or mask!" };
}
var message = await fileService.EditImage(input.Provider, input.Model, input.Text, image, mask);
imageViewModel.Content = message.Content;
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
return imageViewModel;
}
catch (Exception ex)
{
var error = $"Error in image mask edit. {ex.Message}";
_logger.LogError(error);
imageViewModel.Message = error;
return imageViewModel;
}
}
#endregion

#region Pdf
[HttpPost("/instruct/pdf-completion")]
public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessageModel input)
{
Expand All @@ -181,4 +249,5 @@ public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessa
return viewModel;
}
}
#endregion
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using OpenAI.Images;

namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;

public partial class ImageCompletionProvider
{
public async Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName)
{
var client = ProviderHelper.GetClient(Provider, _model, _services);
var (prompt, imageCount, options) = PrepareEditOptions(message);
var imageClient = client.GetImageClient(_model);

var response = imageClient.GenerateImageEdits(image, imageFileName, prompt, imageCount, options);
var images = response.Value;

var generatedImages = GetImageGenerations(images, options.ResponseFormat);
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
{
CurrentAgentId = agent.Id,
MessageId = message?.MessageId ?? string.Empty,
GeneratedImages = generatedImages
};

return await Task.FromResult(responseMessage);
}

public async Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message,
Stream image, string imageFileName, Stream mask, string maskFileName)
{
var client = ProviderHelper.GetClient(Provider, _model, _services);
var (prompt, imageCount, options) = PrepareEditOptions(message);
var imageClient = client.GetImageClient(_model);

var response = imageClient.GenerateImageEdits(image, imageFileName, prompt, mask, maskFileName, imageCount, options);
var images = response.Value;

var generatedImages = GetImageGenerations(images, options.ResponseFormat);
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
{
CurrentAgentId = agent.Id,
MessageId = message?.MessageId ?? string.Empty,
GeneratedImages = generatedImages
};

return await Task.FromResult(responseMessage);
}

private (string, int, ImageEditOptions) PrepareEditOptions(RoleDialogModel message)
{
var prompt = message?.Payload ?? message?.Content ?? string.Empty;

var state = _services.GetRequiredService<IConversationStateService>();
var size = GetImageSize(state.GetState("image_size"));
var format = GetImageFormat(state.GetState("image_format"));
var count = GetImageCount(state.GetState("image_count", "1"));

var options = new ImageEditOptions
{
Size = size,
ResponseFormat = format
};
return (prompt, count, options);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,13 @@ public partial class ImageCompletionProvider
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message)
{
var client = ProviderHelper.GetClient(Provider, _model, _services);
var (prompt, imageCount, options) = PrepareOptions(message);
var (prompt, imageCount, options) = PrepareGenerationOptions(message);
var imageClient = client.GetImageClient(_model);

var response = imageClient.GenerateImages(prompt, imageCount, options);
var values = response.Value;

var generatedImages = new List<ImageGeneration>();
foreach (var value in values)
{
if (value == null) continue;

var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
if (options.ResponseFormat == GeneratedImageFormat.Uri)
{
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
}
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
{
var base64Str = string.Empty;
var bytes = value?.ImageBytes?.ToArray();
if (!bytes.IsNullOrEmpty())
{
base64Str = Convert.ToBase64String(bytes);
}
generatedImage.ImageData = base64Str;
}

generatedImages.Add(generatedImage);
}
var images = response.Value;

var generatedImages = GetImageGenerations(images, options.ResponseFormat);
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
{
Expand All @@ -48,7 +25,7 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogMod
return await Task.FromResult(responseMessage);
}

private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message)
private (string, int, ImageGenerationOptions) PrepareGenerationOptions(RoleDialogModel message)
{
var prompt = message?.Payload ?? message?.Content ?? string.Empty;

Expand Down
Loading

0 comments on commit 5637a41

Please sign in to comment.