-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_aws.py
85 lines (68 loc) · 2.82 KB
/
run_aws.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import os
import hydra
from omegaconf import DictConfig, OmegaConf, errors
from sagemaker.debugger import TensorBoardOutputConfig
from sagemaker.pytorch.estimator import PyTorch
if __name__ == "__main__":
# Find the argument for yaml_file=some_path
# Define hydra calling method
@hydra.main()
def my_aws_app(cfg: DictConfig) -> None:
script_folder = "." # todo. this is overriden by hydra
script_folder = (
hydra.utils.get_original_cwd()
) # todo. this is overriden by hydra
as_dict = OmegaConf.to_container(cfg, resolve=False)
# Override s3 datapath
aws_bucket = cfg.aws.bucket_prefix
try:
aws_root_path = aws_bucket + cfg.aws.root_path
except errors.ConfigAttributeError:
aws_root_path = aws_bucket + cfg.root_path
# Get the s3 location to load /save to
aws_out_path = aws_root_path + "/" + as_dict["output_subdir"]
aws_data_path = aws_root_path + "/" + as_dict["data_subdir"]
# Override the job json file with sagemaker local dirs
as_dict["root_path"] = "/opt/ml/"
as_dict["data_subdir"] = "input/data/train"
as_dict["output_subdir"] = "output/data"
# Set the local dir for tensorboard
tb_log_dir = "/opt/ml/output/tensorboard/"
as_dict["tb_log_dir"] = tb_log_dir
tensorboard_output_config = TensorBoardOutputConfig(
s3_output_path=aws_out_path,
container_local_output_path=tb_log_dir,
)
print(OmegaConf.to_yaml(cfg))
print("Overriden Root Path: " + aws_root_path)
# Save json file to tmp location to be uploaded with script
tmp_relative_path = "tmp/tmp_job.json"
tmp_path = script_folder + "/" + tmp_relative_path
with open(tmp_path, "w") as json_file:
json.dump(as_dict, json_file)
wait = cfg.aws.wait
role = cfg.aws.role
instance_count = cfg.aws.instance_count
instance_type = cfg.aws.instance_type
env = {
"SAGEMAKER_REQUIREMENTS": "requirements.txt", # path relative to `source_dir` below.
}
# Using Sagemaker prebuilt Pytorch container
pytorch_estimator = PyTorch(
entry_point="run.py",
source_dir=script_folder,
hyperparameters={"config_file": tmp_relative_path},
role=role,
env=env,
instance_count=instance_count,
py_version="py3",
framework_version="1.5.0",
output_path=aws_out_path,
base_job_name=cfg.experiment_name,
instance_type=instance_type,
tensorboard_output_config=tensorboard_output_config,
)
pytorch_estimator.fit({"train": aws_data_path}, wait=wait)
os.remove(tmp_path)
my_aws_app()