Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
quick fix to a map_indice bug && add comment for parameter round_robi…
Browse files Browse the repository at this point in the history
…n_size
  • Loading branch information
linhu-nv committed May 27, 2024
1 parent 7352f1c commit cbf5db7
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
8 changes: 4 additions & 4 deletions cpp/src/wholememory/file_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static size_t get_handle_partial_size(size_t handle_size,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
*/
static void read_file_list_to_local_memory_roundrobin(char* local_ptr,
size_t local_size,
Expand Down Expand Up @@ -407,7 +407,7 @@ static void read_file_list_to_local_memory(char* local_ptr,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
* @param dev_id : the device bound to the rank.
*/
static void read_file_list_to_local_memory_roundrobin_with_multi_threads(
Expand Down Expand Up @@ -878,7 +878,7 @@ static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
*/
static void read_file_list_to_local_memory_roundrobin_directio(
char* local_ptr,
Expand Down Expand Up @@ -1546,7 +1546,7 @@ static void read_file_list_to_local_memory_directio_with_multi_thread(
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
* @param dev_id : the device bound to the rank.
*/
static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads(
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory_ops/functions/map_indices_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr,
if (block_num > 1568) block_num = 1568;
IndexT* indice = static_cast<IndexT*>(indice_ptr);
IndexT* mapped_indice = static_cast<IndexT*>(mapped_indice_ptr);
storage_idx2wm_emb_idx_kernel<<<block_num, block_size>>>(
storage_idx2wm_emb_idx_kernel<<<block_num, block_size, 0, stream>>>(
indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size);
WM_CUDA_CHECK(cudaStreamSynchronize(stream));
return;
Expand Down
4 changes: 3 additions & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def create_embedding(
cache_policy: Union[WholeMemoryCachePolicy, None] = None,
random_init: bool = False,
gather_sms: int = -1,
round_robin_size=0,
round_robin_size: int = 0,
):
r"""
Create embedding
Expand All @@ -419,6 +419,7 @@ def create_embedding(
:param optimizer: optimizer
:param cache_policy: cache policy
:param gather_sms: the number of SMs used in gather process
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return: WholeMemoryEmbedding
"""
if optimizer is None:
Expand Down Expand Up @@ -491,6 +492,7 @@ def create_embedding_from_filelist(
:param optimizer: optimizer
:param cache_policy: cache policy
:param gather_sms: the number of SMs used in gather process
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return:
"""
if isinstance(filelist, str):
Expand Down
1 change: 1 addition & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def from_filelist(self, filelist: Union[List[str], str], round_robin_size: int =
"""
Load WholeMemory Tensor from file lists
:param filelist: file list to load from
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return: None
"""
if isinstance(filelist, str):
Expand Down

0 comments on commit cbf5db7

Please sign in to comment.