Skip to content

High Batch Size with SD3 Dreambooth Destabilizes Training #8619

@CiaraStrawberry

Description

@CiaraStrawberry

Describe the bug

I have been trying to train a slightly modified IP Adapter architecture for SD3 over the past few days and wrote the training script by copying the up to date weighting, noise and loss code from the train_dreamboothsd3.py script, while training nothing I could do would allow the model to train, after 2-3000 steps at batch size 40 lr 5e-6, the output would just turn to mush.

image

Now, after dropping to a batch size of 4 and lr 8e-7, the problem appears to have gone away completely, no hints of degradation 40,000 steps in.

Only other possible explanation is that around that time I also removed torch.autocast block around the model forward pass that shouldn't have been there given i was also using accelerate, but i don't think that was the source of the issue as that has been there for previous perfectly functional runs using a very similar script. I'll do an extra test to check some point over the next few days when i have a chance.

I have been using the newly modified weighting pushed a day or two ago and logit normal weighting. (same seemed to happen with the original sigma_sqrt weighting)

Reproduction

Run sd3 training using logit norm, presumably with a lora or something and use a batch size of at least 40 with lr 6e-6ish

Logs

a

System Info

4xA100 runpod server

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions