Skip to content

Commit

Permalink
Merge pull request #58 from getmetal/jo/namespaced-sessions
Browse files Browse the repository at this point in the history
add namespacing
  • Loading branch information
softboyjimbo authored May 16, 2023
2 parents 3d80caa + a225a10 commit 0feca2f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 22 deletions.
50 changes: 28 additions & 22 deletions src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,23 @@
use crate::long_term_memory::index_messages;
use crate::models::{
AckResponse, AppState, MemoryMessage, MemoryMessagesAndContext, MemoryResponse,
AckResponse, AppState, GetSessionsQuery, MemoryMessage, MemoryMessagesAndContext,
MemoryResponse, NamespaceQuery,
};
use crate::reducer::handle_compaction;
use actix_web::{delete, error, get, post, web, HttpResponse, Responder};
use std::sync::Arc;

#[derive(serde::Deserialize)]
pub struct Pagination {
#[serde(default = "default_page")]
page: usize,
#[serde(default = "default_size")]
size: usize,
}

fn default_page() -> usize {
1
}

fn default_size() -> usize {
10
}

#[get("/sessions")]
pub async fn get_sessions(
web::Query(pagination): web::Query<Pagination>,
web::Query(pagination): web::Query<GetSessionsQuery>,
_data: web::Data<Arc<AppState>>,
redis: web::Data<redis::Client>,
) -> actix_web::Result<impl Responder> {
let Pagination { page, size } = pagination;
let GetSessionsQuery {
page,
size,
namespace,
} = pagination;

if page > 100 {
return Err(actix_web::error::ErrorBadRequest(
Expand All @@ -44,8 +33,13 @@ pub async fn get_sessions(
.await
.map_err(error::ErrorInternalServerError)?;

let sessions_key = match &namespace {
Some(namespace) => format!("sessions:{}", namespace),
None => String::from("sessions"),
};

let session_ids: Vec<String> = redis::cmd("ZRANGE")
.arg("sessions")
.arg(sessions_key)
.arg(start)
.arg(end)
.query_async(&mut conn)
Expand Down Expand Up @@ -123,6 +117,7 @@ pub async fn post_memory(
web::Json(memory_messages): web::Json<MemoryMessagesAndContext>,
data: web::Data<Arc<AppState>>,
redis: web::Data<redis::Client>,
web::Query(namespace_query): web::Query<NamespaceQuery>,
) -> actix_web::Result<impl Responder> {
let mut conn = redis
.get_tokio_connection_manager()
Expand All @@ -145,9 +140,14 @@ pub async fn post_memory(
.map_err(error::ErrorInternalServerError)?;
}

let sessions_key = match namespace_query.namespace {
Some(namespace) => format!("sessions:{}", namespace),
None => String::from("sessions"),
};

// add to sorted set of sessions
redis::cmd("ZADD")
.arg("sessions")
.arg(sessions_key)
.arg(chrono::Utc::now().timestamp())
.arg(&*session_id)
.query_async(&mut conn)
Expand Down Expand Up @@ -211,6 +211,7 @@ pub async fn post_memory(
pub async fn delete_memory(
session_id: web::Path<String>,
redis: web::Data<redis::Client>,
web::Query(namespace_query): web::Query<NamespaceQuery>,
) -> actix_web::Result<impl Responder> {
let mut conn = redis
.get_tokio_connection_manager()
Expand All @@ -222,8 +223,13 @@ pub async fn delete_memory(
let session_key = format!("session:{}", &*session_id);
let keys = vec![context_key, session_key, token_count_key];

let sessions_key = match namespace_query.namespace {
Some(namespace) => format!("sessions:{}", namespace),
None => String::from("sessions"),
};

redis::cmd("ZREM")
.arg("sessions")
.arg(sessions_key)
.arg(&*session_id)
.query_async(&mut conn)
.await
Expand Down
22 changes: 22 additions & 0 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,25 @@ pub fn parse_redisearch_response(response: &Value) -> Vec<RedisearchResult> {
_ => vec![],
}
}

#[derive(serde::Deserialize)]
pub struct NamespaceQuery {
pub namespace: Option<String>,
}

#[derive(serde::Deserialize)]
pub struct GetSessionsQuery {
#[serde(default = "default_page")]
pub page: usize,
#[serde(default = "default_size")]
pub size: usize,
pub namespace: Option<String>,
}

fn default_page() -> usize {
1
}

fn default_size() -> usize {
10
}

0 comments on commit 0feca2f

Please sign in to comment.