Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,38 @@ Full loss curves and per-epoch breakdown: [summary.md](summary.md)
### Loading from Python

```python
import torch
from src.model import UNetModel, DDIMScheduler
import sys, torch
sys.path.insert(0, "src") # make src/ importable

from huggingface_hub import hf_hub_download
from SD_Model import UNetModel # legacy single-file module
# — or, equivalently, the refactored module: from model import UNetModel

checkpoint = hf_hub_download(
repo_id="atandra2000/sd-from-scratch-v1",
filename="sd_epoch_042.pt",
local_dir="checkpoints",
)

ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
unet_sd = ckpt["unet_state_dict"]

# Load EMA shadow (produces better images than live weights)
unet = UNetModel(in_ch=4, out_ch=4, ch=320, res_blks=2,
attn_lvls=(1, 2, 3), ch_mults=(1, 2, 4, 4),
heads=8, ctx_dim=768)
unet.load_state_dict(unet_sd, strict=True)
shadow = ckpt["ema_state_dict"]["shadow"]
cleaned = {}
for k, v in shadow.items():
for prefix in ("module.", "unet.", "_orig_mod."):
if k.startswith(prefix):
k = k[len(prefix):]
break
cleaned[k] = v
unet.load_state_dict(cleaned, strict=False) # strict=False: a few shadow keys may be absent
unet.eval()
```

See `src/inference.py:load_ema_unet()` for the canonical loader used in production.

---

## Training Reproduction
Expand Down Expand Up @@ -240,18 +254,30 @@ python src/inference.py \
### Python API

```python
from src.generate import generate_images

images = generate_images(
prompts=["a beautiful sunset over mountains"],
checkpoint_path="checkpoints/sd_epoch_042.pt",
num_steps=50,
guidance_scale=7.5,
seed=42,
import sys
sys.path.insert(0, "src")

import torch
from transformers import CLIPTokenizer
from generate import load_model, generate

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("checkpoints/sd_epoch_042.pt", device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

images = generate(
model = model,
tokenizer = tokenizer,
prompts = ["a beautiful sunset over mountains"],
num_steps = 50,
guidance_scale = 7.5,
seed = 42,
output_path = "output.png",
)
images[0].save("output.png")
```

Note: `generate()` is the function in `src/generate.py`. It takes a loaded `StableDiffusionModel`, not a checkpoint path — that's what `load_model()` is for above.

See [docs/inference.md](docs/inference.md) for all options.

---
Expand Down
89 changes: 58 additions & 31 deletions docs/blog_post.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,10 @@ Higher than 9.0 introduces oversaturation and CFG artefacts (waxy skin, blown hi

Negative prompts replace the unconditional embedding with something like `"blurry, low quality, distorted, deformed"`. They're the fastest way to prune common generative failure modes.

**Implementation note for the repo:** I have two inference scripts.
**Implementation note for the repo:** I have two inference scripts. Both wire negative prompts end-to-end in the current release:

- `SD_ImageGen.py` (CUDA) — properly wires negative prompts through `generate(..., negative_prompts=...)`.
- `inference.py` (CUDA/MPS, Apple Silicon friendly) — currently parses `--negative` but doesn't thread it into `generate()`; the unconditional branch still uses the empty string. **If you want true negative-prompt CFG, use `SD_ImageGen.py` until I fix that.**
- `SD_ImageGen.py` (CUDA) — `generate(..., negative_prompts=...)` parameter.
- `inference.py` (CUDA/MPS, Apple Silicon friendly) — `--negative` flag is parsed in `main()` and passed to `generate()` as `negative_prompts`; both the per-sample and broadcast (single negative for the whole batch) cases work.

### 6.4 Live vs EMA — The A/B Test That Mattered Most

Expand Down Expand Up @@ -520,14 +520,7 @@ pred_x0 = pred_x0.clamp(-1.0, 1.0) # ← THIS

It looks reasonable. It is catastrophically wrong. **SD latents are not in `[-1, 1]`** — their standard deviation is ≈ 4.0. Clamping decapitates the signal.

I removed it at inference and the images snapped into focus.

```python
# in inference.py: monkey-patch the scheduler before sampling
SD_Model.DDIMScheduler.step = _fixed_ddim_step # the no-clamp version
```

**Caveat:** the clamp is still present in `SD_Model.py:736` because removing it changes behaviour for any old script that imports the class directly. The inference scripts patch it out at runtime. The proper long-term fix is to delete that line.
I made the clamp **opt-in** and turned it off by default. The current `DDIMScheduler.__init__` takes a `clamp_pred_x0: bool = False` parameter; passing `True` reproduces the old training-time behaviour. Both `inference.py` and `SD_ImageGen.py` default to `False`, so the wrong behaviour is no longer reachable from the supported code paths.

**Lesson:** Never assume latent distributions match pixel distributions.

Expand Down Expand Up @@ -571,8 +564,6 @@ If I started over tomorrow:
2. **Start with a lower Min-SNR γ (2.0–2.5) carefully.** For aesthetic-heavy data, lower γ can speed convergence — but test it visually, not just by loss.
3. **Test inference every single epoch.** I'd have caught the latent clamp bug in week one instead of week six.
4. **Validate with EMA from the very first epoch.** Don't ship a fix you forgot to backfill into your visualisations.
5. **Delete the training-time `pred_x0.clamp`** instead of monkey-patching around it. Make the wrong version unreachable.
6. **Wire `--negative` through `inference.py:generate()` properly** so both inference scripts behave identically.

---

Expand All @@ -591,25 +582,61 @@ If I started over tomorrow:

```
StableDiffusion/
├── SD_Model.py # UNet + VAE/CLIP wrappers + DDPM/DDIM schedulers
├── SD_Train.py # 2× RTX 5090 DDP + BF16 training loop
├── SD_ImageGen.py # CUDA inference (full negative-prompt CFG)
├── inference.py # CUDA/MPS inference (DDIM clamp monkey-patched)
├── encode_latents.py # 4-stage VAE → fp16 .npy pipeline
├── 01_download_metadata.py # LAION parquet snapshot
├── 01b_download_diffusiondb.py # DiffusionDB → 512×512 tar shards
├── 01c_download_journeydb_images.py # JourneyDB → 512×512 tar shards
├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup
├── 03_download_images.py # img2dataset LAION downloader
├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB → Arrow HF dataset
├── 04_preprocess_to_cache.py # Tars → parquet (image_key + CLIP tokens)
├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset
├── sd_epoch_042.pt # Released checkpoint (~12.5 GB)
├── sd-val-imgs/ # val_epoch_001..043.png (live-weight grids)
├── sd-logs/ # captured training.log / output*.log
└── generated_images/ # curated epoch-42 renders
├── src/ # Core implementation
│ ├── model.py # UNet + DDPM/DDIM schedulers (refactored)
│ ├── SD_Model.py # Legacy single-file module (kept for reproducibility)
│ ├── SD_Model_v2.py # Earlier refactor (experimental)
│ ├── SD_Model_scratch.py # Throwaway early prototype
│ ├── train.py # 2× RTX 5090 DDP + BF16 training loop (refactored)
│ ├── SD_Train.py # Training loop (the one that produced the checkpoint)
│ ├── SD_Train_v2.py # Earlier refactor
│ ├── inference.py # CUDA/MPS inference (negative prompts wired)
│ ├── SD_ImageGen.py # CUDA inference (negative prompts wired)
│ ├── generate.py # Programmatic generation API
│ ├── encode_latents.py # VAE pre-encoding → .npy
│ └── encode_pipeline.py # Data-parallel latent encoder (2-GPU)
├── data_pipeline/ # LAION-2B + DiffusionDB + JourneyDB pipeline
│ ├── 01_download_metadata.py # LAION parquet snapshot
│ ├── 01b_download_diffusiondb.py # DiffusionDB → 512×512 tar shards
│ ├── 01c_download_journeydb_images.py # JourneyDB → 512×512 tar shards
│ ├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup
│ ├── 03_download_images.py # img2dataset LAION downloader
│ ├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB → Arrow HF dataset
│ ├── 04_preprocess_to_cache.py # Tars → parquet (image_key + CLIP tokens)
│ ├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset
│ └── 06_filter_dataset.py # Final aesthetic / dedup pass
├── configs/
│ └── config.py # Dataclass-based configuration
├── docs/ # Architecture, training, data-pipeline, inference guides
│ ├── architecture.md
│ ├── training-loop.md
│ ├── data-pipeline.md
│ ├── inference.md
│ ├── blog_post.md # (this file)
│ └── images/ # Hero collage, sample gallery, architecture diagram
├── tests/ # CPU smoke tests
│ ├── test_unet_forward.py
│ └── test_ddim_step.py
├── scripts/
│ └── download_checkpoint.py # Hugging Face Hub checkpoint downloader
├── results/ # Training artifacts
│ ├── samples/ # Curated epoch-42 renders
│ ├── loss_curve.csv
│ └── training_status.md
├── assets/ # Architecture diagram, plots
├── sd_epoch_042.pt # Released checkpoint (hosted on HF Hub, not in repo)
├── LICENSE # MIT
├── CITATION.cff
├── README.md
├── requirements.txt
├── .github/workflows/ci.yml # Lint + smoke-test on every push
└── .env.example
```

The released checkpoint lives at
[`atandra2000/sd-from-scratch-v1`](https://huggingface.co/atandra2000/sd-from-scratch-v1) on
the Hugging Face Hub — download it with `python scripts/download_checkpoint.py`.

---

## Final Thoughts
Expand All @@ -621,7 +648,7 @@ Training a Stable Diffusion model from scratch was one of the most rewarding eng
- knowing which optimisations compose and which ones explode when you stack them,
- and trusting visual validation over a loss number that lies to you.

If you're considering this: do it. But go in with your eyes open. The transformers will be fine. It's the JPEG decoder, the pinned buffer, the EMA decay and the one clamp at line 736 that will decide whether your model converges.
If you're considering this: do it. But go in with your eyes open. The transformers will be fine. It's the JPEG decoder, the pinned buffer, the EMA decay and the latent-space distribution mismatch that will decide whether your model converges.

---

Expand Down
Loading