forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
abalone.py
133 lines (112 loc) · 4.86 KB
/
abalone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import print_function
import argparse
import json
import logging
import os
import pickle as pkl
import pandas as pd
import xgboost as xgb
from sagemaker_containers import entry_point
from sagemaker_xgboost_container import distributed
from sagemaker_xgboost_container.data_utils import get_dmatrix
def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master):
"""Run xgb train on arguments given with rabit initialized.
This is our rabit execution function.
:param args_dict: Argument dictionary used to run xgb.train().
:param is_master: True if current node is master host in distributed training,
or is running single node training job.
Note that rabit_run will include this argument.
"""
booster = xgb.train(params=params, dtrain=dtrain, evals=evals, num_boost_round=num_boost_round)
if is_master:
model_location = model_dir + "/xgboost-model"
pkl.dump(booster, open(model_location, "wb"))
logging.info("Stored trained model at {}".format(model_location))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Hyperparameters are described here.
parser.add_argument(
"--max_depth",
type=int,
)
parser.add_argument("--eta", type=float)
parser.add_argument("--gamma", type=int)
parser.add_argument("--min_child_weight", type=int)
parser.add_argument("--subsample", type=float)
parser.add_argument("--verbosity", type=int)
parser.add_argument("--objective", type=str)
parser.add_argument("--num_round", type=int)
parser.add_argument("--tree_method", type=str, default="auto")
parser.add_argument("--predictor", type=str, default="auto")
# Sagemaker specific arguments. Defaults are set in the environment variables.
parser.add_argument("--output_data_dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR"))
parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
parser.add_argument("--validation", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION"))
parser.add_argument("--sm_hosts", type=str, default=os.environ.get("SM_HOSTS"))
parser.add_argument("--sm_current_host", type=str, default=os.environ.get("SM_CURRENT_HOST"))
args, _ = parser.parse_known_args()
# Get SageMaker host information from runtime environment variables
sm_hosts = json.loads(args.sm_hosts)
sm_current_host = args.sm_current_host
dtrain = get_dmatrix(args.train, "libsvm")
dval = get_dmatrix(args.validation, "libsvm")
watchlist = (
[(dtrain, "train"), (dval, "validation")] if dval is not None else [(dtrain, "train")]
)
train_hp = {
"max_depth": args.max_depth,
"eta": args.eta,
"gamma": args.gamma,
"min_child_weight": args.min_child_weight,
"subsample": args.subsample,
"verbosity": args.verbosity,
"objective": args.objective,
"tree_method": args.tree_method,
"predictor": args.predictor,
}
xgb_train_args = dict(
params=train_hp,
dtrain=dtrain,
evals=watchlist,
num_boost_round=args.num_round,
model_dir=args.model_dir,
)
if len(sm_hosts) > 1:
# Wait until all hosts are able to find each other
entry_point._wait_hostname_resolution()
# Execute training function after initializing rabit.
distributed.rabit_run(
exec_fun=_xgb_train,
args=xgb_train_args,
include_in_training=(dtrain is not None),
hosts=sm_hosts,
current_host=sm_current_host,
update_rabit_args=True,
)
else:
# If single node training, call training method directly.
if dtrain:
xgb_train_args["is_master"] = True
_xgb_train(**xgb_train_args)
else:
raise ValueError("Training channel must have data to train model.")
def model_fn(model_dir):
"""Deserialize and return fitted model.
Note that this should have the same name as the serialized model in the _xgb_train method
"""
model_file = "xgboost-model"
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
return booster