diff --git a/src/query/service/src/servers/http/middleware/session.rs b/src/query/service/src/servers/http/middleware/session.rs index 12388ab5a38d..27ad1ee793f9 100644 --- a/src/query/service/src/servers/http/middleware/session.rs +++ b/src/query/service/src/servers/http/middleware/session.rs @@ -322,6 +322,12 @@ pub struct HTTPSessionEndpoint { pub auth_manager: Arc, } +fn make_cookie(name: impl Into, value: impl Into) -> Cookie { + let mut cookie = Cookie::new_with_str(name, value); + cookie.set_path("/"); + cookie +} + impl HTTPSessionEndpoint { #[async_backtrace::framed] async fn auth(&self, req: &Request, query_id: String) -> Result { @@ -364,8 +370,7 @@ impl HTTPSessionEndpoint { Some(id1.clone()) } (Some(id), None) => { - req.cookie() - .add(Cookie::new_with_str(COOKIE_SESSION_ID, id)); + req.cookie().add(make_cookie(COOKIE_SESSION_ID, id)); Some(id.clone()) } (None, Some(id)) => Some(id.clone()), @@ -373,8 +378,7 @@ impl HTTPSessionEndpoint { if cookie_enabled { let id = Uuid::new_v4().to_string(); info!("new session id: {}", id); - req.cookie() - .add(Cookie::new_with_str(COOKIE_SESSION_ID, &id)); + req.cookie().add(make_cookie(COOKIE_SESSION_ID, &id)); Some(id) } else { None @@ -399,8 +403,7 @@ impl HTTPSessionEndpoint { if cookie_enabled { let ts = unix_ts().as_secs().to_string(); - req.cookie() - .add(Cookie::new_with_str(COOKIE_LAST_ACCESS_TIME, ts)); + req.cookie().add(make_cookie(COOKIE_LAST_ACCESS_TIME, ts)); } let session = session_manager.register_session(session)?; diff --git a/tests/suites/1_stateful/09_http_handler/09_0009_cookie.py b/tests/suites/1_stateful/09_http_handler/09_0009_cookie.py index 521e887949f4..7af8f6a92eac 100755 --- a/tests/suites/1_stateful/09_http_handler/09_0009_cookie.py +++ b/tests/suites/1_stateful/09_http_handler/09_0009_cookie.py @@ -14,17 +14,14 @@ def __init__(self): super().__init__() def set_cookie(self, cookie: Cookie, *args, **kwargs): + assert cookie.path == "/" , cookie + # "" is prefix of any host name or IP, so it will be applied cookie.domain = "" - cookie.path = "/" super().set_cookie(cookie, *args, **kwargs) - def get_dict(self, domain=None, path=None): - # 忽略 domain 和 path 参数,返回所有 Cookie - return {cookie.name: cookie.value for cookie in self} - def do_query(session_client, query, session_state=None): - url = f"http://localhost:8000/v1/query" + url = f"http://127.0.0.1:8000/v1/query" query_payload = { "sql": query, "pagination": {"wait_time_secs": 100, "max_rows_per_page": 2}, @@ -47,12 +44,12 @@ def test_simple(): resp = do_query(client, "select 1") assert resp.status_code == 200, resp.text assert resp.json()["data"] == [["1"]], resp.text - sid = client.cookies.get("session_id") - # print(sid) + # print(client.cookies) + sid = client.cookies.get("session_id", path="/") last_access_time1 = int(client.cookies.get("last_access_time")) # print(last_access_time1) - assert time.time() - 10 < last_access_time1 < time.time() + assert time.time() - 10 < last_access_time1 <= time.time() time.sleep(1.5) @@ -60,9 +57,10 @@ def test_simple(): assert resp.status_code == 200, resp.text assert resp.json()["data"] == [["1"]], resp.text sid2 = client.cookies.get("session_id") + # print(client.cookies) last_access_time2 = int(client.cookies.get("last_access_time")) assert sid2 == sid - assert last_access_time1 < last_access_time2 < time.time() + assert last_access_time1 < last_access_time2 <= time.time() def test_temp_table():