Upload matryoshka.py
Browse files- matryoshka.py +23 -10
matryoshka.py
CHANGED
@@ -102,15 +102,21 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
102 |
EXAMPLE_DOC_STRING = """
|
103 |
Examples:
|
104 |
```py
|
105 |
-
>>> import
|
106 |
-
>>> from diffusers import
|
107 |
|
108 |
-
>>>
|
109 |
-
>>> pipe =
|
|
|
110 |
|
111 |
-
>>>
|
112 |
-
>>>
|
113 |
-
>>>
|
|
|
|
|
|
|
|
|
|
|
114 |
```
|
115 |
"""
|
116 |
|
@@ -1636,12 +1642,19 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
|
|
1636 |
|
1637 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1638 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1639 |
-
hidden_states = self.attention(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1640 |
query,
|
1641 |
key,
|
1642 |
value,
|
1643 |
-
|
1644 |
-
|
1645 |
)
|
1646 |
|
1647 |
hidden_states = hidden_states.to(query.dtype)
|
|
|
102 |
EXAMPLE_DOC_STRING = """
|
103 |
Examples:
|
104 |
```py
|
105 |
+
>>> from diffusers import DiffusionPipeline
|
106 |
+
>>> from diffusers.utils import make_image_grid
|
107 |
|
108 |
+
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
|
109 |
+
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
110 |
+
>>> custom_pipeline="matryoshka").to("cuda")
|
111 |
|
112 |
+
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
|
113 |
+
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
|
114 |
+
>>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
|
115 |
+
>>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
|
116 |
+
>>> make_image_grid(image, rows=1, cols=len(image))
|
117 |
+
|
118 |
+
>>> pipe.change_nesting_level(<int>) # 0, 1, or 2
|
119 |
+
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
|
120 |
```
|
121 |
"""
|
122 |
|
|
|
1642 |
|
1643 |
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1644 |
# TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
|
1645 |
+
# hidden_states = self.attention(
|
1646 |
+
# query,
|
1647 |
+
# key,
|
1648 |
+
# value,
|
1649 |
+
# mask=attention_mask,
|
1650 |
+
# num_heads=attn.heads,
|
1651 |
+
# )
|
1652 |
+
hidden_states = F.scaled_dot_product_attention(
|
1653 |
query,
|
1654 |
key,
|
1655 |
value,
|
1656 |
+
attn_mask=attention_mask,
|
1657 |
+
dropout=attn.dropout,
|
1658 |
)
|
1659 |
|
1660 |
hidden_states = hidden_states.to(query.dtype)
|