Skip to content

Add MLXArray and DType extensions#429

Open
ronaldmannak wants to merge 5 commits into
ml-explore:mainfrom
PicoMLX:greatestFiniteMagnitude
Open

Add MLXArray and DType extensions#429
ronaldmannak wants to merge 5 commits into
ml-explore:mainfrom
PicoMLX:greatestFiniteMagnitude

Conversation

@ronaldmannak

Copy link
Copy Markdown
Contributor

Proposed changes

Moving MLXArray and DType extensions from MLX-Swift-LM PR 369 to MLX-Swift, see ml-explore/mlx-swift-lm#369

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

case .bfloat16: return Float(bitPattern: 0x7F7F_0000)
default: return .greatestFiniteMagnitude
}
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, is this the same as:

    /// For floating point values return the floating point info, similar to `numpy.finfo`.
    public var finfo: FInfo? {
        isFloatingPoint ? FInfo(dtype: self) : nil
    }

    /// Floating point info.
    public struct FInfo: Sendable {
        public let dtype: DType

...
        /// The largest representable number
        public var max: Double {
            switch dtype {
            #if !arch(x86_64)
                case .float16: Double(Float16.greatestFiniteMagnitude)
            #else
                case .float16: 65500.0
            #endif
            case .float32: Double(Float.greatestFiniteMagnitude)
            case .bfloat16: 3.3895313892515355e+38
            case .complex64: Double.greatestFiniteMagnitude
            case .float64: Double.greatestFiniteMagnitude
            default:
                fatalError("\(dtype) is not a floating point type")
            }
        }

The Double might actually make this semi-useless though -- we can't use it in an MLXArray on the GPU.

Questions:

  • are these the same?
  • should we fix this (and maybe some of the other properties) to be Float?
  • and replace the bfloat16 with your version? (IIRC I just got that value from numpy)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice if FInfo wasn't optional. We could return values for the ints as well, mostly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, see what you think. I added FInfo at one point for porting some code from vlm, but I don't think it is widely used.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we can indeed use DType.finfo. I do think it will need to stay an optional and floating point-only since integers probably don't make sense here. At least not for masking, and smallestNormal, smallestSubnormal, etc are meaningless for ints.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkoski so the code in mlx-swift-lm became pretty ugly (e.g. let bound = DType.float16.finfo!.max) since the codebase rarely uses exceptions, so I've re-added greatestFiniteMagnitude but instead of duplicating the code, it's now calling finfo. greatestFiniteMagnitude does have a fatalerror in case of non-fp, but at least API is cleaner now.

I've also updated finfo with clearer magical numbers, including for float16 on x86, which had a different value set than on Apple Silicon (65500 vs 65504). If there was a reason for the difference, I can change that back

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good -- I think the use case for this will be in cases with known dtypes so the fatalError is appropriate (programming error). finfo is optional but perhaps greatestFiniteMagnitude should document that it is for floating point types only since it isn't clear from the API.

This is why I was wondering if we could squeeze finfo to make it non-optional, but even numpy does't go that far :-)

raise ValueError(f"data type {dtype!r} not inexact")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidkoski comment edited, as well as added to skill.md

Re: Double, my original implementation returned float but since finfo returned Double, I settled for Double for greatestFiniteMagnitude as well. Should I change both greatestFiniteMagnitude as well as the finfo methods to Double?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants