Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/MLXNN/PositionalEncoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ final public class ALiBi: Module {
}

let x1 = MLXArray(key.offset ..< key.qSequenceLength).expandedDimensions(axis: 1)
let x2 = MLXArray(0 ..< key.kSequenceLength).expandedDimensions(axis: 1)
let x2 = MLXArray(0 ..< key.kSequenceLength).expandedDimensions(axis: 0)
let distanceMatrix = -abs(expandedDimensions((x1 - x2), axes: [0, 1]))

let slope = alibiSlope(numHeads: key.numHeads)
Expand Down
41 changes: 41 additions & 0 deletions Tests/MLXTests/PositionalEncodingTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright © 2024 Apple Inc.

import Foundation
import MLX
import MLXNN
import XCTest

class MLXNNPositionalEncodingTests: XCTestCase {
func testALiBiMatrixIsRelativeDistance() {
// With a single head the slope is 256**-1, so the bias added to the
// attention scores is the relative-distance matrix -(|i - j|) / 256.
Stream.withNewDefaultStream(device: .cpu) {
let q = 4
let k = 4
let attentionScores = MLXArray.zeros([1, 1, q, k])
let output = ALiBi().callAsFunction(attentionScores: attentionScores)

let slope = 1.0 / 256.0
var expectedValues = [Double]()
for i in 0 ..< q {
for j in 0 ..< k {
expectedValues.append(-Double(abs(i - j)) * slope)
}
}
let expected = MLXArray(converting: expectedValues, [1, 1, q, k])

assertEqual(output, expected)
}
}

func testALiBiSupportsDifferentQueryAndKeyLengths() {
// The query and key sequence lengths differ, so the distance matrix must
// be a proper (q, k) outer difference rather than an elementwise one.
Stream.withNewDefaultStream(device: .cpu) {
let attentionScores = MLXArray.zeros([1, 1, 2, 3])
let output = ALiBi().callAsFunction(attentionScores: attentionScores)

XCTAssertEqual(output.shape, [1, 1, 2, 3])
}
}
}