Skip to content

Commit

Permalink
[OPIK-309] Create prompt endpoint (#531)
Browse files Browse the repository at this point in the history
* [OPIK-309] Create prompt endpoint

* Add logic to create first version when specified

* Address PR review comments
  • Loading branch information
thiagohora authored Nov 4, 2024
1 parent e53d35d commit 203299d
Show file tree
Hide file tree
Showing 18 changed files with 852 additions and 7 deletions.
56 changes: 56 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.time.Instant;
import java.util.List;
import java.util.UUID;

import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record Prompt(
@JsonView( {
Prompt.View.Public.class, Prompt.View.Write.class}) UUID id,
@JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name,
@JsonView({Prompt.View.Public.class,
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
@JsonView({
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){

public static class View {
public static class Write {
}

public static class Public {
}
}

public record PromptPage(
@JsonView( {
Project.View.Public.class}) int page,
@JsonView({Project.View.Public.class}) int size,
@JsonView({Project.View.Public.class}) long total,
@JsonView({Project.View.Public.class}) List<Prompt> content)
implements
Page<Prompt>{

public static Prompt.PromptPage empty(int page) {
return new Prompt.PromptPage(page, 0, 0, List.of());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;

import java.time.Instant;
import java.util.List;
import java.util.Set;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record PromptVersion(
@JsonView( {
PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id,
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId,
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit,
@JsonView({PromptVersion.View.Detail.class}) @NotNull String template,
@JsonView({
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set<String> variables,
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){

public static class View {
public static class Public {
}

public static class Detail {
}
}

@Builder
public record PromptVersionPage(
@JsonView( {
PromptVersion.View.Public.class}) int page,
@JsonView({PromptVersion.View.Public.class}) int size,
@JsonView({PromptVersion.View.Public.class}) long total,
@JsonView({PromptVersion.View.Public.class}) List<PromptVersion> content)
implements
Page<PromptVersion>{

public static PromptVersion.PromptVersionPage empty(int page) {
return new PromptVersion.PromptVersionPage(page, 0, 0, List.of());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
public class EntityAlreadyExistsException extends ClientErrorException {

public EntityAlreadyExistsException(ErrorMessage response) {
super(Response.status(Response.Status.CONFLICT).entity(response).build());
this((Object) response);
}

public EntityAlreadyExistsException(io.dropwizard.jersey.errors.ErrorMessage response) {
this((Object) response);
}

private EntityAlreadyExistsException(Object entity) {
super(Response.status(Response.Status.CONFLICT).entity(entity).build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Response getById(@PathParam("id") UUID id) {
}

@POST
@Operation(operationId = "createProject", summary = "Create project", description = "Get project", responses = {
@Operation(operationId = "createProject", summary = "Create project", description = "Create project", responses = {
@ApiResponse(responseCode = "201", description = "Created", headers = {
@Header(name = "Location", required = true, example = "${basePath}/v1/private/projects/{projectId}", schema = @Schema(implementation = String.class))}),
@ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.Prompt;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.domain.PromptService;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.ratelimit.RateLimited;
import com.fasterxml.jackson.annotation.JsonView;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.UriInfo;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@Path("/v1/private/prompts")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Timed
@Slf4j
@RequiredArgsConstructor(onConstructor_ = @Inject)
public class PromptResource {

private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull PromptService promptService;

@POST
@Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = {
@ApiResponse(responseCode = "201", description = "Created", headers = {
@Header(name = "Location", required = true, example = "${basePath}/v1/private/prompts/{promptId}", schema = @Schema(implementation = String.class))}),
@ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))),

})
@RateLimited
public Response createPrompt(
@RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt,
@Context UriInfo uriInfo) {

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Creating prompt with name '{}', on workspace_id '{}'", prompt.name(), workspaceId);
prompt = promptService.create(prompt);
log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", prompt.id(), prompt.name(),
workspaceId);

var resourceUri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(prompt.id())).build();

return Response.created(resourceUri).build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.comet.opik.domain;

import lombok.NonNull;
import lombok.experimental.UtilityClass;

import java.util.UUID;

@UtilityClass
class CommitUtils {

public String getCommit(@NonNull UUID id) {
return id.toString().substring(id.toString().length() - 8);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ List<Dataset> find(@Bind("limit") int limit,
Optional<Dataset> findByName(@Bind("workspace_id") String workspaceId, @Bind("name") String name);

@SqlBatch("UPDATE datasets SET last_created_experiment_at = :experimentCreatedAt WHERE id = :datasetId AND workspace_id = :workspace_id")
int[] recordExperiments(@Bind("workspace_id") String workspaceId, @BindMethods Collection<DatasetLastExperimentCreated> datasets);
int[] recordExperiments(@Bind("workspace_id") String workspaceId,
@BindMethods Collection<DatasetLastExperimentCreated> datasets);

}
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ private List<Dataset> enrichDatasetWithAdditionalInformation(List<Dataset> datas
return datasets.stream()
.map(dataset -> {
var resume = experimentSummary.computeIfAbsent(dataset.id(), ExperimentSummary::empty);
var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), DatasetItemSummary::empty);
var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(),
DatasetItemSummary::empty);

return dataset.toBuilder()
.experimentCount(resume.experimentCount())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.comet.opik.domain;

import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.google.common.base.Preconditions;
import org.jdbi.v3.core.statement.UnableToExecuteStatementException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.SQLIntegrityConstraintViolationException;
import java.util.function.Supplier;

interface EntityConstraintHandler<T> {

Logger log = LoggerFactory.getLogger(EntityConstraintHandler.class);

static <E> EntityConstraintHandler<E> handle(EntityConstraintAction<E> entityAction) {
return () -> entityAction;
}

interface EntityConstraintAction<T> {
T execute();
}

EntityConstraintAction<T> wrappedAction();

default T withError(Supplier<EntityAlreadyExistsException> errorProvider) {
try {
return wrappedAction().execute();
} catch (UnableToExecuteStatementException e) {
if (e.getCause() instanceof SQLIntegrityConstraintViolationException) {
throw errorProvider.get();
} else {
throw e;
}
}
}

default T withRetry(int times, Supplier<EntityAlreadyExistsException> errorProvider) {
Preconditions.checkArgument(times > 0, "Retry times must be greater than 0");

return internalRetry(times, errorProvider);
}

private T internalRetry(int times, Supplier<EntityAlreadyExistsException> errorProvider) {
try {
return wrappedAction().execute();
} catch (UnableToExecuteStatementException e) {
if (e.getCause() instanceof SQLIntegrityConstraintViolationException) {
if (times > 0) {
log.warn("Retrying due to constraint violation, remaining attempts: {}", times);
return internalRetry(times - 1, errorProvider);
}
throw errorProvider.get();
} else {
throw e;
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.comet.opik.domain;

import com.comet.opik.api.Prompt;
import com.comet.opik.infrastructure.db.UUIDArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.customizer.BindMethods;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import org.jdbi.v3.sqlobject.statement.SqlUpdate;

import java.util.UUID;

@RegisterConstructorMapper(Prompt.class)
@RegisterArgumentFactory(UUIDArgumentFactory.class)
interface PromptDAO {

@SqlUpdate("INSERT INTO prompts (id, name, description, created_by, last_updated_by, workspace_id) " +
"VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspaceId)")
void save(@Bind("workspaceId") String workspaceId, @BindMethods("bean") Prompt prompt);

@SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId")
Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId);

}
Loading

0 comments on commit 203299d

Please sign in to comment.