From b7836dd405da07cb1d0ef57d9fd72e47094d95ab Mon Sep 17 00:00:00 2001 From: gejielun <18243586937@163.com> Date: Mon, 13 Feb 2023 15:14:06 +0800 Subject: [PATCH] fix(data_join): fix batch id error --- fedlearner/trainer/data_visitor.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/fedlearner/trainer/data_visitor.py b/fedlearner/trainer/data_visitor.py index 77edffeac..8df7845cb 100644 --- a/fedlearner/trainer/data_visitor.py +++ b/fedlearner/trainer/data_visitor.py @@ -344,10 +344,13 @@ def __init__(self, datablocks = [] for dirname, _, filenames in tf.io.gfile.walk(data_path): + base_path = data_path + if os.path.basename(data_path) != 'batch': + base_path = os.path.dirname(data_path) for filename in filenames: if not fnmatch(os.path.join(dirname, filename), wildcard): continue - subdirname = os.path.relpath(dirname, data_path) + subdirname = os.path.relpath(dirname, base_path) block_id = os.path.join(subdirname, filename) datablock = _RawDataBlock( id=block_id, data_path=os.path.join(dirname, filename), @@ -359,11 +362,14 @@ def __init__(self, local_data_path) local_datablocks = [] if local_data_path and tf.io.gfile.exists(local_data_path): + local_base_path = local_data_path + if os.path.basename(local_data_path) != 'batch': + local_base_path = os.path.dirname(local_data_path) for dirname, _, filenames in tf.io.gfile.walk(local_data_path): for filename in filenames: if not fnmatch(os.path.join(dirname, filename), wildcard): continue - subdirname = os.path.relpath(dirname, local_data_path) + subdirname = os.path.relpath(dirname, local_base_path) block_id = os.path.join(subdirname, filename) datablock = _RawDataBlock( id=block_id, data_path=os.path.join(dirname, filename),