Skip to content

Low-latency P2P code cleanup and bug fixed #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,8 +1190,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
}

torch::Tensor
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) {
Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const {
LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts);

auto buffer = layout.buffers[low_latency_buffer_idx];
auto dtype = torch::kBFloat16;
auto num_msg_elems = static_cast<int>(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16));
Expand Down
2 changes: 1 addition & 1 deletion csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ struct Buffer {
const std::optional<torch::Tensor>& out = std::nullopt);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
};

} // namespace deep_ep
14 changes: 14 additions & 0 deletions csrc/kernels/ibgda_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -432,4 +432,18 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, cons
}
}

__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) {
// Local rank, no need for mapping
if (rank == dst_rank)
return ptr;
auto peer_base = __ldg(reinterpret_cast<uint64_t*>(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank);

// RDMA connected
if (peer_base == 0)
return 0;

// NVLink P2P is enabled
return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
}

} // namespace deep_ep
64 changes: 21 additions & 43 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
slot_idx * num_bytes_per_msg;
if (dst_rank != rank) {
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
} else {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
}
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx);
} else {
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global);
}

Expand Down Expand Up @@ -222,16 +215,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,

// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
if (dst_rank != rank) {
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) { // P2P enabled
int *rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_count + dst_expert_local_idx * num_ranks + rank) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(rptr_actual, -num_tokens_sent - 1);
} else {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
}
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx);
} else {
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1);
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), -num_tokens_sent - 1);
}

// Clean workspace for next use
Expand Down Expand Up @@ -428,22 +417,15 @@ combine(void* combined_x,
auto src_idx = __ldg(local_src_info + token_idx);
const auto buf_ptr = reinterpret_cast<int64_t>(rdma_send_x_vec_row);
const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot;
if (dst_rank == rank) {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
if (not zero_copy)
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global);

void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
char *req_rptr_actual = (char *)(peer_base_addr) + ((char *)dst_ptr - (char *)(nvshmemi_device_state_d.heap_base));
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
} else {
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
}
nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(nv_bfloat16), dst_rank, local_expert_idx, lane_id, token_idx - offset);
} else {
const auto dst_int4_ptr = reinterpret_cast<int4*>(dst_p2p_ptr);
UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global);
}
}

Expand All @@ -452,16 +434,12 @@ combine(void* combined_x,
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32));
if (sub_warp_id == 1 and lane_id == 0) {
while (ld_acquire_global(atomic_clean_flag) == 0);
if (dst_rank != rank) {
void *peer_base_addr = (void *)__ldg((const long long unsigned *)nvshmemi_device_state_d.peer_heap_base_p2p + dst_rank);
if (peer_base_addr) {
int *req_rptr_actual = (int *)((char *)(peer_base_addr) + ((char *)(rdma_recv_flag + global_expert_idx) - (char *)(nvshmemi_device_state_d.heap_base)));
st_na_release(req_rptr_actual, 1);
} else {
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
}
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_flag + global_expert_idx);
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx);
} else {
st_na_release(rdma_recv_flag + global_expert_idx, 1);
st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1);
}
atomic_add_release_global(atomic_clean_flag, -1);
}
Expand All @@ -473,7 +451,7 @@ combine(void* combined_x,
if ((phases & LOW_LATENCY_RECV_PHASE) == 0)
return;

// Wait all ranks to arrive and notify PCIe usage
// Wait all ranks to arrive
if (responsible_expert_idx < num_experts) {
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
if (sub_warp_id == 0 and lane_id == 0)
Expand Down
10 changes: 7 additions & 3 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Buffer:

def __init__(self, group: dist.ProcessGroup,
num_nvl_bytes: int = 0, num_rdma_bytes: int = 0,
low_latency_mode: bool = False, num_qps_per_rank: int = 12) -> None:
low_latency_mode: bool = False, num_qps_per_rank: int = 12,
allow_nvlink_for_low_latency_mode: bool = False) -> None:
"""
Initialize the communication buffer.

Expand All @@ -42,6 +43,10 @@ def __init__(self, group: dist.ProcessGroup,
low_latency_mode: whether to enable low-latency mode.
num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals
to the number of local experts.
allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice
this is somehow incompatible with the hook-based overlapping.
Warning: PCIe connections may lead to errors due to memory ordering issues,
please make sure all connections are via NVLink.
"""

# Initialize the CPP runtime
Expand All @@ -68,8 +73,7 @@ def __init__(self, group: dist.ProcessGroup,
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA
assert num_qps_per_rank > 0
if not os.getenv("NVSHMEM_DISABLE_P2P"):
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
Expand Down