Skip to content

[core][distributed] use tcp store directly #10275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():


def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
Expand All @@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):

def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
Expand All @@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):


def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
Expand All @@ -101,16 +108,15 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):


def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()


# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
Expand Down
28 changes: 13 additions & 15 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, Deque, Dict, Optional, Sequence, Tuple

import torch
from torch.distributed.rendezvous import rendezvous
from torch.distributed import TCPStore

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
Expand Down Expand Up @@ -127,7 +126,7 @@ def __post_init__(self):
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
Expand All @@ -147,8 +146,7 @@ def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
self.recv_src_counter[src] += 1
return obj

Expand All @@ -159,14 +157,14 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
Expand Down Expand Up @@ -194,7 +192,8 @@ def barrier(self):

@staticmethod
def create(
init_method: str,
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
Expand All @@ -214,15 +213,14 @@ def create(
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)

return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,
Expand Down