diff --git a/cpp/src/wholememory/file_io.cpp b/cpp/src/wholememory/file_io.cpp index 1ad7e85aa..31b87c144 100644 --- a/cpp/src/wholememory/file_io.cpp +++ b/cpp/src/wholememory/file_io.cpp @@ -97,7 +97,7 @@ static size_t get_handle_partial_size(size_t handle_size, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin(char* local_ptr, size_t local_size, @@ -407,7 +407,7 @@ static void read_file_list_to_local_memory(char* local_ptr, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. * @param dev_id : the device bound to the rank. */ static void read_file_list_to_local_memory_roundrobin_with_multi_threads( @@ -878,7 +878,7 @@ static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr, * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. */ static void read_file_list_to_local_memory_roundrobin_directio( char* local_ptr, @@ -1546,7 +1546,7 @@ static void read_file_list_to_local_memory_directio_with_multi_thread( * @param suggested_buffer_size : Suggested buffer size to read. * @param wm_rank : WholeMemory rank. * @param wm_world_size : WholeMemory world size. - * @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy. + * @param round_robin_size : continuous embedding size of a rank using round robin shard strategy. * @param dev_id : the device bound to the rank. */ static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads( diff --git a/cpp/src/wholememory_ops/functions/map_indices_func.cu b/cpp/src/wholememory_ops/functions/map_indices_func.cu index 97d6ca868..1a1418179 100644 --- a/cpp/src/wholememory_ops/functions/map_indices_func.cu +++ b/cpp/src/wholememory_ops/functions/map_indices_func.cu @@ -58,7 +58,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr, if (block_num > 1568) block_num = 1568; IndexT* indice = static_cast(indice_ptr); IndexT* mapped_indice = static_cast(mapped_indice_ptr); - storage_idx2wm_emb_idx_kernel<<>>( + storage_idx2wm_emb_idx_kernel<<>>( indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size); WM_CUDA_CHECK(cudaStreamSynchronize(stream)); return; diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 8ad83bd77..8abc92be9 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -407,7 +407,7 @@ def create_embedding( cache_policy: Union[WholeMemoryCachePolicy, None] = None, random_init: bool = False, gather_sms: int = -1, - round_robin_size=0, + round_robin_size: int = 0, ): r""" Create embedding @@ -419,6 +419,7 @@ def create_embedding( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: WholeMemoryEmbedding """ if optimizer is None: @@ -491,6 +492,7 @@ def create_embedding_from_filelist( :param optimizer: optimizer :param cache_policy: cache policy :param gather_sms: the number of SMs used in gather process + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: """ if isinstance(filelist, str): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index ee62e9964..84ee59eee 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -67,7 +67,7 @@ def gather(self, embedding_count = indice.shape[0] current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),) output_dtype = ( - force_dtype if force_dtype is not None else self.embedding_tensor.dtype + force_dtype if force_dtype is not None else self.dtype ) output_tensor = torch.empty( [embedding_count, embedding_dim], @@ -156,6 +156,7 @@ def from_filelist(self, filelist: Union[List[str], str], round_robin_size: int = """ Load WholeMemory Tensor from file lists :param filelist: file list to load from + :param round_robin_size: continuous embedding size of a rank using round robin shard strategy :return: None """ if isinstance(filelist, str):