diff --git a/Makefile b/Makefile index a6f16952..7e4d1ea3 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ SHLIB_LINK += -lcurl MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o src/pinecone_api.o src/pinecone.o src/cJSON.o +OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o src/pinecone_api.o src/pinecone.o src/cJSON.o src/pinecone_helpers.o HEADERS = src/vector.h src/pinecone_api.h src/pinecone.h src/cJSON.h TESTS = $(wildcard test/sql/*.sql) diff --git a/sql/vector.sql b/sql/vector.sql index a5f0efcf..ed119a30 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -110,6 +110,20 @@ CREATE FUNCTION vector_ip_pinecone_metric_name() RETURNS int4 CREATE FUNCTION vector_cosine_pinecone_metric_name() RETURNS int4 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +-- pinecone helper functions + +CREATE TYPE pinecone_index_stats AS ( + name text, + dimension integer, + metric text, + host text, + status json, + spec json +); + +CREATE FUNCTION pinecone_indexes() RETURNS SETOF pinecone_index_stats + AS 'MODULE_PATHNAME' LANGUAGE C VOLATILE STRICT PARALLEL SAFE; + -- aggregates CREATE AGGREGATE avg(vector) ( diff --git a/src/pinecone_api.c b/src/pinecone_api.c index 459d2bb8..99d0ce4c 100644 --- a/src/pinecone_api.c +++ b/src/pinecone_api.c @@ -7,15 +7,19 @@ size_t write_callback(char *contents, size_t size, size_t nmemb, void *userdata) { size_t real_size = size * nmemb; // size of the response - char **str = (char **)userdata; // cast the userdata to a string pointer - *str = malloc(real_size + 1); // allocate memory for the response - if (*str == NULL) { - ereport(ERROR, (errmsg("Failed to allocate memory for response"))); - return 0; + // char **str = (char **)userdata; // cast the userdata to a string pointer + ResponseData *response_data = (ResponseData *)userdata; + if (response_data->data == NULL) { + response_data->data = malloc(real_size + 1); + memcpy(response_data->data, contents, real_size); + response_data->length = real_size; + } else { + response_data->data = realloc(response_data->data, response_data->length + real_size); + memcpy(response_data->data + response_data->length, contents, real_size); + response_data->length += real_size; } - memcpy(*str, contents, real_size); - (*str)[real_size] = '\0'; // null terminate the string - elog(DEBUG1, "Response (write_callback): %s", *str); + response_data->data[response_data->length] = '\0'; // null terminate the string + elog(DEBUG1, "Response (write_callback): %s", contents); return real_size; } @@ -28,7 +32,7 @@ struct curl_slist *create_common_headers(const char *api_key) { return headers; } -void set_curl_options(CURL *hnd, const char *api_key, const char *url, const char *method, char** response_data) { +void set_curl_options(CURL *hnd, const char *api_key, const char *url, const char *method, ResponseData *response_data) { struct curl_slist *headers = create_common_headers(api_key); curl_easy_setopt(hnd, CURLOPT_HTTPHEADER, headers); curl_easy_setopt(hnd, CURLOPT_CUSTOMREQUEST, method); @@ -43,14 +47,27 @@ void set_curl_options(CURL *hnd, const char *api_key, const char *url, const cha cJSON* describe_index(const char *api_key, const char *index_name) { CURL *hnd = curl_easy_init(); cJSON *response_json; - char *response_data = NULL; + ResponseData response_data = {NULL, 0}; char url[100] = "https://api.pinecone.io/indexes/"; strcat(url, index_name); set_curl_options(hnd, api_key, url, "GET", &response_data); curl_easy_perform(hnd); - response_json = cJSON_Parse(response_data); + response_json = cJSON_Parse(response_data.data); return response_json; } +cJSON* list_indexes(const char *api_key) { + CURL *hnd = curl_easy_init(); + ResponseData response_data = {NULL, 0}; + cJSON *response_json; + cJSON* indexes; + set_curl_options(hnd, api_key, "https://api.pinecone.io/indexes", "GET", &response_data); + curl_easy_perform(hnd); + response_json = cJSON_Parse(response_data.data); + indexes = cJSON_GetObjectItem(response_json, "indexes"); + // todo: error handling + return indexes; +} + /* name, dimension, metric * serverless: cloud, region * pod: environment, replicas, pod_type, pods, shards, metadata_config @@ -60,7 +77,7 @@ cJSON* create_index(const char *api_key, const char *index_name, const int dimen CURL *hnd = curl_easy_init(); cJSON *body = cJSON_CreateObject(); cJSON *spec_json = cJSON_Parse(server_spec); - char *response_data = NULL; + ResponseData response_data = {NULL, 0}; cJSON *response_json; // add fields to body elog(DEBUG1, "Creating index %s with dimension %d and metric %s", index_name, dimension, metric); @@ -74,12 +91,12 @@ cJSON* create_index(const char *api_key, const char *index_name, const int dimen curl_easy_perform(hnd); curl_easy_cleanup(hnd); // return response_data as json - response_json = cJSON_Parse(response_data); + response_json = cJSON_Parse(response_data.data); return response_json; } cJSON* pinecone_api_query_index(const char *api_key, const char *index_host, const int topK, cJSON *query_vector_values, cJSON *filter) { - char* response_data = NULL; + ResponseData response_data = {NULL, 0}; CURL *hnd = curl_easy_init(); cJSON *body = cJSON_CreateObject(); char url[100] = "https://"; strcat(url, index_host); strcat(url, "/query"); // e.g. https://t1-23kshha.svc.apw5-4e34-81fa.pinecone.io/query @@ -92,7 +109,7 @@ cJSON* pinecone_api_query_index(const char *api_key, const char *index_host, con set_curl_options(hnd, api_key, url, "POST", &response_data); curl_easy_setopt(hnd, CURLOPT_POSTFIELDS, cJSON_Print(body)); curl_easy_perform(hnd); - return cJSON_Parse(response_data); + return cJSON_Parse(response_data.data); } void pinecone_bulk_upsert(const char *api_key, const char *index_host, cJSON *vectors, int batch_size) { @@ -126,7 +143,7 @@ CURL* get_pinecone_upsert_handle(const char *api_key, const char *index_host, cJ CURL *hnd = curl_easy_init(); cJSON *body = cJSON_CreateObject(); char *body_str; - char *response_data = NULL; + ResponseData response_data = {NULL, 0}; char url[100] = "https://"; strcat(url, index_host); strcat(url, "/vectors/upsert"); // https://t1-23kshha.svc.apw5-4e34-81fa.pinecone.io/vectors/upsert cJSON_AddItemToObject(body, "vectors", vectors); set_curl_options(hnd, api_key, url, "POST", &response_data); diff --git a/src/pinecone_api.h b/src/pinecone_api.h index 6c1cc2d5..771457da 100644 --- a/src/pinecone_api.h +++ b/src/pinecone_api.h @@ -6,9 +6,16 @@ typedef CURL** CURLHandleList; + +typedef struct { + char *data; + size_t length; +} ResponseData; + struct curl_slist *create_common_headers(const char *api_key); -void set_curl_options(CURL *hnd, const char *api_key, const char *url, const char *method, char** response_data); +void set_curl_options(CURL *hnd, const char *api_key, const char *url, const char *method, ResponseData *response_data); cJSON* describe_index(const char *api_key, const char *index_name); +cJSON* list_indexes(const char *api_key); cJSON* create_index(const char *api_key, const char *index_name, const int dimension, const char *metric, const char *spec); cJSON* pinecone_api_query_index(const char *api_key, const char *index_host, const int topK, cJSON *query_vector_values, cJSON *filter); // void pinecone_upsert_one(const char *api_key, const char *index_host, cJSON *vector); diff --git a/src/pinecone_helpers.c b/src/pinecone_helpers.c new file mode 100644 index 00000000..27d987e3 --- /dev/null +++ b/src/pinecone_helpers.c @@ -0,0 +1,106 @@ +#include "postgres.h" +#include "fmgr.h" +#include +#include "funcapi.h" +#include "pinecone_api.h" +#include "pinecone.h" +#include "cJSON.h" +#include "utils/builtins.h" + +PGDLLEXPORT PG_FUNCTION_INFO_V1(pinecone_indexes); +Datum +pinecone_indexes(PG_FUNCTION_ARGS) { + ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; + Tuplestorestate *tupstore; + TupleDesc tupdesc; + MemoryContext per_query_ctx, oldcontext; + cJSON *indexes; + cJSON *index; + + /* check to see if caller supports us returning a tuplestore */ + if (rsinfo == NULL || !IsA(rsinfo, ReturnSetInfo)) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("set-valued function called in context that cannot accept a set"))); + if (!(rsinfo->allowedModes & SFRM_Materialize)) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("materialize mode required, but it is not allowed in this context"))); + + /* get a tuple descriptor for our result type */ + switch (get_call_result_type(fcinfo, NULL, &tupdesc)) + { + case TYPEFUNC_COMPOSITE: + /* success */ + break; + case TYPEFUNC_RECORD: + /* failed to determine actual type of RECORD */ + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("function returning record called in context that cannot accept type record"))); + break; + default: + /* result type isn't a tuple */ + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("function result type must be a row type"))); + break; + } + + // create a tuple store and tuple descriptor in the per-query context + per_query_ctx = rsinfo->econtext->ecxt_per_query_memory; + oldcontext = MemoryContextSwitchTo(per_query_ctx); + /* create a tuple store */ + tupdesc = CreateTupleDescCopy(tupdesc); + tupstore = tuplestore_begin_heap(true, false, 100); + MemoryContextSwitchTo(oldcontext); + + // validate the api key + if (pinecone_api_key == NULL || strlen(pinecone_api_key) == 0) { + ereport(ERROR, (errmsg("Pinecone API key is not set"))); + } + indexes = list_indexes(pinecone_api_key); + elog(DEBUG1, "Indexes: %s", cJSON_Print(indexes)); + + cJSON_ArrayForEach(index, indexes) { + Datum values[30]; + bool nulls[30]; + HeapTuple tuple; + for (int i = 0; i < tupdesc->natts; i++) { + Form_pg_attribute attr = TupleDescAttr(tupdesc, i); + char* name = NameStr(attr->attname); + Oid type = attr->atttypid; + cJSON *value = cJSON_GetObjectItem(index, name); + switch (type) { + case INT4OID: + nulls[i] = value == NULL || !cJSON_IsNumber(value); + if (!nulls[i]) values[i] = Int32GetDatum((int)cJSON_GetNumberValue(value)); + break; + case TEXTOID: + nulls[i] = value == NULL || !cJSON_IsString(value); + if (!nulls[i]) values[i] = PointerGetDatum(cstring_to_text(cJSON_GetStringValue(value))); + break; + case BOOLOID: + nulls[i] = value == NULL || !cJSON_IsBool(value); + if (!nulls[i]) values[i] = BoolGetDatum(cJSON_IsTrue(value)); + break; + case JSONOID: + nulls[i] = value == NULL; + if (!nulls[i]) values[i] = PointerGetDatum(cstring_to_text(cJSON_Print(value))); + break; + default: + ereport(ERROR, (errmsg("Unsupported type"))); + break; + } + } + tuple = heap_form_tuple(tupdesc, values, nulls); + tuplestore_puttuple(tupstore, tuple); + heap_freetuple(tuple); + } + + rsinfo->returnMode = SFRM_Materialize; + rsinfo->setResult = tupstore; + rsinfo->setDesc = tupdesc; + // when returning a set, we must return a null Datum + return (Datum) 0; +} \ No newline at end of file