Skip to content

Commit

Permalink
improve reliability of logging calls and live output post requests
Browse files Browse the repository at this point in the history
  • Loading branch information
ndharasz committed Oct 22, 2024
1 parent e51afda commit 608b2d9
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib
import time
import random
import glob
import io
from inspect import signature

from numerapi import NumerAPI
Expand Down Expand Up @@ -92,85 +92,94 @@ def exit_with_help(error):


def retry_request_with_backoff(
url: str,
request_func: callable,
retries: int = 10,
delay_base: float = 1.5,
delay_exp: float = 1.5,
debug: bool = False,
):
delay_base = max(1.1, delay_base)
delay_exp = max(1.1, delay_exp)
curr_delay = delay_base
for i in range(retries):
try:
response = requests.get(url, stream=True, allow_redirects=True)
response = request_func()
logging.info("HTTP Response Status: %s", response.status_code)
if response.status_code >= 500:
logging.error("Encountered Server Error. Retrying...")
time.sleep(curr_delay)
curr_delay **= random.uniform(1, delay_exp)
elif 200 <= response.status_code < 300:
logging.debug("Request successful! Returning...")
return response
else:
raise RuntimeError(f"HTTP Error {response.reason} - {response.text}")
except requests.exceptions.ConnectionError:
logging.error("Connection reset! Retrying...")
time.sleep(curr_delay)
curr_delay **= random.uniform(1, delay_exp)
continue
if response.status_code >= 500:
if debug:
logging.debug(f"Encountered Server Error. Retrying...")
except requests.exceptions.SSLError as e:
logging.error("SSL Error: %s", e)
finally:
logging.info("Retrying in %s seconds...", curr_delay)
time.sleep(curr_delay)
curr_delay **= random.uniform(1, delay_exp)
elif response.status_code != 200:
logging.error(f"{response.reason} {response.text}")
sys.exit(1)
else:
logging.debug(f"Request successful! Returning...")
return response
raise RuntimeError(f"Could not complete function call after {retries} retries...")


def get_data(dataset, output_dir):
if os.path.exists(dataset):
dataset_path = dataset
logging.info(f"Using local {dataset_path} for live data")
logging.info("Using local %s for live data", dataset_path)
elif dataset.startswith("/"):
logging.error(f"Local dataset not found - {dataset} does not exist!")
logging.error("Local dataset not found - %s does not exist!", dataset)
exit_with_help(1)
else:
dataset_path = os.path.join(output_dir, dataset)
logging.info(f"Using NumerAPI to download {dataset} for live data")
logging.info("Using NumerAPI to download %s for live data", dataset)
napi = NumerAPI()
napi.download_dataset(dataset, dataset_path)
logging.info(f"Loading live features {dataset_path}")
logging.info("Loading live features %s", dataset_path)
live_features = pd.read_parquet(dataset_path)
return live_features


def main(args):
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)

logging.info(f"Running numerai-predict:{os.getenv('GIT_REF')} Python{py_version()}")
logging.info(
"Running numerai-predict:%s - Python %s", os.getenv("GIT_REF"), py_version()
)

if args.model.lower().startswith("http"):
truncated_url = args.model.split("?")[0]
logging.info(f"Downloading model {truncated_url}")
response = retry_request_with_backoff(args.model)
logging.info("Downloading model %s", truncated_url)
response = retry_request_with_backoff(
lambda: requests.get(args.model, stream=True, allow_redirects=True)
)
model_name = truncated_url.split("/")[-1]
model_pkl = os.path.join(args.output_dir, model_name)
logging.info(f"Saving model to {model_pkl}")
logging.info("Saving model to %s", model_pkl)
with open(model_pkl, "wb") as f:
shutil.copyfileobj(response.raw, f)
else:
model_pkl = args.model

logging.info(f"Loading model {model_pkl}")
logging.info("Loading model %s", model_pkl)
try:
model = pd.read_pickle(model_pkl)
except pickle.UnpicklingError as e:
logging.error(f"Invalid pickle - {e}")
logging.error("Invalid pickle - %s", e)
if args.debug:
logging.exception(e)
exit_with_help(1)
except TypeError as e:
logging.error(f"Pickle incompatible with {py_version()}")
logging.exception(e) if args.debug else logging.error(e)
logging.error("Pickle incompatible with %s", py_version())
if args.debug:
logging.exception(e)
exit_with_help(1)
except ModuleNotFoundError as e:
logging.error(f"Import error reading pickle - {e}")
logging.error("Import error reading pickle - %s", e)
if args.debug:
logging.exception(e)
exit_with_help(1)
Expand All @@ -182,15 +191,15 @@ def main(args):
if num_args > 1:
benchmark_models = get_data(args.benchmarks, args.output_dir)

logging.info(f"Predicting on {len(live_features)} rows of live features")
logging.info("Predicting on %s rows of live features", len(live_features))
try:
if num_args == 1:
predictions = model(live_features)
elif num_args == 2:
predictions = model(live_features, benchmark_models)
else:
logging.error(
f"Invalid pickle function - {model_pkl} must have 1 or 2 arguments"
"Invalid pickle function - %s must have 1 or 2 arguments", model_pkl
)
exit_with_help(1)

Expand All @@ -199,7 +208,8 @@ def main(args):
exit_with_help(1)
elif type(predictions) != pd.DataFrame:
logging.error(
f"Pickle function is invalid - returned {type(predictions)} instead of pd.DataFrame"
"Pickle function is invalid - returned %s instead of pd.DataFrame",
type(predictions),
)
exit_with_help(1)
elif len(predictions) == 0:
Expand All @@ -214,48 +224,38 @@ def main(args):
)
exit_with_help(1)
except TypeError as e:
logging.error(f"Pickle function is invalid - {e}")
logging.error("Pickle function is invalid - %s", e)
if args.debug:
logging.exception(e)
exit_with_help(1)
except Exception as e:
logging.exception(e)
exit_with_help(1)

logging.info(f"Generated {len(predictions)} predictions")
logging.info("Generated %s predictions", len(predictions))
logging.debug(predictions)

predictions_csv = os.path.join(
predictions_csv_file_name = os.path.join(
args.output_dir, f"live_predictions-{secrets.token_hex(6)}.csv"
)
logging.info(f"Saving predictions to {predictions_csv}")
with open(predictions_csv, "w") as f:
predictions.to_csv(f)

if args.post_url:
logging.info(f"Uploading predictions to {args.post_url}")
files = {"file": open(predictions_csv, "rb")}

MAX_RETRIES = 5
RETRY_DELAY = 1.5
RETRY_EXP = 1.5
for i in range(MAX_RETRIES):
try:
r = requests.post(args.post_url, data=args.post_data, files=files)
if r.status_code >= 500:
logging.info("Encountered S3 Server Error.")
elif r.status_code not in [200, 204]:
logging.info(f"HTTP Response Status: {r.status_code}")
logging.error(r.reason)
logging.error(r.text)
else:
sys.exit(0)
except requests.exceptions.SSLError as e:
logging.error(f"SSL Error: {e}")
finally:
logging.info(f"Retrying in {RETRY_DELAY} seconds...")
time.sleep(RETRY_DELAY)
RETRY_DELAY **= random.uniform(1, RETRY_EXP)
logging.info("Uploading predictions to %s", args.post_url)
csv_buffer = io.StringIO()
predictions.to_csv(csv_buffer)
retry_request_with_backoff(
lambda: (
csv_buffer.seek(0)
and requests.post(
args.post_url,
data=args.post_data,
files={"file": (predictions_csv_file_name, csv_buffer, "text/csv")},
)
)
)
else:
logging.info("Saving predictions to %s", predictions_csv_file_name)
with open(predictions_csv_file_name, "w") as f:
predictions.to_csv(f)


if __name__ == "__main__":
Expand Down

0 comments on commit 608b2d9

Please sign in to comment.