Skip to content

Numerical Instability in Forward Pass #112

@thalaby

Description

@thalaby

Hello 😁

While running example_inference.py, I observed numerical instability in the forward pass of the model. The activations rapidly grow in magnitude across layers, reaching extreme values, this starts in layer 11. I added the following prints to the StripedHyena class:

    def stateful_forward(self, x, inference_params_dict=None):
        for block_idx, block in enumerate(self.blocks):
            block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
            inference_params = inference_params_dict[block_name]
            x, _ = block(x, inference_params=inference_params)
           # Added code
            print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
        return x, inference_params_dict
    def stateless_forward(self, x, padding_mask=None):
        if type(padding_mask) == torch.Tensor:
            x = x * padding_mask[..., None]

        for block_idx, block in enumerate(self.blocks):
            x, _ = block(x, inference_params=None, padding_mask=padding_mask)
            # Added code
            print(f"Block {block_idx} stats - min: {x.min().item():.3f}, max: {x.max().item():.3f}, mean: {x.mean().item():.3f}")
        return x, None

After running example_inference.py I got the following results:

Block 0 stats - min: -1.430, max: 1.688, mean: 0.000
Block 1 stats - min: -51.000, max: 54.500, mean: -0.003
Block 2 stats - min: -60.250, max: 74.000, mean: -0.007
Block 3 stats - min: -69.500, max: 88.000, mean: -0.009
Block 4 stats - min: -73.500, max: 93.000, mean: -0.009
Block 5 stats - min: -78.000, max: 99.000, mean: -0.012
Block 6 stats - min: -91.500, max: 115.500, mean: -0.015
Block 7 stats - min: -140.000, max: 180.000, mean: -0.008
Block 8 stats - min: -75.000, max: 79.500, mean: 0.004
Block 9 stats - min: -72.000, max: 88.500, mean: -0.011
Block 10 stats - min: -752.000, max: 724.000, mean: -0.586
Block 11 stats - min: -464896.000, max: 462848.000, mean: -130.000
Block 12 stats - min: -14548992.000, max: 18743296.000, mean: -18304.000
Block 13 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 14 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 15 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 16 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 17 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 18 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 19 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 20 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 21 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 22 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 23 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 24 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 25 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 26 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 27 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 28 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 29 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 30 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000
Block 31 stats - min: -394264576.000, max: 362807296.000, mean: 55040.000

Issues

  • The activations explode in magnitude as they propagate through the layers.
  • By block 10, values reach ±750, and by block 11, they explode to ±460,000.
  • From block 13 onward, the model produces massive activation values in the range of ±394 million.

I got a good inference result at the end:

Image

I this what is supposed to happen or am I missing something?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions