either torch.compile / triton, forward / backward operations got too much activations that are probably bottlenecking training.
For some reason, i got about 30% speedup at 1B scale but does not seem to do better at larger scale. Either way, attaching a good fusedlinear , fusedlayernorm, fusedMHlayernorm would be very helpful.
Reason i would prefer over torch.compile is that it torch.compile with max-autotune takes entirety. :P
Some references
either torch.compile / triton, forward / backward operations got too much activations that are probably bottlenecking training.
For some reason, i got about 30% speedup at 1B scale but does not seem to do better at larger scale. Either way, attaching a good fusedlinear , fusedlayernorm, fusedMHlayernorm would be very helpful.
Reason i would prefer over torch.compile is that it torch.compile with max-autotune takes entirety. :P
Some references
https://github.com/crowsonkb/k-diffusion/blob/21d12c91ad4550e8fcf3308ff9fe7116b3f19a08/k_diffusion/models/image_transformer_v2.py#L90
trident has lot of them implemented? https://github.com/kakaobrain/trident
unsloth has many backwards implemented https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/geglu.py
Triton code extraction https://youtu.be/LuhJEEJQgUM?t=2234