Skip to content

Commit

Permalink
refactor: set path of cookie. (#17008)
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun authored Dec 6, 2024
1 parent 5ca9e64 commit 2ede35d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
15 changes: 9 additions & 6 deletions src/query/service/src/servers/http/middleware/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ pub struct HTTPSessionEndpoint<E> {
pub auth_manager: Arc<AuthMgr>,
}

fn make_cookie(name: impl Into<String>, value: impl Into<String>) -> Cookie {
let mut cookie = Cookie::new_with_str(name, value);
cookie.set_path("/");
cookie
}

impl<E> HTTPSessionEndpoint<E> {
#[async_backtrace::framed]
async fn auth(&self, req: &Request, query_id: String) -> Result<HttpQueryContext> {
Expand Down Expand Up @@ -364,17 +370,15 @@ impl<E> HTTPSessionEndpoint<E> {
Some(id1.clone())
}
(Some(id), None) => {
req.cookie()
.add(Cookie::new_with_str(COOKIE_SESSION_ID, id));
req.cookie().add(make_cookie(COOKIE_SESSION_ID, id));
Some(id.clone())
}
(None, Some(id)) => Some(id.clone()),
(None, None) => {
if cookie_enabled {
let id = Uuid::new_v4().to_string();
info!("new session id: {}", id);
req.cookie()
.add(Cookie::new_with_str(COOKIE_SESSION_ID, &id));
req.cookie().add(make_cookie(COOKIE_SESSION_ID, &id));
Some(id)
} else {
None
Expand All @@ -399,8 +403,7 @@ impl<E> HTTPSessionEndpoint<E> {

if cookie_enabled {
let ts = unix_ts().as_secs().to_string();
req.cookie()
.add(Cookie::new_with_str(COOKIE_LAST_ACCESS_TIME, ts));
req.cookie().add(make_cookie(COOKIE_LAST_ACCESS_TIME, ts));
}

let session = session_manager.register_session(session)?;
Expand Down
18 changes: 8 additions & 10 deletions tests/suites/1_stateful/09_http_handler/09_0009_cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@ def __init__(self):
super().__init__()

def set_cookie(self, cookie: Cookie, *args, **kwargs):
assert cookie.path == "/" , cookie
# "" is prefix of any host name or IP, so it will be applied
cookie.domain = ""
cookie.path = "/"
super().set_cookie(cookie, *args, **kwargs)

def get_dict(self, domain=None, path=None):
# 忽略 domain 和 path 参数,返回所有 Cookie
return {cookie.name: cookie.value for cookie in self}


def do_query(session_client, query, session_state=None):
url = f"http://localhost:8000/v1/query"
url = f"http://127.0.0.1:8000/v1/query"
query_payload = {
"sql": query,
"pagination": {"wait_time_secs": 100, "max_rows_per_page": 2},
Expand All @@ -47,22 +44,23 @@ def test_simple():
resp = do_query(client, "select 1")
assert resp.status_code == 200, resp.text
assert resp.json()["data"] == [["1"]], resp.text
sid = client.cookies.get("session_id")
# print(sid)
# print(client.cookies)
sid = client.cookies.get("session_id", path="/")

last_access_time1 = int(client.cookies.get("last_access_time"))
# print(last_access_time1)
assert time.time() - 10 < last_access_time1 < time.time()
assert time.time() - 10 < last_access_time1 <= time.time()

time.sleep(1.5)

resp = do_query(client, "select 1")
assert resp.status_code == 200, resp.text
assert resp.json()["data"] == [["1"]], resp.text
sid2 = client.cookies.get("session_id")
# print(client.cookies)
last_access_time2 = int(client.cookies.get("last_access_time"))
assert sid2 == sid
assert last_access_time1 < last_access_time2 < time.time()
assert last_access_time1 < last_access_time2 <= time.time()


def test_temp_table():
Expand Down

0 comments on commit 2ede35d

Please sign in to comment.