tolgacangoz commited on
Commit
33d718d
·
verified ·
1 Parent(s): c0c0ade

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. matryoshka.py +23 -10
matryoshka.py CHANGED
@@ -102,15 +102,21 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
102
  EXAMPLE_DOC_STRING = """
103
  Examples:
104
  ```py
105
- >>> import torch
106
- >>> from diffusers import MatryoshkaPipeline
107
 
108
- >>> pipe = MatryoshkaPipeline.from_pretrained("A/B", torch_dtype=torch.float16, variant="fp16")
109
- >>> pipe = pipe.to("cuda")
 
110
 
111
- >>> prompt = "a photo of an astronaut riding a horse on mars"
112
- >>> image = pipe(prompt).images[0]
113
- >>> image
 
 
 
 
 
114
  ```
115
  """
116
 
@@ -1636,12 +1642,19 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0:
1636
 
1637
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1638
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1639
- hidden_states = self.attention(
 
 
 
 
 
 
 
1640
  query,
1641
  key,
1642
  value,
1643
- mask=attention_mask,
1644
- num_heads=attn.heads,
1645
  )
1646
 
1647
  hidden_states = hidden_states.to(query.dtype)
 
102
  EXAMPLE_DOC_STRING = """
103
  Examples:
104
  ```py
105
+ >>> from diffusers import DiffusionPipeline
106
+ >>> from diffusers.utils import make_image_grid
107
 
108
+ >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
109
+ >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
110
+ >>> custom_pipeline="matryoshka").to("cuda")
111
 
112
+ >>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
113
+ >>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
114
+ >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
115
+ >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
116
+ >>> make_image_grid(image, rows=1, cols=len(image))
117
+
118
+ >>> pipe.change_nesting_level(<int>) # 0, 1, or 2
119
+ >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
120
  ```
121
  """
122
 
 
1642
 
1643
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1644
  # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available
1645
+ # hidden_states = self.attention(
1646
+ # query,
1647
+ # key,
1648
+ # value,
1649
+ # mask=attention_mask,
1650
+ # num_heads=attn.heads,
1651
+ # )
1652
+ hidden_states = F.scaled_dot_product_attention(
1653
  query,
1654
  key,
1655
  value,
1656
+ attn_mask=attention_mask,
1657
+ dropout=attn.dropout,
1658
  )
1659
 
1660
  hidden_states = hidden_states.to(query.dtype)