From 203299d976c454acd0e1fd349bec4beda3b2f94c Mon Sep 17 00:00:00 2001 From: Thiago dos Santos Hora Date: Mon, 4 Nov 2024 16:36:14 +0100 Subject: [PATCH] [OPIK-309] Create prompt endpoint (#531) * [OPIK-309] Create prompt endpoint * Add logic to create first version when specified * Address PR review comments --- .../main/java/com/comet/opik/api/Prompt.java | 56 ++++ .../com/comet/opik/api/PromptVersion.java | 58 ++++ .../error/EntityAlreadyExistsException.java | 10 +- .../resources/v1/priv/ProjectsResource.java | 2 +- .../api/resources/v1/priv/PromptResource.java | 68 ++++ .../com/comet/opik/domain/CommitUtils.java | 14 + .../com/comet/opik/domain/DatasetDAO.java | 3 +- .../com/comet/opik/domain/DatasetService.java | 3 +- .../opik/domain/EntityConstraintHandler.java | 60 ++++ .../java/com/comet/opik/domain/PromptDAO.java | 25 ++ .../com/comet/opik/domain/PromptService.java | 118 +++++++ .../comet/opik/domain/PromptVersionDAO.java | 25 ++ ..._increate_prompt_version_commit_length.sql | 6 + .../v1/events/DatasetEventListenerTest.java | 1 - .../v1/priv/DatasetExperimentE2ETest.java | 2 +- .../v1/priv/DatasetsResourceTest.java | 3 +- .../resources/v1/priv/PromptResourceTest.java | 317 ++++++++++++++++++ .../domain/EntityConstraintHandlerTest.java | 88 +++++ 18 files changed, 852 insertions(+), 7 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java new file mode 100644 index 0000000000..a726fb4f4e --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -0,0 +1,56 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.Pattern; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.UUID; + +import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record Prompt( + @JsonView( { + Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + @JsonView({Prompt.View.Public.class, + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({ + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + + public static class View { + public static class Write { + } + + public static class Public { + } + } + + public record PromptPage( + @JsonView( { + Project.View.Public.class}) int page, + @JsonView({Project.View.Public.class}) int size, + @JsonView({Project.View.Public.class}) long total, + @JsonView({Project.View.Public.class}) List content) + implements + Page{ + + public static Prompt.PromptPage empty(int page) { + return new Prompt.PromptPage(page, 0, 0, List.of()); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java new file mode 100644 index 0000000000..cef10c243a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -0,0 +1,58 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersion( + @JsonView( { + PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + @JsonView({ + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set variables, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){ + + public static class View { + public static class Public { + } + + public static class Detail { + } + } + + @Builder + public record PromptVersionPage( + @JsonView( { + PromptVersion.View.Public.class}) int page, + @JsonView({PromptVersion.View.Public.class}) int size, + @JsonView({PromptVersion.View.Public.class}) long total, + @JsonView({PromptVersion.View.Public.class}) List content) + implements + Page{ + + public static PromptVersion.PromptVersionPage empty(int page) { + return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); + } + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java index df74e5e107..b8a16bddd4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java @@ -6,6 +6,14 @@ public class EntityAlreadyExistsException extends ClientErrorException { public EntityAlreadyExistsException(ErrorMessage response) { - super(Response.status(Response.Status.CONFLICT).entity(response).build()); + this((Object) response); + } + + public EntityAlreadyExistsException(io.dropwizard.jersey.errors.ErrorMessage response) { + this((Object) response); + } + + private EntityAlreadyExistsException(Object entity) { + super(Response.status(Response.Status.CONFLICT).entity(entity).build()); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index 186d52022b..ac0a722297 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -95,7 +95,7 @@ public Response getById(@PathParam("id") UUID id) { } @POST - @Operation(operationId = "createProject", summary = "Create project", description = "Get project", responses = { + @Operation(operationId = "createProject", summary = "Create project", description = "Create project", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/projects/{projectId}", schema = @Schema(implementation = String.class))}), @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java new file mode 100644 index 0000000000..249bcf6232 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -0,0 +1,68 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.PromptService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; +import com.fasterxml.jackson.annotation.JsonView; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Path("/v1/private/prompts") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class PromptResource { + + private final @NonNull Provider requestContext; + private final @NonNull PromptService promptService; + + @POST + @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { + @ApiResponse(responseCode = "201", description = "Created", headers = { + @Header(name = "Location", required = true, example = "${basePath}/v1/private/prompts/{promptId}", schema = @Schema(implementation = String.class))}), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + + }) + @RateLimited + public Response createPrompt( + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt, + @Context UriInfo uriInfo) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt with name '{}', on workspace_id '{}'", prompt.name(), workspaceId); + prompt = promptService.create(prompt); + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", prompt.id(), prompt.name(), + workspaceId); + + var resourceUri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(prompt.id())).build(); + + return Response.created(resourceUri).build(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java new file mode 100644 index 0000000000..7253d3ee98 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitUtils.java @@ -0,0 +1,14 @@ +package com.comet.opik.domain; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.UUID; + +@UtilityClass +class CommitUtils { + + public String getCommit(@NonNull UUID id) { + return id.toString().substring(id.toString().length() - 8); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java index c687dc72a9..d96b85b49d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java @@ -81,6 +81,7 @@ List find(@Bind("limit") int limit, Optional findByName(@Bind("workspace_id") String workspaceId, @Bind("name") String name); @SqlBatch("UPDATE datasets SET last_created_experiment_at = :experimentCreatedAt WHERE id = :datasetId AND workspace_id = :workspace_id") - int[] recordExperiments(@Bind("workspace_id") String workspaceId, @BindMethods Collection datasets); + int[] recordExperiments(@Bind("workspace_id") String workspaceId, + @BindMethods Collection datasets); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java index 4002a9c59c..48df7716fe 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java @@ -278,7 +278,8 @@ private List enrichDatasetWithAdditionalInformation(List datas return datasets.stream() .map(dataset -> { var resume = experimentSummary.computeIfAbsent(dataset.id(), ExperimentSummary::empty); - var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), DatasetItemSummary::empty); + var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), + DatasetItemSummary::empty); return dataset.toBuilder() .experimentCount(resume.experimentCount()) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java new file mode 100644 index 0000000000..a0919e854f --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java @@ -0,0 +1,60 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.google.common.base.Preconditions; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +interface EntityConstraintHandler { + + Logger log = LoggerFactory.getLogger(EntityConstraintHandler.class); + + static EntityConstraintHandler handle(EntityConstraintAction entityAction) { + return () -> entityAction; + } + + interface EntityConstraintAction { + T execute(); + } + + EntityConstraintAction wrappedAction(); + + default T withError(Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + throw errorProvider.get(); + } else { + throw e; + } + } + } + + default T withRetry(int times, Supplier errorProvider) { + Preconditions.checkArgument(times > 0, "Retry times must be greater than 0"); + + return internalRetry(times, errorProvider); + } + + private T internalRetry(int times, Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + if (times > 0) { + log.warn("Retrying due to constraint violation, remaining attempts: {}", times); + return internalRetry(times - 1, errorProvider); + } + throw errorProvider.get(); + } else { + throw e; + } + } + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java new file mode 100644 index 0000000000..6bfaee9d6b --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(Prompt.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptDAO { + + @SqlUpdate("INSERT INTO prompts (id, name, description, created_by, last_updated_by, workspace_id) " + + "VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspaceId)") + void save(@Bind("workspaceId") String workspaceId, @BindMethods("bean") Prompt prompt); + + @SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId") + Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java new file mode 100644 index 0000000000..9207435bf2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -0,0 +1,118 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.inject.ImplementedBy; +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.inject.Singleton; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; + +import java.util.UUID; + +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; + +@ImplementedBy(PromptServiceImpl.class) +public interface PromptService { + Prompt create(Prompt prompt); + +} + +@Singleton +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +class PromptServiceImpl implements PromptService { + + private static final String ALREADY_EXISTS = "Prompt id or name already exists"; + private static final String VERSION_ALREADY_EXISTS = "Prompt version already exists"; + private final @NonNull Provider requestContext; + private final @NonNull IdGenerator idGenerator; + private final @NonNull TransactionTemplate transactionTemplate; + + @Override + public Prompt create(Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + var newPrompt = prompt.toBuilder() + .id(prompt.id() == null ? idGenerator.generateId() : prompt.id()) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + IdGenerator.validateVersion(prompt.id(), "prompt"); + + var createdPrompt = EntityConstraintHandler + .handle(() -> savePrompt(workspaceId, newPrompt)) + .withError(this::newPromptConflict); + + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", createdPrompt.id(), + createdPrompt.name(), + workspaceId); + + if (!StringUtils.isEmpty(prompt.template())) { + EntityConstraintHandler + .handle(() -> createPromptVersionFromPromptRequest(prompt, createdPrompt, workspaceId)) + .withRetry(3, this::newVersionConflict); + } + + return createdPrompt; + } + + private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt createdPrompt, + String workspaceId) { + log.info("Creating prompt version for prompt id '{}'", createdPrompt.id()); + + var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { + PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); + + UUID versionId = idGenerator.generateId(); + PromptVersion promptVersion = PromptVersion.builder() + .id(versionId) + .promptId(createdPrompt.id()) + .commit(CommitUtils.getCommit(versionId)) + .template(prompt.template()) + .createdBy(createdPrompt.createdBy()) + .build(); + + promptVersionDAO.save(workspaceId, promptVersion); + + return promptVersionDAO.findById(versionId, workspaceId); + }); + + log.info("Created Prompt version for prompt id '{}'", createdPrompt.id()); + + return createdVersion; + } + + private Prompt savePrompt(String workspaceId, Prompt newPrompt) { + return transactionTemplate.inTransaction(WRITE, handle -> { + PromptDAO promptDAO = handle.attach(PromptDAO.class); + + promptDAO.save(workspaceId, newPrompt); + + return promptDAO.findById(newPrompt.id(), workspaceId); + }); + } + + private EntityAlreadyExistsException newConflict(String alreadyExists) { + log.info(alreadyExists); + return new EntityAlreadyExistsException(new ErrorMessage(alreadyExists)); + } + + private EntityAlreadyExistsException newVersionConflict() { + return newConflict(VERSION_ALREADY_EXISTS); + } + + private EntityAlreadyExistsException newPromptConflict() { + return newConflict(ALREADY_EXISTS); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java new file mode 100644 index 0000000000..f1f46fb000 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.PromptVersion; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(PromptVersion.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptVersionDAO { + + @SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, created_by, workspace_id) " + + "VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.createdBy, :workspace_id)") + void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt); + + @SqlQuery("SELECT * FROM prompt_versions WHERE id = :id AND workspace_id = :workspace_id") + PromptVersion findById(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql new file mode 100644 index 0000000000..021459eb50 --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql @@ -0,0 +1,6 @@ +--liquibase formatted sql +--changeset thiagohora:increate_prompt_version_commit_length + +ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(8); + +--rollback ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(7); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java index ca55c88d16..73b585a627 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java @@ -51,7 +51,6 @@ class DatasetEventListenerTest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; private static final String EXPERIMENT_RESOURCE_URI = "%s/v1/private/experiments"; - private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); private static final String WORKSPACE_ID = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java index 8f122ef488..37e907ff6a 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java @@ -47,7 +47,7 @@ import static org.assertj.core.api.Assertions.within; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -@DisplayName("Dataset Event Listener") +@DisplayName("Dataset Experiments E2E Test") class DatasetExperimentE2ETest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index e5c2b9224b..b3f0910bbf 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -136,7 +136,8 @@ class DatasetsResourceTest { public static final String[] IGNORED_FIELDS_DATA_ITEM = {"createdAt", "lastUpdatedAt", "experimentItems", "createdBy", "lastUpdatedBy"}; public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", - "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", "datasetItemsCount"}; + "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", + "datasetItemsCount"}; public static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java new file mode 100644 index 0000000000..ec87ae95c9 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -0,0 +1,317 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.podam.PodamFactoryUtils; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.redis.testcontainers.RedisContainer; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.lifecycle.Startables; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; + +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.infrastructure.auth.RequestContext.SESSION_COOKIE; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static com.comet.opik.infrastructure.auth.TestHttpClientUtils.UNAUTHORIZED_RESPONSE; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +@Testcontainers(parallel = true) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Prompt Resource Test") +class PromptResourceTest { + + private static final String RESOURCE_PATH = "%s/v1/private/prompts"; + + private static final String API_KEY = UUID.randomUUID().toString(); + private static final String USER = UUID.randomUUID().toString(); + private static final String WORKSPACE_ID = UUID.randomUUID().toString(); + private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final ClickHouseContainer CLICKHOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + private static final WireMockUtils.WireMockRuntime wireMock; + + static { + Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); + wireMock = WireMockUtils.startWireMock(); + + DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils + .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + private String baseURI; + private ClientSupport client; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + + ClientSupportUtils.config(client); + + mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + @Nested + @DisplayName("Api Key Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class ApiKey { + + private final String fakeApikey = UUID.randomUUID().toString(); + private final String okApikey = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(okApikey, true), + arguments(fakeApikey, false), + arguments("", false)); + } + + @BeforeEach + void setUp() { + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo("")) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when api key is present, then return proper response") + void createPrompt__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { + + var prompt = factory.manufacturePojo(Prompt.class); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(prompt, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + + } + + @Nested + @DisplayName("Session Token Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class SessionTokenCookie { + + private final String sessionToken = UUID.randomUUID().toString(); + private final String fakeSessionToken = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(sessionToken, true, "OK_" + UUID.randomUUID()), + arguments(fakeSessionToken, false, UUID.randomUUID().toString())); + } + + @BeforeAll + void setUp() { + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(sessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching("OK_.+"))) + .willReturn(okJson(AuthTestUtils.newWorkspaceAuthResponse(USER, WORKSPACE_ID)))); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(fakeSessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when session token is present, then return proper response") + void createPrompt__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, + String workspaceName) { + var prompt = factory.manufacturePojo(Prompt.class); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)).request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(prompt, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + } + + private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(201); + + return TestUtils.getIdFromLocation(response.getLocation()); + } + } + + @Nested + @DisplayName("Create Prompt") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class CreatePrompt { + + @Test + @DisplayName("Should create prompt") + void shouldCreatePrompt() { + + var prompt = factory.manufacturePojo(Prompt.class); + + var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + assertThat(promptId).isNotNull(); + } + + @ParameterizedTest + @MethodSource + @DisplayName("when prompt state is invalid, then return conflict") + void when__promptIsInvalid__thenReturnError(Prompt prompt, int expectedStatusCode, Object expectedBody, + Class expectedResponseClass) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(RequestContext.WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(expectedStatusCode); + + var actualBody = response.readEntity(expectedResponseClass); + + assertThat(actualBody).isEqualTo(expectedBody); + } + } + + Stream when__promptIsInvalid__thenReturnError() { + Prompt prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .id(UUID.randomUUID()) + .build(); + + Prompt duplicatedPrompt = factory.manufacturePojo(Prompt.class); + createPrompt(duplicatedPrompt, API_KEY, TEST_WORKSPACE); + + return Stream.of( + Arguments.of(prompt, 400, + new ErrorMessage(List.of("prompt id must be a version 7 UUID")), + ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().name(UUID.randomUUID().toString()).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().id(factory.manufacturePojo(UUID.class)).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().description("").build(), 422, + new ErrorMessage(List.of("description must not be blank")), + ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().name("").build(), 422, + new ErrorMessage(List.of("name must not be blank")), ErrorMessage.class)); + } + } + +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java new file mode 100644 index 0000000000..5a19c88d79 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java @@ -0,0 +1,88 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import io.dropwizard.jersey.errors.ErrorMessage; +import org.jdbi.v3.core.statement.StatementContext; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class EntityConstraintHandlerTest { + + private static final Supplier ENTITY_ALREADY_EXISTS = () -> new EntityAlreadyExistsException( + new ErrorMessage(409, "Entity already exists")); + + @Test + void testWithError() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + throwDuplicateEntryException(); + return null; + }); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withError(ENTITY_ALREADY_EXISTS)); + } + + private static void throwDuplicateEntryException() { + throw new UnableToExecuteStatementException(new SQLIntegrityConstraintViolationException( + "Duplicate entry '1' for key 'PRIMARY'"), Mockito.mock(StatementContext.class)); + } + + @Test + void testWithRetrySuccess() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> "Success"); + + assertEquals("Success", handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } + + @Test + void testWithRetryFailure() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + throwDuplicateEntryException(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + final int NUM_OF_RETRIES = 3; + + assertThrows(EntityAlreadyExistsException.class, + () -> handler.withRetry(NUM_OF_RETRIES, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(NUM_OF_RETRIES + 1)).execute(); + } + + @Test + void testWithRetryExhausted() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + throwDuplicateEntryException(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withRetry(1, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(2)).execute(); + } + + @Test + void testWithRetryNonConstraintViolation() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + throw new UnableToExecuteStatementException(new RuntimeException(), Mockito.mock(StatementContext.class)); + }); + + assertThrows(UnableToExecuteStatementException.class, () -> handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } +}