Skip to content

Commit

Permalink
set env variables for pytorch before initializing w ddp.
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 4, 2025
1 parent f6ca2bc commit b4ed6cc
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@
import json
import os

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
# disable only for deepspeech if it works fine for other workloads.
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

import struct
import time
from types import MappingProxyType
Expand Down Expand Up @@ -56,6 +51,11 @@
from algorithmic_efficiency.pytorch_utils import sync_ddp_time
from algorithmic_efficiency.workloads import workloads

# Environment variables
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
# disable only for deepspeech if it works fine for other workloads
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'

# TODO(znado): make a nicer registry of workloads that lookup in.
BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR

Expand Down Expand Up @@ -681,6 +681,14 @@ def main(_):
else:
profiler = PassThroughProfiler()

# Set PyTorch environment variables before initializing w DDP
base_workload = workloads.get_base_workload_name(FLAGS.workload)
if base_workload == 'librispeech_conformer':
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

if FLAGS.set_pytorch_max_split_size:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

if FLAGS.framework == 'pytorch':
pytorch_init(USE_PYTORCH_DDP, RANK, profiler)

Expand All @@ -692,9 +700,6 @@ def main(_):

workload_metadata = WORKLOADS[FLAGS.workload]

# Prevent OOM on librispeech conformer.
base_workload = workloads.get_base_workload_name(FLAGS.workload)

if base_workload in [
'librispeech_conformer',
'librispeech_deepspeech',
Expand All @@ -703,13 +708,6 @@ def main(_):
]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'

if base_workload != 'librispeech_conformer':
# Remove the environment variable (only for workloads other than librispeech conformer).
del os.environ['PYTORCH_CUDA_ALLOC_CONF']

if FLAGS.set_pytorch_max_split_size:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256'

# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
BASE_WORKLOADS_DIR,
Expand Down

0 comments on commit b4ed6cc

Please sign in to comment.