From 1b5630db010e1089d2fa43cf288dd97935520033 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 25 Nov 2024 13:44:52 +0100 Subject: [PATCH] fix: Fix Polars queries --- polars/query.py | 128 +++++++++++++++++++++++------------------------- 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/polars/query.py b/polars/query.py index 96df85dd5..269463acf 100755 --- a/polars/query.py +++ b/polars/query.py @@ -3,7 +3,7 @@ import pandas as pd import polars as pl import timeit -import datetime +from datetime import datetime, date import json hits = pd.read_parquet("hits.parquet") @@ -22,48 +22,48 @@ hits[col] = hits[col].astype(str) start = timeit.default_timer() -pl_df = pl.DataFrame(hits) +pl_df = pl.DataFrame(hits).rechunk() stop = timeit.default_timer() load_time = stop - start # 0: No., 1: SQL, 2: Pandas, 3: Polars -queries = queries = [ - ("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.count(), lambda x: x.height), +queries = [ + ("Q0", "SELECT COUNT(*) FROM hits;", lambda x: x.count(), lambda x: x.select(pl.len()).collect().height), ( "Q1", "SELECT COUNT(*) FROM hits WHERE AdvEngineID <> 0;", lambda x: x[x["AdvEngineID"] != 0].count(), - lambda x: x.filter(pl.col("AdvEngineID") != 0).height, + lambda x: x.select(pl.col("AdvEngineID").filter(pl.col("AdvEngineID") != 0).count()).collect().height, ), ( "Q2", "SELECT SUM(AdvEngineID), COUNT(*), AVG(ResolutionWidth) FROM hits;", lambda x: (x["AdvEngineID"].sum(), x.shape[0], x["ResolutionWidth"].mean()), - lambda x: (x["AdvEngineID"].sum(), x.height, x["ResolutionWidth"].mean()), + lambda x: (x.select(pl.col("advengineid").sum()).collect().item(), x.select(pl.len()).collect().item(), x.select(pl.col("advengineid").mean()).collect().item()), ), ( "Q3", "SELECT AVG(UserID) FROM hits;", lambda x: x["UserID"].mean(), - lambda x: x["UserID"].mean(), + lambda x: x.select(pl.col("UserID").mean()).collect().item(), ), ( "Q4", "SELECT COUNT(DISTINCT UserID) FROM hits;", lambda x: x["UserID"].nunique(), - lambda x: x["UserID"].n_unique(), + lambda x: x.select(pl.col("UserID").n_unique()).collect().item(), ), ( "Q5", "SELECT COUNT(DISTINCT SearchPhrase) FROM hits;", lambda x: x["SearchPhrase"].nunique(), - lambda x: x["SearchPhrase"].n_unique(), + lambda x: x.select(pl.col("SearchPhrase").n_unique()).collect().item(), ), ( "Q6", "SELECT MIN(EventDate), MAX(EventDate) FROM hits;", lambda x: (x["EventDate"].min(), x["EventDate"].max()), - lambda x: (x["EventDate"].min(), x["EventDate"].max()), + lambda x: x.select(pl.col("EventDate").min().alias("e_min"), pl.col("EventDate").max("e_max")).collect().rows()[0] ), ( "Q7", @@ -75,7 +75,7 @@ lambda x: x.filter(pl.col("AdvEngineID") != 0) .group_by("AdvEngineID") .agg(pl.len().alias("count")) - .sort("count", descending=True), + .sort("count", descending=True).collect(), ), ( "Q8", @@ -84,7 +84,7 @@ lambda x: x.group_by("RegionID") .agg(pl.col("UserID").n_unique()) .sort("UserID", descending=True) - .head(10), + .head(10).collect(), ), ( "Q9", @@ -101,7 +101,7 @@ ] ) .sort("AdvEngineID_sum", descending=True) - .head(10), + .head(10).collect(), ), ( "Q10", @@ -114,7 +114,7 @@ .group_by("MobilePhoneModel") .agg(pl.col("UserID").n_unique()) .sort("UserID", descending=True) - .head(10), + .head(10).collect(), ), ( "Q11", @@ -127,7 +127,7 @@ .group_by(["MobilePhone", "MobilePhoneModel"]) .agg(pl.col("UserID").n_unique()) .sort("UserID", descending=True) - .head(10), + .head(10).collect(), ), ( "Q12", @@ -140,7 +140,7 @@ .group_by("SearchPhrase") .agg(pl.len().alias("count")) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q13", @@ -153,7 +153,7 @@ .group_by("SearchPhrase") .agg(pl.col("UserID").n_unique()) .sort("UserID", descending=True) - .head(10), + .head(10).collect(), ), ( "Q14", @@ -166,7 +166,7 @@ .group_by(["SearchEngineID", "SearchPhrase"]) .agg(pl.len().alias("count")) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q15", @@ -175,7 +175,7 @@ lambda x: x.group_by("UserID") .agg(pl.len().alias("count")) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q16", @@ -184,13 +184,13 @@ lambda x: x.group_by(["UserID", "SearchPhrase"]) .agg(pl.len().alias("count")) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q17", "SELECT UserID, SearchPhrase, COUNT(*) FROM hits GROUP BY UserID, SearchPhrase LIMIT 10;", lambda x: x.groupby(["UserID", "SearchPhrase"]).size().head(10), - lambda x: x.group_by(["UserID", "SearchPhrase"]).agg(pl.len()).head(10), + lambda x: x.group_by(["UserID", "SearchPhrase"]).agg(pl.len()).head(10).collect(), ), ( "Q18", @@ -203,19 +203,19 @@ ) .agg(pl.len().alias("count")) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q19", "SELECT UserID FROM hits WHERE UserID = 435090932899640449;", lambda x: x[x["UserID"] == 435090932899640449], - lambda x: x.filter(pl.col("UserID") == 435090932899640449), + lambda x: x.select("UserID").filter(pl.col("UserID") == 435090932899640449).collect(), ), ( "Q20", "SELECT COUNT(*) FROM hits WHERE URL LIKE '%google%';", lambda x: x[x["URL"].str.contains("google")].shape[0], - lambda x: x.filter(pl.col("URL").str.contains("google")).height, + lambda x: x.filter(pl.col("URL").str.contains("google")).select(pl.len()).collect().item(), ), ( "Q21", @@ -230,7 +230,7 @@ .group_by("SearchPhrase") .agg([pl.col("URL").min(), pl.len().alias("count")]) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q22", @@ -260,7 +260,7 @@ ] ) .sort("count", descending=True) - .head(10), + .head(10).collect(), ), ( "Q23", @@ -270,7 +270,7 @@ .head(10), lambda x: x.filter(pl.col("URL").str.contains("google")) .sort("EventTime") - .head(10), + .head(10).collect(), ), ( "Q24", @@ -281,7 +281,7 @@ lambda x: x.filter(pl.col("SearchPhrase") != "") .sort("EventTime") .select("SearchPhrase") - .head(10), + .head(10).collect(), ), ( "Q25", @@ -292,7 +292,7 @@ lambda x: x.filter(pl.col("SearchPhrase") != "") .sort("SearchPhrase") .select("SearchPhrase") - .head(10), + .head(10).collect(), ), ( "Q26", @@ -303,7 +303,7 @@ lambda x: x.filter(pl.col("SearchPhrase") != "") .sort(["EventTime", "SearchPhrase"]) .select("SearchPhrase") - .head(10), + .head(10).collect(), ), ( "Q27", @@ -318,15 +318,13 @@ .group_by("CounterID") # GROUP BY CounterID .agg( [ - pl.col("URL") - .map_elements(lambda y: len(y), return_dtype=pl.Int64) - .alias("l"), # AVG(STRLEN(URL)) + pl.col("URL").str.len_chars().mean().alias("l"), # AVG(STRLEN(URL)) pl.len().alias("c"), # COUNT(*) ] ) .filter(pl.col("c") > 100000) # HAVING COUNT(*) > 100000 .sort("l", descending=True) # ORDER BY l DESC - .limit(25), # LIMIT 25, + .limit(25).collect(), # LIMIT 25, ), ( "Q28", @@ -352,18 +350,14 @@ .group_by("k") .agg( [ - pl.col("Referer").map_elements( - lambda y: len(y), return_dtype=pl.Int64 - ) - # .mean() # skip mean for now - .alias("l"), # AVG(STRLEN(Referer)) + pl.col("Referer").str.len_chars().mean().alias("l"), # AVG(STRLEN(Referer)) pl.col("Referer").min().alias("min_referer"), # MIN(Referer) pl.len().alias("c"), # COUNT(*) ] ) .filter(pl.col("c") > 100000) # HAVING COUNT(*) > 100000 .sort("l", descending=True) # ORDER BY l DESC - .limit(25) # LIMIT 25 + .limit(25).collect() # LIMIT 25 ), ), ( @@ -459,7 +453,7 @@ + x["ResolutionWidth"].shift(87).sum() + x["ResolutionWidth"].shift(88).sum() + x["ResolutionWidth"].shift(89).sum(), - lambda x: sum(x["ResolutionWidth"][:90] + pl.Series(range(90))), + lambda x: x.select(pl.sum_horizontal([pl.col("ResolutionWidth").shift(i) for i in range(1, 90)])).collect(), ), ( "Q30", @@ -482,7 +476,7 @@ ] ) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q31", @@ -505,7 +499,7 @@ ] ) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q32", @@ -526,7 +520,7 @@ ] ) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q33", @@ -535,7 +529,7 @@ lambda x: x.group_by("URL") .agg(pl.len().alias("c")) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q34", @@ -544,7 +538,7 @@ lambda x: x.group_by("URL") .agg(pl.len().alias("c")) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q35", @@ -562,7 +556,7 @@ .group_by(["ClientIP"]) .agg(pl.len().alias("c")) .sort("c", descending=True) - .head(10), + .head(10).collect(), ), ( "Q36", @@ -580,8 +574,8 @@ .nlargest(10), lambda x: x.filter( (pl.col("CounterID") == 62) - & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + & (pl.col("EventDate") >= datetime(2013, 7, 1)) + & (pl.col("EventDate") <= datetime(2013, 7, 31)) & (pl.col("DontCountHits") == 0) & (pl.col("IsRefresh") == 0) & (pl.col("URL") != "") @@ -589,7 +583,7 @@ .group_by("URL") .agg(pl.len().alias("PageViews")) .sort("PageViews", descending=True) - .head(10), + .head(10).collect(), ), ( "Q37", @@ -607,8 +601,8 @@ .nlargest(10), lambda x: x.filter( (pl.col("CounterID") == 62) - & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + & (pl.col("EventDate") >= datetime(2013, 7, 1)) + & (pl.col("EventDate") <= datetime(2013, 7, 31)) & (pl.col("DontCountHits") == 0) & (pl.col("IsRefresh") == 0) & (pl.col("Title") != "") @@ -616,7 +610,7 @@ .group_by("Title") .agg(pl.len().alias("PageViews")) .sort("PageViews", descending=True) - .head(10), + .head(10).collect(), ), ( "Q38", @@ -636,8 +630,8 @@ .iloc[1000:1010], lambda x: x.filter( (pl.col("CounterID") == 62) - & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + & (pl.col("EventDate") >= datetime(2013, 7, 1)) + & (pl.col("EventDate") <= datetime(2013, 7, 31)) & (pl.col("IsRefresh") == 0) & (pl.col("IsLink") != 0) & (pl.col("IsDownload") == 0) @@ -645,7 +639,7 @@ .group_by("URL") .agg(pl.len().alias("PageViews")) .sort("PageViews", descending=True) - .slice(1000, 10), + .slice(1000, 10).collect(), ), ( "Q39", @@ -668,8 +662,8 @@ # note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace # lambda x: x.filter( # (pl.col("CounterID") == 62) - # & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - # & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + # & (pl.col("EventDate") >= datetime(2013, 7, 1)) + # & (pl.col("EventDate") <= datetime(2013, 7, 31)) # & (pl.col("IsRefresh") == 0) # ) # .group_by( @@ -706,8 +700,8 @@ .iloc[100:110], lambda x: x.filter( (pl.col("CounterID") == 62) - & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + & (pl.col("EventDate") >= datetime(2013, 7, 1)) + & (pl.col("EventDate") <= datetime(2013, 7, 31)) & (pl.col("IsRefresh") == 0) & (pl.col("TraficSourceID").is_in([-1, 6])) & (pl.col("RefererHash") == 3594120000172545465) @@ -715,7 +709,7 @@ .group_by(["URLHash", "EventDate"]) .agg(pl.len().alias("PageViews")) .sort("PageViews", descending=True) - .slice(100, 10), + .slice(100, 10).collect(), ), ( "Q41", @@ -735,8 +729,8 @@ .iloc[10000:10010], lambda x: x.filter( (pl.col("CounterID") == 62) - & (pl.col("EventDate") >= pl.datetime(2013, 7, 1)) - & (pl.col("EventDate") <= pl.datetime(2013, 7, 31)) + & (pl.col("EventDate") >= datetime(2013, 7, 1)) + & (pl.col("EventDate") <= datetime(2013, 7, 31)) & (pl.col("IsRefresh") == 0) & (pl.col("DontCountHits") == 0) & (pl.col("URLHash") == 2868770270353813622) @@ -744,7 +738,7 @@ .group_by(["WindowClientWidth", "WindowClientHeight"]) .agg(pl.len().alias("PageViews")) .sort("PageViews", descending=True) - .slice(10000, 10), + .slice(10000, 10).collect(), ), ( "Q42", @@ -766,8 +760,8 @@ # expected leading integer in the duration string, found m # lambda x: x.filter( # (pl.col("CounterID") == 62) - # & (pl.col("EventDate") >= pl.datetime(2013, 7, 14)) - # & (pl.col("EventDate") <= pl.datetime(2013, 7, 15)) + # & (pl.col("EventDate") >= datetime(2013, 7, 14)) + # & (pl.col("EventDate") <= datetime(2013, 7, 15)) # & (pl.col("IsRefresh") == 0) # & (pl.col("DontCountHits") == 0) # ) @@ -792,7 +786,7 @@ result_json = { "system": "Polars (DataFrame)", - "date": datetime.date.today().strftime("%Y-%m-%d"), + "date": date.today().strftime("%Y-%m-%d"), "machine": "c6a.metal, 500gb gp2", "cluster_size": 1, "comment": "",