diff --git a/tests/test_distributed.py b/tests/test_distributed.py index aa884e50..3e7cdc2f 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -18,7 +18,7 @@ def _test_init_parallel_context(parallel_context: ParallelContext): world_rank = dist.get_rank(parallel_context.world_pg) ranks3d = parallel_context.get_3d_ranks(world_rank) - assert type(ranks3d) and len(ranks3d) + assert isinstance(ranks3d, tuple) and len(ranks3d) assert isinstance(parallel_context.world_rank_matrix, np.ndarray) assert isinstance(parallel_context.world_ranks_to_pg, dict)