Skip to content

Commit

Permalink
Enable parallel backfill in run.py (#672)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* added ut

* fix indent

* added blank line

* simplification

* Update run.py

update description. 

Signed-off-by: Pengyu Hou <[email protected]>

---------

Signed-off-by: Pengyu Hou <[email protected]>
  • Loading branch information
pengyu-hou authored Feb 7, 2024
1 parent 06ff92e commit b59ab4a
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 26 deletions.
120 changes: 94 additions & 26 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import re
import subprocess
import xml.etree.ElementTree as ET
from datetime import datetime
from datetime import datetime, timedelta
import multiprocessing

ONLINE_ARGS = "--online-jar={online_jar} --online-class={online_class} "
OFFLINE_ARGS = "--conf-path={conf_path} --end-date={ds} "
Expand Down Expand Up @@ -373,7 +374,8 @@ def __init__(self, args, jar_path):
else:
self.conf_type = args.conf_type
self.ds = args.end_ds if hasattr(args, 'end_ds') and args.end_ds else args.ds
self.parallelism = args.parallelism if hasattr(args, 'parallelism') and args.parallelism else 1
self.start_ds = args.start_ds if hasattr(args, 'start_ds') and args.start_ds else None
self.parallelism = int(args.parallelism) if hasattr(args, 'parallelism') and args.parallelism else 1
self.jar_path = jar_path
self.args = args.args if args.args else ""
self.online_class = args.online_class
Expand All @@ -390,27 +392,22 @@ def __init__(self, args, jar_path):
self.list_apps_cmd = args.list_apps

def run(self):
base_args = MODE_ARGS[self.mode].format(
conf_path=self.conf,
ds=self.ds,
online_jar=self.online_jar,
online_class=self.online_class,
)
final_args = base_args + " " + str(self.args)
command_list = []
if self.mode == "info":
command = "python3 {script} --conf {conf} --ds {ds} --repo {repo}".format(
command_list.append("python3 {script} --conf {conf} --ds {ds} --repo {repo}".format(
script=self.render_info, conf=self.conf, ds=self.ds, repo=self.repo
)
))
elif self.sub_help or (self.mode not in SPARK_MODES):
command = (
command_list.append(
"java -cp {jar} ai.chronon.spark.Driver {subcommand} {args}".format(
jar=self.jar_path,
args="--help" if self.sub_help else final_args,
args="--help" if self.sub_help else self._gen_final_args(),
subcommand=ROUTES[self.conf_type][self.mode],
)
)
else:
if self.mode in ["streaming", "streaming-client"]:
# streaming mode
self.app_name = self.app_name.replace(
"_streaming-client_", "_streaming_"
) # If the job is running cluster mode we want to kill it.
Expand Down Expand Up @@ -455,16 +452,87 @@ def run(self):
"Attempting to submit an application in client mode, but there's already"
" an existing one running."
)
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=final_args,
additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
)
check_call(command)
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(),
additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
)
command_list.append(command)
else:
# offline mode
if self.parallelism > 1:
assert self.start_ds is not None and self.ds is not None, \
"To use parallelism, please specify --start-ds and --end-ds to " \
"break down into multiple backfill jobs"
date_ranges = split_date_range(self.start_ds, self.ds, self.parallelism)
for (start_ds, end_ds) in date_ranges:
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(start_ds=start_ds, end_ds=end_ds),
additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
)
command_list.append(command)
else:
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(self.start_ds),
additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
)
command_list.append(command)
if len(command_list) > 1:
# parallel backfill mode
with multiprocessing.Pool(processes=int(self.parallelism)) as pool:
logging.info("Running args list {} with pool size {}".format(command_list, self.parallelism))
pool.map(check_call, command_list)
elif len(command_list) == 1:
check_call(command_list[0])

def _gen_final_args(self, start_ds=None, end_ds=None):
base_args = MODE_ARGS[self.mode].format(
conf_path=self.conf,
ds=end_ds if end_ds else self.ds,
online_jar=self.online_jar,
online_class=self.online_class,
)
override_start_partition_arg = "--start-partition-override=" + start_ds if start_ds else ""
final_args = base_args + " " + str(self.args) + override_start_partition_arg
return final_args


def split_date_range(start_date, end_date, parallelism):
start_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date = datetime.strptime(end_date, "%Y-%m-%d")
if start_date > end_date:
raise ValueError("Start date should be earlier than end date")
total_days = (end_date - start_date).days + 1 # +1 to include the end_date in the range

# Check if parallelism is greater than total_days
if parallelism > total_days:
raise ValueError("Parallelism should be less than or equal to total days")

split_size = total_days // parallelism
date_ranges = []

for i in range(parallelism):
split_start = start_date + timedelta(days=i * split_size)
if i == parallelism - 1:
split_end = end_date
else:
split_end = split_start + timedelta(days=split_size - 1)
date_ranges.append((split_start.strftime("%Y-%m-%d"), split_end.strftime("%Y-%m-%d")))
return date_ranges


def set_defaults(parser):
Expand Down Expand Up @@ -514,7 +582,8 @@ def set_defaults(parser):
"--end-ds", help="the end ds for a range backfill"
)
parser.add_argument(
"--parallelism", help="break down the backfill range into this number of tasks in parallel"
"--parallelism", help="break down the backfill range into this number of tasks in parallel. "
"Please use it along with --start-ds and --end-ds and only in manual mode"
)
parser.add_argument("--repo", help="Path to chronon repo")
parser.add_argument(
Expand Down Expand Up @@ -572,8 +641,7 @@ def set_defaults(parser):
args, unknown_args = parser.parse_known_args()
jar_type = "embedded" if args.mode in MODES_USING_EMBEDDED else "uber"
extra_args = (" " + args.online_args) if args.mode in ONLINE_MODES else ""
override_start_partition_arg = "--start-partition-override=" + args.start_ds if args.start_ds else ""
args.args = " ".join(unknown_args) + extra_args + override_start_partition_arg
args.args = " ".join(unknown_args) + extra_args
jar_path = (
args.chronon_jar
if args.chronon_jar
Expand Down
15 changes: 15 additions & 0 deletions api/py/test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,18 @@ def mock_check_output_with_app_other_user(cmd):
runner = run.Runner(parse_args, 'some.jar')
with pytest.raises(RuntimeError):
runner.run()


def test_split_date_range():
start_date = "2022-01-01"
end_date = "2022-01-11"
parallelism = 5
expected_result = [('2022-01-01', '2022-01-02'),
('2022-01-03', '2022-01-04'),
('2022-01-05', '2022-01-06'),
('2022-01-07', '2022-01-08'),
('2022-01-09', '2022-01-11')]

result = run.split_date_range(start_date, end_date, parallelism)
assert(result == expected_result)

0 comments on commit b59ab4a

Please sign in to comment.