Skip to content

Refactor attention.py #1880

Closed
Closed
@patrickvonplaten

Description

@patrickvonplaten

attention.py has at the moment two concurrent attention implementations which essentially do the exact same thing:

Both

class CrossAttention(nn.Module):
and
class AttentionBlock(nn.Module):
are already used for "simple" attention - e.g. the former for Stable Diffusion and the later for the simple DDPM UNet.

We should start deprecating

class AttentionBlock(nn.Module):
very soon as it's not viable to keep two attention mechanisms.

Deprecating this class won't be easy as it essentially means we have to force people to re-upload their weights. Essentially every model checkpoint that made use of

class AttentionBlock(nn.Module):
has to eventually re-upload their weights to be kept compatible.

I would propose to do this in the following way:

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions