Skip to content

Feature: LL nvlink p2p #173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
}
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
Expand Down Expand Up @@ -215,7 +223,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1);
} else {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
}
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
}
Expand Down Expand Up @@ -421,7 +435,15 @@ combine(void* combined_x,
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);

void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
}
}

Expand All @@ -431,7 +453,13 @@ combine(void* combined_x,
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1);
} else {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
}
} else {
st_na_release(rdma_recv_flag + global_expert_idx, 1);
}
Expand Down
3 changes: 2 additions & 1 deletion deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, group: dist.ProcessGroup,
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
if not os.getenv("NVSHMEM_DISABLE_P2P"):
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
Expand Down