Open
Description
diffusers/src/diffusers/models/transformers/transformer_cosmos.py
Lines 188 to 193 in 42077e6
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
Speedup ~10% here in Cosmos2TextToImagePipeline and Cosmos2VideoToWorldPipeline.
Metadata
Metadata
Assignees
Labels
No labels