Skip to content

add sana-sprint#11074

Merged
yiyixuxu merged 22 commits into
mainfrom
sana-sprint
Mar 21, 2025
Merged

add sana-sprint#11074
yiyixuxu merged 22 commits into
mainfrom
sana-sprint

Conversation

@yiyixuxu

@yiyixuxu yiyixuxu commented Mar 17, 2025

Copy link
Copy Markdown
Collaborator
# test sana sprint
"""
python scripts/convert_sana_to_diffusers.py \
  --orig_ckpt_path /raid/.cache/huggingface/yiyi/models--JunsongChen--Sana_Sprint_1600M_1024px/snapshots/8ecfdc7e6269e5b065f5b2cf3fac9a2a1a778c6a/Sana_Sprint_1600M_1024px_36K.pth \
  --model_type SanaSprint_1600M_P1_D20 \
  --image_size 1024 \
  --dump_path /raid/yiyi/Sana-Sprint-yiyi \
  --save_full_pipeline \
  --scheduler_type scm
"""

from diffusers import SanaSprintPipeline
import torch

device = "cuda:0"
dtype = torch.bfloat16

repo = "/raid/yiyi/Sana-Sprint-yiyi"

pipeline = SanaSprintPipeline.from_pretrained(repo, torch_dtype=dtype)
pipeline.to(device)


prompt = "a tiny astronaut hatching from an egg on the moon"

image = pipeline(prompt=prompt, num_inference_steps=2).images[0]
image.save("test_out.png")

yiyi_test_8_out

vibe tests with different timesteps settings

# test sana sprint
# (pipeline)
test_max_timesteps = [1.57080, 1.56830, 1.56580, 1.56454, 1.56246, 1.55830, 1.55413, 1.55080, 1.54580]
test_intermediate_timesteps = [None, 1.0, 1.1, 1.2, 1.3, 1.4]
test_num_inference_steps = [1,2,4]

# test_max_timesteps = [1.57080]
# test_intermediate_timesteps = [None]
# test_num_inference_steps = [1]

from diffusers import SanaSprintPipeline
import torch

device = "cuda:0"
dtype = torch.bfloat16
repo = "/raid/yiyi/Sana-Sprint-yiyi"

def run_sana(pipeline, num_inference_steps, max_timesteps, intermediate_timesteps):
    prompt = "a tiny astronaut hatching from an egg on the moon"
    generator = torch.Generator(device=device).manual_seed(123)
    test_name = f"num_inference_steps_{num_inference_steps}_max_timesteps_{max_timesteps}_intermediate_timesteps_{intermediate_timesteps}"
    print(f"--------------------------------")
    print(f"Running test:")
    print(f"num_inference_steps: {num_inference_steps}")
    print(f"max_timesteps: {max_timesteps}")
    print(f"intermediate_timesteps: {intermediate_timesteps}")
    try:
        image = pipeline(prompt=prompt, num_inference_steps=num_inference_steps, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps, generator=generator).images[0]
        image.save(f"yiyi_test_10_1_out_{test_name}.png")
    except Exception as e:
        print(e)
    print(f"--------------------------------")


pipeline = SanaSprintPipeline.from_pretrained(repo, torch_dtype=dtype)
pipeline.to(device)

for num_inference_steps in test_num_inference_steps:
    for max_timesteps in test_max_timesteps:
        for intermediate_timesteps in test_intermediate_timesteps:
            run_sana(pipeline, num_inference_steps, max_timesteps, intermediate_timesteps)

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ishan-modi

Copy link
Copy Markdown
Contributor

Nice Work !! just a heads up this PR might have conflicts with #11040 if merged first

@lawrence-cj

lawrence-cj commented Mar 17, 2025

Copy link
Copy Markdown
Contributor

Wonderful work. Since SANA-Sprint and SANA-1.5 follow the same architecture, so this PR would make SANA-1.5 work as well.
@yiyixuxu @sayakpaul

lawrence-cj and others added 10 commits March 16, 2025 20:39
* 1. update conversion script for sana1.5;
2. add conversion script for sana-sprint;

* seperate sana and sana-sprint conversion scripts;

* update for upstream

* fix the } bug

* add a doc for SanaSprintPipeline;

* minor update;

* make style && make quality
@yiyixuxu yiyixuxu requested review from a-r-r-o-w and hlky March 20, 2025 10:47
@yiyixuxu

Copy link
Copy Markdown
Collaborator Author

@bot /style

@yiyixuxu

Copy link
Copy Markdown
Collaborator Author

@bot/ style

@github-actions

Copy link
Copy Markdown
Contributor

Style fixes have been applied. View the workflow run here.

@yiyixuxu

Copy link
Copy Markdown
Collaborator Author

cc @lawrence-cj can you do a review?

@a-r-r-o-w a-r-r-o-w left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really amazing work! Can't wait for the release ❤️

Comment thread src/diffusers/models/transformers/sana_transformer.py Outdated
Comment thread src/diffusers/schedulers/scheduling_scm.py Outdated
Comment thread src/diffusers/schedulers/scheduling_scm.py Outdated
Comment thread src/diffusers/schedulers/scheduling_scm.py Outdated
Comment on lines +114 to +116
>>> from diffusers import SanaPipeline

>>> pipe = SanaPipeline.from_pretrained(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example to be updated to SanaSprintPipeline

Comment thread src/diffusers/pipelines/sana/pipeline_sana_sprint.py
Comment thread src/diffusers/pipelines/sana/pipeline_sana_sprint.py

@hlky hlky left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yiyixuxu

Comment on lines +149 to +152
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other recent models we found that attention mask with shape [B, 1, 1, N] is faster as the total size is smaller and PyTorch's broadcasting handles it. Something to look into, if we see a benefit all occurrences of this code can be updated.

latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this.

Comment thread scripts/convert_sana_to_diffusers.py Outdated
Comment thread scripts/convert_sana_to_diffusers.py Outdated
@lawrence-cj

lawrence-cj commented Mar 21, 2025

Copy link
Copy Markdown
Contributor

try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be logger.warning()?

else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
print(f"Set timesteps: {self.timesteps}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(f"Set timesteps: {self.timesteps}")

lawrence-cj and others added 2 commits March 21, 2025 05:57
* change sample prompt;

* only 1024px is supported;
@yiyixuxu yiyixuxu merged commit 8a63aa5 into main Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants