Skip to content

Commit

Permalink
feat: add concurrent table synchronization using ThreadPoolExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
daoleno committed Dec 17, 2023
1 parent d604984 commit 24787d3
Showing 1 changed file with 59 additions and 67 deletions.
126 changes: 59 additions & 67 deletions bq-syncer/sync_parquet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import concurrent.futures
import os
import sys
import time
Expand Down Expand Up @@ -33,6 +34,59 @@
os.makedirs(output_directory, exist_ok=True)


def sync_table(table_item, index, total_tables):
last_timestamp = 0
table_id = table_item.table_id
table_ref = dataset_ref.table(table_id)
table = bqclient.get_table(table_ref) # get table object

parquet_file_path = os.path.join(output_directory, f"{table_id}.parquet")

# If the parquet file exists, get the maximum timestamp
if os.path.exists(parquet_file_path):
df_old = pl.read_parquet(parquet_file_path)
if "datastream_metadata.source_timestamp" in df_old.columns:
last_timestamp = df_old["datastream_metadata.source_timestamp"].max()

# Generate list of fields, preserving the original schema's order.
fields = [
f.name
if f.name != "datastream_metadata"
else "datastream_metadata.source_timestamp"
for f in table.schema
]
try:
query = f"SELECT {', '.join(fields)} FROM `{table_ref}` WHERE datastream_metadata.source_timestamp > {last_timestamp}"
query_job = bqclient.query(query)
iterator = query_job.result(page_size=10000)

pages_received = 0
for page in iterator.pages:
pages_received += 1
print(
f"[{datetime.now()}] Processing table {index}/{total_tables}: {table_id} - Page {pages_received}"
)
items = list(page)
if len(items) == 0:
print(
f"[{datetime.now()}] No data received for table {index}/{total_tables}: {table_id}"
)
continue
df = pl.DataFrame({field: data for field, data in zip(fields, zip(*items))})
if os.path.exists(parquet_file_path):
df_old = pl.read_parquet(parquet_file_path)
df = df_old.vstack(df)
df.write_parquet(parquet_file_path)
print(
f"[{datetime.now()}] Data sync of table {index}/{total_tables}: {table_id} completed."
)
except Exception as table_related_error:
print(
f"An error occurred while processing table {table_id}: {table_related_error}"
)
sys.exit(1)


def perform_sync_task():
global is_task_running

Expand All @@ -48,73 +102,11 @@ def perform_sync_task():
total_tables = len(tables)

try:
for index, table_item in enumerate(tables, start=1):
last_timestamp = 0
table_id = table_item.table_id
table_ref = dataset_ref.table(table_id)
table = bqclient.get_table(table_ref) # get table object

parquet_file_path = os.path.join(output_directory, f"{table_id}.parquet")

# If the parquet file exists, get the maximum timestamp
if os.path.exists(parquet_file_path):
df_old = pl.read_parquet(parquet_file_path)

# Check if column 'source_timestamp' exists in the DataFrame
if "datastream_metadata.source_timestamp" in df_old.columns:
last_timestamp = df_old[
"datastream_metadata.source_timestamp"
].max()

# Generate list of fields, preserving the original schema's order.
fields = [
f.name
if f.name != "datastream_metadata"
else "datastream_metadata.source_timestamp"
for f in table.schema
]

try:
query = f"SELECT {', '.join(fields)} FROM `{table_ref}` WHERE datastream_metadata.source_timestamp > {last_timestamp}"
query_job = bqclient.query(query)
iterator = query_job.result(page_size=10000)

pages_received = 0

for page in iterator.pages:
pages_received += 1
print(
f"[{datetime.now()}] Processing table {index}/{total_tables}: {table_id} - Page {pages_received}"
)

items = list(page)

# If no data, skip the loop
if len(items) == 0:
print(
f"[{datetime.now()}] No data received for table {index}/{total_tables}: {table_id}"
)
continue

df = pl.DataFrame(
{field: data for field, data in zip(fields, zip(*items))}
)

if os.path.exists(parquet_file_path):
df_old = pl.read_parquet(parquet_file_path)
df = df_old.vstack(df)

df.write_parquet(parquet_file_path)

print(
f"[{datetime.now()}] Data sync of table {index}/{total_tables}: {table_id} completed."
)
except Exception as table_related_error:
print(
f"An error occurred while processing table {table_id}: {table_related_error}"
)
sys.exit(1)

with concurrent.futures.ThreadPoolExecutor(
max_workers=5
) as executor: # You can adjust the max_workers to your needs
for index, table_item in enumerate(tables, start=1):
executor.submit(sync_table, table_item, index, total_tables)
except Exception as e:
print(f"An error occurred: {e}")
is_task_running = False
Expand Down

0 comments on commit 24787d3

Please sign in to comment.